From c73386fa03be05d3861dae3a7618abf11141ed2a Mon Sep 17 00:00:00 2001 From: Daniel Goldstein Date: Fri, 20 Oct 2023 17:47:38 -0400 Subject: [PATCH] [query] Avoid py4j for python-backend interactions (#13797) CHANGELOG: Fixes #13756: operations that collect large results such as `to_pandas` may require up to 3x less memory. This turns all "actions", i.e. backend methods supported by QoB into HTTP endpoints on the spark and local backends. This intentionally avoids py4j because py4j was really designed to pass function names and references around and does not handle large payloads well (such as results from a `collect`). Specifically, py4j uses a text-based protocol on top of TCP that substantially inflates the memory requirement for communicating large byte arrays. On the Java side, py4j serializes every binary payload as a Base64-encoded `java.lang.String`, which between the Base64 encoding and `String`'s use of UTF-16 results in a memory footprint of the `String` being `4/3 * 2 = 8/3` nearly three times the size of the byte array on either side of the py4j pipe. py4j also appears to do an entire copy of this payload, which means nearly a 6x memory requirement for sending back bytes. Using our own socket means we can directly send back the response bytes to python without any of this overhead, even going so far as to encode results directly into the TCP output stream. Formalizing the API between python and java also allows us to reuse the same payload schema across all three backends. --- hail/python/hail/backend/backend.py | 118 +++- hail/python/hail/backend/local_backend.py | 156 +---- hail/python/hail/backend/py4j_backend.py | 271 +++++--- hail/python/hail/backend/service_backend.py | 412 ++++-------- hail/python/hail/backend/spark_backend.py | 194 +----- hail/python/hail/context.py | 6 +- .../hail/expr/expressions/expression_utils.py | 15 +- hail/python/hail/ir/__init__.py | 3 +- hail/python/hail/ir/base_ir.py | 2 +- hail/python/hail/ir/blockmatrix_ir.py | 15 - hail/python/hail/ir/ir.py | 31 +- hail/python/hail/ir/matrix_ir.py | 20 - hail/python/hail/ir/renderer.py | 62 +- hail/python/hail/ir/table_ir.py | 16 +- hail/python/hail/matrixtable.py | 4 - hail/python/hail/table.py | 4 +- .../test/hail/backend/test_service_backend.py | 15 +- .../hail/genetics/test_reference_genome.py | 2 + hail/src/main/scala/is/hail/HailContext.scala | 3 - .../main/scala/is/hail/HailFeatureFlags.scala | 2 +- .../main/scala/is/hail/backend/Backend.scala | 117 ++++ .../scala/is/hail/backend/BackendServer.scala | 95 +++ .../is/hail/backend/local/LocalBackend.scala | 109 ++-- .../scala/is/hail/backend/service/Main.scala | 2 +- .../hail/backend/service/ServiceBackend.scala | 599 ++++++------------ .../is/hail/backend/service/Worker.scala | 4 +- .../is/hail/backend/spark/SparkBackend.scala | 78 ++- .../main/scala/is/hail/expr/ir/Parser.scala | 16 +- .../src/main/scala/is/hail/io/CodecSpec.scala | 8 +- .../main/scala/is/hail/types/TableType.scala | 10 + .../test/scala/is/hail/expr/ir/IRSuite.scala | 17 +- 31 files changed, 1054 insertions(+), 1352 deletions(-) create mode 100644 hail/src/main/scala/is/hail/backend/BackendServer.scala diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index 0292b35f071..9527bbe765d 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -1,5 +1,7 @@ from typing import Mapping, List, Union, TypeVar, Tuple, Dict, Optional, Any, AbstractSet import abc +from enum import Enum +from dataclasses import dataclass import warnings import orjson import pkg_resources @@ -10,8 +12,12 @@ from ..builtin_references import BUILTIN_REFERENCE_RESOURCE_PATHS from ..expr import Expression -from ..expr.types import HailType -from ..ir import BaseIR +from ..expr.table_type import ttable +from ..expr.matrix_type import tmatrix +from ..expr.blockmatrix_type import tblockmatrix +from ..expr.types import HailType, dtype +from ..ir import BaseIR, finalize_randomness +from ..ir.renderer import CSERenderer from ..linalg.blockmatrix import BlockMatrix from ..matrixtable import MatrixTable from ..table import Table @@ -70,6 +76,57 @@ def local_jar_information() -> LocalJarInformation: ) +class ActionTag(Enum): + LOAD_REFERENCES_FROM_DATASET = 1 + VALUE_TYPE = 2 + TABLE_TYPE = 3 + MATRIX_TABLE_TYPE = 4 + BLOCK_MATRIX_TYPE = 5 + EXECUTE = 6 + PARSE_VCF_METADATA = 7 + IMPORT_FAM = 8 + FROM_FASTA_FILE = 9 + +@dataclass +class ActionPayload: + pass + +@dataclass +class IRTypePayload(ActionPayload): + ir: str + +@dataclass +class ExecutePayload(ActionPayload): + ir: str + stream_codec: str + timed: bool + +@dataclass +class LoadReferencesFromDatasetPayload(ActionPayload): + path: str + +@dataclass +class ParseVCFMetadataPayload(ActionPayload): + path: str + +@dataclass +class ImportFamPayload(ActionPayload): + path: str + quant_pheno: bool + delimiter: str + missing: str + +@dataclass +class FromFASTAFilePayload(ActionPayload): + name: str + fasta_file: str + index_file: str + x_contigs: List[str] + y_contigs: List[str] + mt_contigs: List[str] + par: List[str] + + class Backend(abc.ABC): # Must match knownFlags in HailFeatureFlags.scala _flags_env_vars_and_defaults: Dict[str, Tuple[str, Optional[str]]] = { @@ -115,33 +172,52 @@ def validate_file(self, uri: str): def stop(self): pass - @abc.abstractmethod def execute(self, ir: BaseIR, timed: bool = False) -> Any: - pass + payload = ExecutePayload(self._render_ir(ir), '{"name":"StreamBufferSpec"}', timed) + try: + result, timings = self._rpc(ActionTag.EXECUTE, payload) + except FatalError as e: + raise e.maybe_user_error(ir) from None + value = ir.typ._from_encoding(result) + return (value, timings) if timed else value @abc.abstractmethod - async def _async_execute(self, ir, timed=False): + def _rpc(self, action: ActionTag, payload: ActionPayload) -> Tuple[bytes, str]: pass - @abc.abstractmethod + def _render_ir(self, ir): + r = CSERenderer() + return r(finalize_randomness(ir)) + def value_type(self, ir): - pass + payload = IRTypePayload(ir=self._render_ir(ir)) + type_bytes, _ = self._rpc(ActionTag.VALUE_TYPE, payload) + return dtype(type_bytes.decode('utf-8')) - @abc.abstractmethod def table_type(self, tir): - pass + payload = IRTypePayload(ir=self._render_ir(tir)) + type_bytes, _ = self._rpc(ActionTag.TABLE_TYPE, payload) + return ttable._from_json(orjson.loads(type_bytes)) - @abc.abstractmethod def matrix_type(self, mir): - pass + payload = IRTypePayload(ir=self._render_ir(mir)) + type_bytes, _ = self._rpc(ActionTag.MATRIX_TABLE_TYPE, payload) + return tmatrix._from_json(orjson.loads(type_bytes)) + + def blockmatrix_type(self, bmir): + payload = IRTypePayload(ir=self._render_ir(bmir)) + type_bytes, _ = self._rpc(ActionTag.BLOCK_MATRIX_TYPE, payload) + return tblockmatrix._from_json(orjson.loads(type_bytes)) - @abc.abstractmethod def load_references_from_dataset(self, path): - pass + payload = LoadReferencesFromDatasetPayload(path=path) + references_json_bytes, _ = self._rpc(ActionTag.LOAD_REFERENCES_FROM_DATASET, payload) + return orjson.loads(references_json_bytes) - @abc.abstractmethod def from_fasta_file(self, name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par): - pass + payload = FromFASTAFilePayload(name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par) + rg_json_bytes, _ = self._rpc(ActionTag.FROM_FASTA_FILE, payload) + return orjson.loads(rg_json_bytes) def add_reference(self, rg): self._references[rg.name] = rg @@ -184,9 +260,10 @@ def add_liftover(self, name, chain_file, dest_reference_genome): def remove_liftover(self, name, dest_reference_genome): pass - @abc.abstractmethod def parse_vcf_metadata(self, path): - pass + payload = ParseVCFMetadataPayload(path) + metadata_json_bytes, _ = self._rpc(ActionTag.PARSE_VCF_METADATA, payload) + return orjson.loads(metadata_json_bytes) @property @abc.abstractmethod @@ -198,9 +275,10 @@ def logger(self): def fs(self) -> FS: pass - @abc.abstractmethod def import_fam(self, path: str, quant_pheno: bool, delimiter: str, missing: str): - pass + payload = ImportFamPayload(path, quant_pheno, delimiter, missing) + fam_json_bytes, _ = self._rpc(ActionTag.IMPORT_FAM, payload) + return orjson.loads(fam_json_bytes) def persist(self, dataset: Dataset) -> Dataset: from hail.context import TemporaryFilename @@ -242,7 +320,7 @@ def _initialize_flags(self, initial_flags: Dict[str, str]) -> None: }, **initial_flags) @abc.abstractmethod - def set_flags(self, **flags: Mapping[str, str]): + def set_flags(self, **flags: str): """Set Hail flags.""" pass diff --git a/hail/python/hail/backend/local_backend.py b/hail/python/hail/backend/local_backend.py index 7ad4ab2ec67..fcfe377ac77 100644 --- a/hail/python/hail/backend/local_backend.py +++ b/hail/python/hail/backend/local_backend.py @@ -1,19 +1,13 @@ -from typing import Optional, Union, Tuple, List, Set +from typing import Optional, Union, Tuple, List import os -import socket -import socketserver import sys -from threading import Thread -import py4j from py4j.java_gateway import JavaGateway, GatewayParameters, launch_gateway -from hail.utils.java import scala_package_object from hail.ir.renderer import CSERenderer from hail.ir import finalize_randomness -from .py4j_backend import Py4JBackend, handle_java_exception +from .py4j_backend import Py4JBackend, uninstall_exception_handler from .backend import local_jar_information -from ..hail_logging import Logger from ..expr import Expression from ..expr.types import HailType @@ -22,96 +16,6 @@ from hailtop.aiotools.validators import validate_file -_installed = False -_original = None - - -def install_exception_handler(): - global _installed - global _original - if not _installed: - _original = py4j.protocol.get_return_value - _installed = True - # The original `get_return_value` is not patched, it's idempotent. - patched = handle_java_exception(_original) - # only patch the one used in py4j.java_gateway (call Java API) - py4j.java_gateway.get_return_value = patched - - -def uninstall_exception_handler(): - global _installed - global _original - if _installed: - _installed = False - py4j.protocol.get_return_value = _original - - -class LoggingTCPHandler(socketserver.StreamRequestHandler): - def handle(self): - for line in self.rfile: - sys.stderr.write(line.decode("ISO-8859-1")) - - -class SimpleServer(socketserver.ThreadingMixIn, socketserver.TCPServer): - daemon_threads = True - allow_reuse_address = True - - def __init__(self, server_address, handler_class): - socketserver.TCPServer.__init__(self, server_address, handler_class) - - -def connect_logger(utils_package_object, host, port): - """ - This method starts a simple server which listens on a port for a - client to connect and start writing messages. Whenever a message - is received, it is written to sys.stderr. The server is run in - a daemon thread from the caller, which is killed when the caller - thread dies. - - If the socket is in use, then the server tries to listen on the - next port (port + 1). After 25 tries, it gives up. - - :param str host: Hostname for server. - :param int port: Port to listen on. - """ - server = None - tries = 0 - max_tries = 25 - while not server: - try: - server = SimpleServer((host, port), LoggingTCPHandler) - except socket.error: - port += 1 - tries += 1 - - if tries >= max_tries: - sys.stderr.write( - 'WARNING: Could not find a free port for logger, maximum retries {} exceeded.'.format(max_tries)) - return - - t = Thread(target=server.serve_forever, args=()) - - # The thread should be a daemon so that it shuts down when the parent thread is killed - t.daemon = True - - t.start() - utils_package_object.addSocketAppender(host, port) - - -class Log4jLogger(Logger): - def __init__(self, log_pkg): - self._log_pkg = log_pkg - - def error(self, msg): - self._log_pkg.error(msg) - - def warning(self, msg): - self._log_pkg.warn(msg) - - def info(self, msg): - self._log_pkg.info(msg) - - class LocalBackend(Py4JBackend): def __init__(self, tmpdir, log, quiet, append, branching_factor, skip_logging_configuration, optimizer_iterations, @@ -120,7 +24,6 @@ def __init__(self, tmpdir, log, quiet, append, branching_factor, gcs_requester_pays_project: Optional[str] = None, gcs_requester_pays_buckets: Optional[str] = None ): - super(LocalBackend, self).__init__() assert gcs_requester_pays_project is not None or gcs_requester_pays_buckets is None spark_home = find_spark_home() @@ -149,14 +52,10 @@ def __init__(self, tmpdir, log, quiet, append, branching_factor, die_on_exit=True) self._gateway = JavaGateway( gateway_parameters=GatewayParameters(port=port, auto_convert=True)) - self._jvm = self._gateway.jvm - hail_package = getattr(self._jvm, 'is').hail + hail_package = getattr(self._gateway.jvm, 'is').hail - self._hail_package = hail_package - self._utils_package_object = scala_package_object(hail_package.utils) - - self._jbackend = hail_package.backend.local.LocalBackend.apply( + jbackend = hail_package.backend.local.LocalBackend.apply( tmpdir, gcs_requester_pays_project, gcs_requester_pays_buckets, @@ -165,22 +64,13 @@ def __init__(self, tmpdir, log, quiet, append, branching_factor, append, skip_logging_configuration ) - self._jhc = hail_package.HailContext.apply( - self._jbackend, branching_factor, optimizer_iterations) - self._registered_ir_function_names: Set[str] = set() - - # This has to go after creating the SparkSession. Unclear why. - # Maybe it does its own patch? - install_exception_handler() - - from hail.context import version + jhc = hail_package.HailContext.apply( + jbackend, + branching_factor, + optimizer_iterations + ) - py_version = version() - jar_version = self._jhc.version() - if jar_version != py_version: - raise RuntimeError(f"Hail version mismatch between JAR and Python library\n" - f" JAR: {jar_version}\n" - f" Python: {py_version}") + super(LocalBackend, self).__init__(self._gateway.jvm, jbackend, jhc) self._fs = RouterFS() self._logger = None @@ -190,15 +80,6 @@ def __init__(self, tmpdir, log, quiet, append, branching_factor, def validate_file(self, uri: str) -> None: validate_file(uri, self._fs.afs) - def jvm(self): - return self._jvm - - def hail_package(self): - return self._hail_package - - def utils_package_object(self): - return self._utils_package_object - def register_ir_function(self, name: str, type_parameters: Union[Tuple[HailType, ...], List[HailType]], @@ -206,9 +87,9 @@ def register_ir_function(self, value_parameter_types: Union[Tuple[HailType, ...], List[HailType]], return_type: HailType, body: Expression): - r = CSERenderer(stop_at_jir=True) + r = CSERenderer() code = r(finalize_randomness(body._ir)) - jbody = (self._parse_value_ir(code, ref_map=dict(zip(value_parameter_names, value_parameter_types)), ir_map=r.jirs)) + jbody = self._parse_value_ir(code, ref_map=dict(zip(value_parameter_names, value_parameter_types))) self._registered_ir_function_names.add(name) self.hail_package().expr.ir.functions.IRFunctionRegistry.pyRegisterIR( @@ -219,22 +100,11 @@ def register_ir_function(self, return_type._parsable_string(), jbody) - def _is_registered_ir_function_name(self, name: str) -> bool: - return name in self._registered_ir_function_names - def stop(self): - self._jhc.stop() - self._jhc = None + super().stop() self._gateway.shutdown() - self._registered_ir_function_names = set() uninstall_exception_handler() - @property - def logger(self): - if self._logger is None: - self._logger = Log4jLogger(self._utils_package_object) - return self._logger - @property def fs(self): return self._fs diff --git a/hail/python/hail/backend/py4j_backend.py b/hail/python/hail/backend/py4j_backend.py index 592026ae382..7036ae49153 100644 --- a/hail/python/hail/backend/py4j_backend.py +++ b/hail/python/hail/backend/py4j_backend.py @@ -1,21 +1,52 @@ -from typing import Mapping +from typing import Mapping, Set, Tuple import abc -import json +import socket +import socketserver +import sys +from threading import Thread +import orjson +import requests import py4j -import py4j.java_gateway +from py4j.java_gateway import JavaObject, JVMView import hail from hail.expr import construct_expr -from hail.ir import JavaIR, finalize_randomness +from hail.ir import finalize_randomness, JavaIR from hail.ir.renderer import CSERenderer -from hail.utils.java import FatalError, Env -from hail.expr.blockmatrix_type import tblockmatrix -from hail.expr.matrix_type import tmatrix -from hail.expr.table_type import ttable -from hail.expr.types import dtype +from hail.utils.java import FatalError, Env, scala_package_object -from .backend import Backend, fatal_error_from_java_error_triplet +from .backend import ActionTag, Backend, fatal_error_from_java_error_triplet +from ..hail_logging import Logger + +import http.client +# This defaults to 65536 and fails if a header is longer than _MAXLINE +# The timing json that we output can exceed 65536 bytes so we raise the limit +http.client._MAXLINE = 2 ** 20 + + +_installed = False +_original = None + + +def install_exception_handler(): + global _installed + global _original + if not _installed: + _original = py4j.protocol.get_return_value + _installed = True + # The original `get_return_value` is not patched, it's idempotent. + patched = handle_java_exception(_original) + # only patch the one used in py4j.java_gateway (call Java API) + py4j.java_gateway.get_return_value = patched + + +def uninstall_exception_handler(): + global _installed + global _original + if _installed: + _installed = False + py4j.protocol.get_return_value = _original def handle_java_exception(f): @@ -41,11 +72,88 @@ def deco(*args, **kwargs): return deco -class Py4JBackend(Backend): - _jbackend: py4j.java_gateway.JavaObject +class SimpleServer(socketserver.ThreadingMixIn, socketserver.TCPServer): + daemon_threads = True + allow_reuse_address = True + + def __init__(self, server_address, handler_class): + socketserver.TCPServer.__init__(self, server_address, handler_class) + + +class LoggingTCPHandler(socketserver.StreamRequestHandler): + def handle(self): + for line in self.rfile: + sys.stderr.write(line.decode("ISO-8859-1")) + + +class Log4jLogger(Logger): + def __init__(self, log_pkg): + self._log_pkg = log_pkg + + def error(self, msg): + self._log_pkg.error(msg) + + def warning(self, msg): + self._log_pkg.warn(msg) + + def info(self, msg): + self._log_pkg.info(msg) + + +def connect_logger(utils_package_object, host, port): + """ + This method starts a simple server which listens on a port for a + client to connect and start writing messages. Whenever a message + is received, it is written to sys.stderr. The server is run in + a daemon thread from the caller, which is killed when the caller + thread dies. + + If the socket is in use, then the server tries to listen on the + next port (port + 1). After 25 tries, it gives up. + + :param str host: Hostname for server. + :param int port: Port to listen on. + """ + server = None + tries = 0 + max_tries = 25 + while not server: + try: + server = SimpleServer((host, port), LoggingTCPHandler) + except socket.error: + port += 1 + tries += 1 + + if tries >= max_tries: + sys.stderr.write( + 'WARNING: Could not find a free port for logger, maximum retries {} exceeded.'.format(max_tries)) + return + + t = Thread(target=server.serve_forever, args=()) + + # The thread should be a daemon so that it shuts down when the parent thread is killed + t.daemon = True + + t.start() + utils_package_object.addSocketAppender(host, port) + + +action_routes = { + ActionTag.VALUE_TYPE: '/value/type', + ActionTag.TABLE_TYPE: '/table/type', + ActionTag.MATRIX_TABLE_TYPE: '/matrixtable/type', + ActionTag.BLOCK_MATRIX_TYPE: '/blockmatrix/type', + ActionTag.LOAD_REFERENCES_FROM_DATASET: '/references/load', + ActionTag.FROM_FASTA_FILE: '/references/from_fasta', + ActionTag.EXECUTE: '/execute', + ActionTag.PARSE_VCF_METADATA: '/vcf/metadata/parse', + ActionTag.IMPORT_FAM: '/fam/import', +} + +class Py4JBackend(Backend): @abc.abstractmethod - def __init__(self): + def __init__(self, jvm: JVMView, jbackend: JavaObject, jhc: JavaObject): super(Py4JBackend, self).__init__() import base64 @@ -56,40 +164,66 @@ def decode_bytearray(encoded): # work to support python 2. This eliminates that. py4j.protocol.decode_bytearray = decode_bytearray - @abc.abstractmethod + self._jvm = jvm + self._hail_package = getattr(self._jvm, 'is').hail + self._utils_package_object = scala_package_object(self._hail_package.utils) + self._jbackend = jbackend + self._jhc = jhc + + self._backend_server = self._hail_package.backend.BackendServer.apply(self._jbackend) + self._backend_server_port: int = self._backend_server.port() + self._backend_server.start() + self._requests_session = requests.Session() + + self._registered_ir_function_names: Set[str] = set() + + # This has to go after creating the SparkSession. Unclear why. + # Maybe it does its own patch? + install_exception_handler() + from hail.context import version + + py_version = version() + jar_version = self._jhc.version() + if jar_version != py_version: + raise RuntimeError(f"Hail version mismatch between JAR and Python library\n" + f" JAR: {jar_version}\n" + f" Python: {py_version}") + def jvm(self): - pass + return self._jvm - @abc.abstractmethod def hail_package(self): - pass + return self._hail_package - @abc.abstractmethod def utils_package_object(self): - pass + return self._utils_package_object - def execute(self, ir, timed=False): - jir = self._to_java_value_ir(ir) - stream_codec = '{"name":"StreamBufferSpec"}' - # print(self._hail_package.expr.ir.Pretty.apply(jir, True, -1)) - try: - result_tuple = self._jbackend.executeEncode(jir, stream_codec, timed) - (result, timings) = (result_tuple._1(), result_tuple._2()) - value = ir.typ._from_encoding(result) - - return (value, timings) if timed else value - except FatalError as e: - raise e.maybe_user_error(ir) from None - - async def _async_execute(self, ir, timed=False): - raise NotImplementedError('no async available in Py4JBackend') + @property + def logger(self): + if self._logger is None: + self._logger = Log4jLogger(self._utils_package_object) + return self._logger + + def _rpc(self, action, payload) -> Tuple[bytes, str]: + data = orjson.dumps(payload) + path = action_routes[action] + port = self._backend_server_port + resp = self._requests_session.post(f'http://localhost:{port}{path}', data=data) + if resp.status_code >= 400: + error_json = orjson.loads(resp.content) + raise fatal_error_from_java_error_triplet(error_json['short'], error_json['expanded'], error_json['error_id']) + return resp.content, resp.headers.get('X-Hail-Timings', '') def persist_expression(self, expr): + t = expr.dtype return construct_expr( - JavaIR(self._jbackend.executeLiteral(self._to_java_value_ir(expr._ir))), - expr.dtype + JavaIR(t, self._jbackend.executeLiteral(self._render_ir(expr._ir))), + t ) + def _is_registered_ir_function_name(self, name: str) -> bool: + return name in self._registered_ir_function_names + def set_flags(self, **flags: Mapping[str, str]): available = self._jbackend.availableFlags() invalid = [] @@ -106,17 +240,11 @@ def get_flags(self, *flags) -> Mapping[str, str]: return {flag: self._jbackend.getFlag(flag) for flag in flags} def _add_reference_to_scala_backend(self, rg): - self._jbackend.pyAddReference(json.dumps(rg._config)) + self._jbackend.pyAddReference(orjson.dumps(rg._config).decode('utf-8')) def _remove_reference_from_scala_backend(self, name): self._jbackend.pyRemoveReference(name) - def from_fasta_file(self, name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par): - return json.loads(self._jbackend.pyFromFASTAFile(name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par)) - - def load_references_from_dataset(self, path): - return json.loads(self._jbackend.pyLoadReferencesFromDataset(path)) - def add_sequence(self, name, fasta_file, index_file): self._jbackend.pyAddSequence(name, fasta_file, index_file) @@ -129,65 +257,40 @@ def add_liftover(self, name, chain_file, dest_reference_genome): def remove_liftover(self, name, dest_reference_genome): self._jbackend.pyRemoveLiftover(name, dest_reference_genome) - def parse_vcf_metadata(self, path): - return json.loads(self._jhc.pyParseVCFMetadataJSON(self._jbackend.fs(), path)) - def index_bgen(self, files, index_file_map, referenceGenomeName, contig_recoding, skip_invalid_loci): self._jbackend.pyIndexBgen(files, index_file_map, referenceGenomeName, contig_recoding, skip_invalid_loci) - def import_fam(self, path: str, quant_pheno: bool, delimiter: str, missing: str): - return json.loads(self._jbackend.pyImportFam(path, quant_pheno, delimiter, missing)) - def _to_java_ir(self, ir, parse): if not hasattr(ir, '_jir'): - r = CSERenderer(stop_at_jir=True) + r = CSERenderer() # FIXME parse should be static - ir._jir = parse(r(finalize_randomness(ir)), ir_map=r.jirs) + ir._jir = parse(r(finalize_randomness(ir))) return ir._jir - def _parse_value_ir(self, code, ref_map={}, ir_map={}): + def _parse_value_ir(self, code, ref_map={}): return self._jbackend.parse_value_ir( code, {k: t._parsable_string() for k, t in ref_map.items()}, - ir_map) + ) - def _parse_table_ir(self, code, ir_map={}): - return self._jbackend.parse_table_ir(code, ir_map) + def _parse_table_ir(self, code): + return self._jbackend.parse_table_ir(code) - def _parse_matrix_ir(self, code, ir_map={}): - return self._jbackend.parse_matrix_ir(code, ir_map) + def _parse_matrix_ir(self, code): + return self._jbackend.parse_matrix_ir(code) - def _parse_blockmatrix_ir(self, code, ir_map={}): - return self._jbackend.parse_blockmatrix_ir(code, ir_map) + def _parse_blockmatrix_ir(self, code): + return self._jbackend.parse_blockmatrix_ir(code) def _to_java_value_ir(self, ir): return self._to_java_ir(ir, self._parse_value_ir) - def _to_java_table_ir(self, ir): - return self._to_java_ir(ir, self._parse_table_ir) - - def _to_java_matrix_ir(self, ir): - return self._to_java_ir(ir, self._parse_matrix_ir) - def _to_java_blockmatrix_ir(self, ir): return self._to_java_ir(ir, self._parse_blockmatrix_ir) - def value_type(self, ir): - jir = self._to_java_value_ir(ir) - return dtype(jir.typ().toString()) - - def table_type(self, tir): - jir = self._to_java_table_ir(tir) - return ttable._from_java(jir.typ()) - - def matrix_type(self, mir): - jir = self._to_java_matrix_ir(mir) - return tmatrix._from_java(jir.typ()) - - def blockmatrix_type(self, bmir): - jir = self._to_java_blockmatrix_ir(bmir) - return tblockmatrix._from_java(jir.typ()) - - @property - def requires_lowering(self): - return True + def stop(self): + self._backend_server.stop() + self._jhc.stop() + self._jhc = None + self._registered_ir_function_names = set() + uninstall_exception_handler() diff --git a/hail/python/hail/backend/service_backend.py b/hail/python/hail/backend/service_backend.py index d301f9b8e82..04df3d98f40 100644 --- a/hail/python/hail/backend/service_backend.py +++ b/hail/python/hail/backend/service_backend.py @@ -1,6 +1,7 @@ -from typing import Dict, Optional, Callable, Awaitable, Mapping, Any, List, Union, Tuple, TypeVar, Set +from typing import Dict, Optional, Awaitable, Mapping, Any, List, Union, Tuple, TypeVar, Set import abc import asyncio +from dataclasses import dataclass import math import struct from hail.expr.expressions.base_expression import Expression @@ -8,19 +9,16 @@ import logging import warnings -from hail.context import TemporaryDirectory, tmp_dir, TemporaryFilename, revision, version +from hail.context import TemporaryDirectory, TemporaryFilename, tmp_dir, revision, version from hail.utils import FatalError -from hail.expr.types import HailType, dtype, ttuple, tvoid -from hail.expr.table_type import ttable -from hail.expr.matrix_type import tmatrix -from hail.expr.blockmatrix_type import tblockmatrix -from hail.experimental import write_expression, read_expression +from hail.expr.types import HailType +from hail.experimental import read_expression, write_expression from hail.ir import finalize_randomness from hail.ir.renderer import CSERenderer from hailtop import yamlx from hailtop.config import (ConfigVariable, configuration_of, get_remote_tmpdir) -from hailtop.utils import async_to_blocking, TransientError, Timings, am_i_interactive, retry_transient_errors +from hailtop.utils import async_to_blocking, Timings, am_i_interactive, retry_transient_errors from hailtop.utils.rich_progress_bar import BatchProgressBar from hailtop.batch_client import client as hb from hailtop.batch_client import aioclient as aiohb @@ -31,9 +29,8 @@ from hailtop.fs.router_fs import RouterFS from hailtop.aiotools.fs.exceptions import UnexpectedEOFError -from .backend import Backend, fatal_error_from_java_error_triplet +from .backend import Backend, fatal_error_from_java_error_triplet, ActionTag, ActionPayload, ExecutePayload from ..builtin_references import BUILTIN_REFERENCES -from ..ir import BaseIR from ..utils import ANY_REGION from hailtop.aiotools.validators import validate_file @@ -47,41 +44,6 @@ log = logging.getLogger('backend.service_backend') -async def write_bool(strm: afs.WritableStream, v: bool): - if v: - await strm.write(b'\x01') - else: - await strm.write(b'\x00') - - -async def write_int(strm: afs.WritableStream, v: int): - await strm.write(struct.pack(' int: return (await strm.readexactly(1))[0] @@ -95,11 +57,6 @@ async def read_int(strm: afs.ReadableStream) -> int: return struct.unpack(' int: - b = await strm.readexactly(8) - return struct.unpack(' bytes: n = await read_int(strm) return await strm.readexactly(n) @@ -138,6 +95,16 @@ def __repr__(self): return f'GitRevision({self.revision})' +@dataclass +class SerializedIRFunction: + name: str + type_parameters: List[str] + value_parameter_names: List[str] + value_parameter_types: List[str] + return_type: str + rendered_body: str + + class IRFunction: def __init__(self, name: str, @@ -147,7 +114,7 @@ def __init__(self, return_type: HailType, body: Expression): assert len(value_parameter_names) == len(value_parameter_types) - render = CSERenderer(stop_at_jir=True) + render = CSERenderer() self._name = name self._type_parameters = type_parameters self._value_parameter_names = value_parameter_names @@ -155,23 +122,51 @@ def __init__(self, self._return_type = return_type self._rendered_body = render(finalize_randomness(body._ir)) - async def serialize(self, writer: afs.WritableStream): - await write_str(writer, self._name) + def to_dataclass(self): + return SerializedIRFunction( + name=self._name, + type_parameters=[tp._parsable_string() for tp in self._type_parameters], + value_parameter_names=list(self._value_parameter_names), + value_parameter_types=[vpt._parsable_string() for vpt in self._value_parameter_types], + return_type=self._return_type._parsable_string(), + rendered_body=self._rendered_body, + ) + + +@dataclass +class ServiceBackendExecutePayload(ActionPayload): + functions: List[SerializedIRFunction] + idempotency_token: str + payload: ExecutePayload + - await write_int(writer, len(self._type_parameters)) - for type_parameter in self._type_parameters: - await write_str(writer, type_parameter._parsable_string()) +@dataclass +class CloudfuseConfig: + bucket: str + mount_path: str + read_only: bool - await write_int(writer, len(self._value_parameter_names)) - for value_parameter_name in self._value_parameter_names: - await write_str(writer, value_parameter_name) - await write_int(writer, len(self._value_parameter_types)) - for value_parameter_type in self._value_parameter_types: - await write_str(writer, value_parameter_type._parsable_string()) +@dataclass +class SequenceConfig: + fasta: str + index: str - await write_str(writer, self._return_type._parsable_string()) - await write_str(writer, self._rendered_body) + +@dataclass +class ServiceBackendRPCConfig: + tmp_dir: str + remote_tmpdir: str + billing_project: str + worker_cores: str + worker_memory: str + storage: str + cloudfuse_configs: List[CloudfuseConfig] + regions: List[str] + flags: Dict[str, str] + custom_references: List[str] + liftovers: Dict[str, Dict[str, str]] + sequences: Dict[str, SequenceConfig] class ServiceBackend(Backend): @@ -363,76 +358,26 @@ def stop(self): self.functions = [] self._registered_ir_function_names = set() - def render(self, ir): - r = CSERenderer() - assert len(r.jirs) == 0 - return r(finalize_randomness(ir)) - - async def _rpc(self, - name: str, - inputs: Callable[[afs.WritableStream, str], Awaitable[None]], - *, - ir: Optional[BaseIR] = None, - progress: Optional[BatchProgressBar] = None, - driver_cores: Optional[Union[int, str]] = None, - driver_memory: Optional[str] = None, - worker_cores: Optional[Union[int, str]] = None, - worker_memory: Optional[str] = None, - ): + async def _run_on_batch( + self, + name: str, + service_backend_config: ServiceBackendRPCConfig, + action: ActionTag, + payload: ActionPayload, + *, + progress: Optional[BatchProgressBar] = None, + driver_cores: Optional[Union[int, str]] = None, + driver_memory: Optional[str] = None, + ) -> Tuple[bytes, str]: timings = Timings() with TemporaryDirectory(ensure_exists=False) as iodir: - readonly_fuse_buckets = set() - storage_requirement_bytes = 0 - with timings.step("write input"): async with await self._async_fs.create(iodir + '/in') as infile: - nonnull_flag_count = sum(v is not None for v in self.flags.values()) - await write_int(infile, nonnull_flag_count) - for k, v in self.flags.items(): - if v is not None: - await write_str(infile, k) - await write_str(infile, v) - custom_references = [rg for rg in self._references.values() if rg.name not in BUILTIN_REFERENCES] - await write_int(infile, len(custom_references)) - for reference_config in custom_references: - await write_str(infile, orjson.dumps(reference_config._config).decode('utf-8')) - non_empty_liftovers = {rg.name: rg._liftovers for rg in self._references.values() if len(rg._liftovers) > 0} - await write_int(infile, len(non_empty_liftovers)) - for source_genome_name, liftovers in non_empty_liftovers.items(): - await write_str(infile, source_genome_name) - await write_int(infile, len(liftovers)) - for dest_reference_genome, chain_file in liftovers.items(): - await write_str(infile, dest_reference_genome) - await write_str(infile, chain_file) - added_sequences = {rg.name: rg._sequence_files for rg in self._references.values() if rg._sequence_files is not None} - await write_int(infile, len(added_sequences)) - for rg_name, (fasta_file, index_file) in added_sequences.items(): - await write_str(infile, rg_name) - for blob in (fasta_file, index_file): - bucket, path = self._get_bucket_and_path(blob) - readonly_fuse_buckets.add(bucket) - storage_requirement_bytes += await (await self._async_fs.statfile(blob)).size() - await write_str(infile, f'/cloudfuse/{bucket}/{path}') - if worker_cores is not None: - await write_str(infile, str(worker_cores)) - else: - await write_str(infile, str(self.worker_cores)) - if worker_memory is not None: - await write_str(infile, str(worker_memory)) - else: - await write_str(infile, str(self.worker_memory)) - await write_int(infile, len(self.regions)) - for region in self.regions: - await write_str(infile, region) - storage_gib_str = f'{math.ceil(storage_requirement_bytes / 1024 / 1024 / 1024)}Gi' - await write_str(infile, storage_gib_str) - cloudfuse_config = [(bucket, f'/cloudfuse/{bucket}', True) for bucket in readonly_fuse_buckets] - await write_int(infile, len(cloudfuse_config)) - for bucket, mount_point, readonly in cloudfuse_config: - await write_str(infile, bucket) - await write_str(infile, mount_point) - await write_bool(infile, readonly) - await inputs(infile, self._batch.token) + await infile.write(orjson.dumps({ + 'config': service_backend_config, + 'action': action.value, + 'payload': payload, + })) with timings.step("submit batch"): resources: Dict[str, Union[str, bool]] = {'preemptible': False} @@ -446,8 +391,8 @@ async def _rpc(self, elif self.driver_memory is not None: resources['memory'] = str(self.driver_memory) - if storage_requirement_bytes != 0: - resources['storage'] = storage_gib_str + if service_backend_config.storage != '0Gi': + resources['storage'] = service_backend_config.storage j = self._batch.create_jvm_job( jar_spec=self.jar_spec.to_dict(), @@ -461,7 +406,7 @@ async def _rpc(self, resources=resources, attributes={'name': name + '_driver'}, regions=self.regions, - cloudfuse=cloudfuse_config, + cloudfuse=[(c.bucket, c.mount_path, c.read_only) for c in service_backend_config.cloudfuse_configs], profile=self.flags['profile'] is not None, ) await self._batch.submit(disable_progress_bar=True) @@ -485,10 +430,10 @@ async def _rpc(self, raise with timings.step("read output"): - result_bytes = await retry_transient_errors(self._read_output, ir, iodir + '/out', iodir + '/in') - return result_bytes, timings + result_bytes = await retry_transient_errors(self._read_output, iodir + '/out', iodir + '/in') + return result_bytes, str(timings.to_dict()) - async def _read_output(self, ir: Optional[BaseIR], output_uri: str, input_uri: str) -> bytes: + async def _read_output(self, output_uri: str, input_uri: str) -> bytes: try: driver_output = await self._async_fs.open(output_uri) except FileNotFoundError as exc: @@ -511,10 +456,7 @@ async def _read_output(self, ir: Optional[BaseIR], output_uri: str, input_uri: s expanded_message = await read_str(outfile) error_id = await read_int(outfile) - reconstructed_error = fatal_error_from_java_error_triplet(short_message, expanded_message, error_id) - if ir is None: - raise reconstructed_error - raise reconstructed_error.maybe_user_error(ir) + raise fatal_error_from_java_error_triplet(short_message, expanded_message, error_id) except UnexpectedEOFError as exc: raise FatalError('Hail internal error. Please contact the Hail team and provide the following information.\n\n' + yamlx.dump({ 'service_backend_debug_info': self.debug_info(), @@ -537,127 +479,42 @@ def _cancel_on_ctrl_c(self, coro: Awaitable[T]) -> T: self._batch_was_submitted = False raise - def execute(self, ir: BaseIR, timed: bool = False, **kwargs): - return self._cancel_on_ctrl_c(self._async_execute(ir, timed=timed, **kwargs)) - - async def _async_execute(self, - ir: BaseIR, - *, - timed: bool = False, - progress: Optional[BatchProgressBar] = None, - **kwargs): - async def inputs(infile, token): - await write_int(infile, ServiceBackend.EXECUTE) - await write_str(infile, tmp_dir()) - await write_str(infile, self.billing_project) - await write_str(infile, self.remote_tmpdir) - await write_str(infile, self.render(ir)) - await write_str(infile, token) - await write_int(infile, len(self.functions)) - for fun in self.functions: - await fun.serialize(infile) - await write_str(infile, '{"name":"StreamBufferSpec"}') - - resp, timings = await self._rpc( - 'execute(...)', - inputs, - ir=ir, - progress=progress, - **kwargs + def _rpc(self, action: ActionTag, payload: ActionPayload) -> Tuple[bytes, str]: + return self._cancel_on_ctrl_c(self._async_rpc(action, payload)) + + async def _async_rpc(self, action: ActionTag, payload: ActionPayload): + if isinstance(payload, ExecutePayload): + payload = ServiceBackendExecutePayload([f.to_dataclass() for f in self.functions], self._batch.token, payload) + + storage_requirement_bytes = 0 + readonly_fuse_buckets: Set[str] = set() + + added_sequences = {rg.name: rg._sequence_files for rg in self._references.values() if rg._sequence_files is not None} + sequence_file_mounts = {} + for rg_name, (fasta_file, index_file) in added_sequences.items(): + fasta_bucket, fasta_path = self._get_bucket_and_path(fasta_file) + index_bucket, index_path = self._get_bucket_and_path(index_file) + for bucket, blob in [(fasta_bucket, fasta_file), (index_bucket, index_file)]: + readonly_fuse_buckets.add(bucket) + storage_requirement_bytes += await (await self._async_fs.statfile(blob)).size() + sequence_file_mounts[rg_name] = SequenceConfig(f'/cloudfuse/{fasta_bucket}/{fasta_path}', f'/cloudfuse/{index_bucket}/{index_path}') + + storage_gib_str = f'{math.ceil(storage_requirement_bytes / 1024 / 1024 / 1024)}Gi' + qob_config = ServiceBackendRPCConfig( + tmp_dir=tmp_dir(), + remote_tmpdir=self.remote_tmpdir, + billing_project=self.billing_project, + worker_cores=str(self.worker_cores), + worker_memory=str(self.worker_memory), + storage=storage_gib_str, + cloudfuse_configs=[CloudfuseConfig(bucket, f'/cloudfuse/{bucket}', True) for bucket in readonly_fuse_buckets], + regions=self.regions, + flags=self.flags, + custom_references=[orjson.dumps(rg._config).decode('utf-8') for rg in self._references.values() if rg.name not in BUILTIN_REFERENCES], + liftovers={rg.name: rg._liftovers for rg in self._references.values() if len(rg._liftovers) > 0}, + sequences=sequence_file_mounts, ) - typ: HailType = ir.typ - if typ == tvoid: - assert resp == b'', (typ, resp) - converted_value = None - else: - converted_value = ttuple(typ)._from_encoding(resp)[0] - if timed: - return converted_value, timings - return converted_value - - def value_type(self, ir): - return self._cancel_on_ctrl_c(self._async_value_type(ir)) - - async def _async_value_type(self, ir, *, progress: Optional[BatchProgressBar] = None): - async def inputs(infile, _): - await write_int(infile, ServiceBackend.VALUE_TYPE) - await write_str(infile, tmp_dir()) - await write_str(infile, self.billing_project) - await write_str(infile, self.remote_tmpdir) - await write_str(infile, self.render(ir)) - resp, _ = await self._rpc('value_type(...)', inputs, progress=progress) - return dtype(orjson.loads(resp)) - - def table_type(self, tir): - return self._cancel_on_ctrl_c(self._async_table_type(tir)) - - async def _async_table_type(self, tir, *, progress: Optional[BatchProgressBar] = None): - async def inputs(infile, _): - await write_int(infile, ServiceBackend.TABLE_TYPE) - await write_str(infile, tmp_dir()) - await write_str(infile, self.billing_project) - await write_str(infile, self.remote_tmpdir) - await write_str(infile, self.render(tir)) - resp, _ = await self._rpc('table_type(...)', inputs, progress=progress) - return ttable._from_json(orjson.loads(resp)) - - def matrix_type(self, mir): - return self._cancel_on_ctrl_c(self._async_matrix_type(mir)) - - async def _async_matrix_type(self, mir, *, progress: Optional[BatchProgressBar] = None): - async def inputs(infile, _): - await write_int(infile, ServiceBackend.MATRIX_TABLE_TYPE) - await write_str(infile, tmp_dir()) - await write_str(infile, self.billing_project) - await write_str(infile, self.remote_tmpdir) - await write_str(infile, self.render(mir)) - resp, _ = await self._rpc('matrix_type(...)', inputs, progress=progress) - return tmatrix._from_json(orjson.loads(resp)) - - def blockmatrix_type(self, bmir): - return self._cancel_on_ctrl_c(self._async_blockmatrix_type(bmir)) - - async def _async_blockmatrix_type(self, bmir, *, progress: Optional[BatchProgressBar] = None): - async def inputs(infile, _): - await write_int(infile, ServiceBackend.BLOCK_MATRIX_TYPE) - await write_str(infile, tmp_dir()) - await write_str(infile, self.billing_project) - await write_str(infile, self.remote_tmpdir) - await write_str(infile, self.render(bmir)) - resp, _ = await self._rpc('blockmatrix_type(...)', inputs, progress=progress) - return tblockmatrix._from_json(orjson.loads(resp)) - - def from_fasta_file(self, name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par): - return async_to_blocking(self._from_fasta_file(name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par)) - - async def _from_fasta_file(self, name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par, *, progress: Optional[BatchProgressBar] = None): - async def inputs(infile, _): - await write_int(infile, ServiceBackend.FROM_FASTA_FILE) - await write_str(infile, tmp_dir()) - await write_str(infile, self.billing_project) - await write_str(infile, self.remote_tmpdir) - await write_str(infile, name) - await write_str(infile, fasta_file) - await write_str(infile, index_file) - await write_str_array(infile, x_contigs) - await write_str_array(infile, y_contigs) - await write_str_array(infile, mt_contigs) - await write_str_array(infile, par) - resp, _ = await self._rpc('from_fasta_file(...)', inputs, progress=progress) - return orjson.loads(resp) - - def load_references_from_dataset(self, path): - return self._cancel_on_ctrl_c(self._async_load_references_from_dataset(path)) - - async def _async_load_references_from_dataset(self, path, *, progress: Optional[BatchProgressBar] = None): - async def inputs(infile, _): - await write_int(infile, ServiceBackend.LOAD_REFERENCES_FROM_DATASET) - await write_str(infile, tmp_dir()) - await write_str(infile, self.billing_project) - await write_str(infile, self.remote_tmpdir) - await write_str(infile, path) - resp, _ = await self._rpc('load_references_from_dataset(...)', inputs, progress=progress) - return orjson.loads(resp) + return await self._run_on_batch(f'{action.name.lower()}(...)', qob_config, action, payload) # Sequence and liftover information is stored on the ReferenceGenome # and there is no persistent backend to keep in sync. @@ -680,41 +537,6 @@ def add_liftover(self, name: str, chain_file: str, dest_reference_genome: str): def remove_liftover(self, name, dest_reference_genome): # pylint: disable=unused-argument pass - def parse_vcf_metadata(self, path): - return self._cancel_on_ctrl_c(self._async_parse_vcf_metadata(path)) - - async def _async_parse_vcf_metadata(self, path, *, progress: Optional[BatchProgressBar] = None): - async def inputs(infile, _): - await write_int(infile, ServiceBackend.PARSE_VCF_METADATA) - await write_str(infile, tmp_dir()) - await write_str(infile, self.billing_project) - await write_str(infile, self.remote_tmpdir) - await write_str(infile, path) - resp, _ = await self._rpc('parse_vcf_metadata(...)', inputs, progress=progress) - return orjson.loads(resp) - - def import_fam(self, path: str, quant_pheno: bool, delimiter: str, missing: str): - return self._cancel_on_ctrl_c(self._async_import_fam(path, quant_pheno, delimiter, missing)) - - async def _async_import_fam(self, - path: str, - quant_pheno: bool, - delimiter: str, - missing: str, - *, - progress: Optional[BatchProgressBar] = None): - async def inputs(infile, _): - await write_int(infile, ServiceBackend.IMPORT_FAM) - await write_str(infile, tmp_dir()) - await write_str(infile, self.billing_project) - await write_str(infile, self.remote_tmpdir) - await write_str(infile, path) - await write_bool(infile, quant_pheno) - await write_str(infile, delimiter) - await write_str(infile, missing) - resp, _ = await self._rpc('import_fam(...)', inputs, progress=progress) - return orjson.loads(resp) - def register_ir_function(self, name: str, type_parameters: Union[Tuple[HailType, ...], List[HailType]], diff --git a/hail/python/hail/backend/spark_backend.py b/hail/python/hail/backend/spark_backend.py index a09b6f0d013..d88d113ed11 100644 --- a/hail/python/hail/backend/spark_backend.py +++ b/hail/python/hail/backend/spark_backend.py @@ -1,120 +1,22 @@ -from typing import Set import sys import os -import json -import socket -import socketserver -from threading import Thread -import py4j import pyspark import pyspark.sql -from typing import List, Optional +import orjson +from typing import Optional -import hail as hl -from hail.utils.java import scala_package_object +from hail.expr.table_type import ttable from hail.fs.hadoop_fs import HadoopFS from hail.ir.renderer import CSERenderer from hail.table import Table -from hail.matrixtable import MatrixTable from hailtop.aiotools.router_fs import RouterAsyncFS from hailtop.aiotools.validators import validate_file -from .py4j_backend import Py4JBackend, handle_java_exception -from ..hail_logging import Logger +from .py4j_backend import Py4JBackend from .backend import local_jar_information -_installed = False -_original = None - - -def install_exception_handler(): - global _installed - global _original - if not _installed: - _original = py4j.protocol.get_return_value - _installed = True - # The original `get_return_value` is not patched, it's idempotent. - patched = handle_java_exception(_original) - # only patch the one used in py4j.java_gateway (call Java API) - py4j.java_gateway.get_return_value = patched - - -def uninstall_exception_handler(): - global _installed - global _original - if _installed: - _installed = False - py4j.protocol.get_return_value = _original - - -class LoggingTCPHandler(socketserver.StreamRequestHandler): - def handle(self): - for line in self.rfile: - sys.stderr.write(line.decode("ISO-8859-1")) - - -class SimpleServer(socketserver.ThreadingMixIn, socketserver.TCPServer): - daemon_threads = True - allow_reuse_address = True - - def __init__(self, server_address, handler_class): - socketserver.TCPServer.__init__(self, server_address, handler_class) - - -def connect_logger(utils_package_object, host, port): - """ - This method starts a simple server which listens on a port for a - client to connect and start writing messages. Whenever a message - is received, it is written to sys.stderr. The server is run in - a daemon thread from the caller, which is killed when the caller - thread dies. - - If the socket is in use, then the server tries to listen on the - next port (port + 1). After 25 tries, it gives up. - - :param str host: Hostname for server. - :param int port: Port to listen on. - """ - server = None - tries = 0 - max_tries = 25 - while not server: - try: - server = SimpleServer((host, port), LoggingTCPHandler) - except socket.error: - port += 1 - tries += 1 - - if tries >= max_tries: - sys.stderr.write( - 'WARNING: Could not find a free port for logger, maximum retries {} exceeded.'.format(max_tries)) - return - - t = Thread(target=server.serve_forever, args=()) - - # The thread should be a daemon so that it shuts down when the parent thread is killed - t.daemon = True - - t.start() - utils_package_object.addSocketAppender(host, port) - - -class Log4jLogger(Logger): - def __init__(self, log_pkg): - self._log_pkg = log_pkg - - def error(self, msg): - self._log_pkg.error(msg) - - def warning(self, msg): - self._log_pkg.warn(msg) - - def info(self, msg): - self._log_pkg.info(msg) - - def append_to_comma_separated_list(conf: pyspark.SparkConf, k: str, *new_values: str): old = conf.get(k, None) if old is None: @@ -131,7 +33,6 @@ def __init__(self, idempotent, sc, spark_conf, app_name, master, gcs_requester_pays_project: Optional[str] = None, gcs_requester_pays_buckets: Optional[str] = None ): - super(SparkBackend, self).__init__() assert gcs_requester_pays_project is not None or gcs_requester_pays_buckets is None try: @@ -196,49 +97,34 @@ def __init__(self, idempotent, sc, spark_conf, app_name, master, pyspark.SparkContext._ensure_initialized() self._gateway = pyspark.SparkContext._gateway - self._jvm = pyspark.SparkContext._jvm - - hail_package = getattr(self._jvm, 'is').hail - - self._hail_package = hail_package - self._utils_package_object = scala_package_object(hail_package.utils) + jvm = pyspark.SparkContext._jvm + assert jvm + hail_package = getattr(jvm, 'is').hail jsc = sc._jsc.sc() if sc else None if idempotent: - self._jbackend = hail_package.backend.spark.SparkBackend.getOrCreate( + jbackend = hail_package.backend.spark.SparkBackend.getOrCreate( jsc, app_name, master, local, log, True, append, skip_logging_configuration, min_block_size, tmpdir, local_tmpdir, gcs_requester_pays_project, gcs_requester_pays_buckets) - self._jhc = hail_package.HailContext.getOrCreate( - self._jbackend, branching_factor, optimizer_iterations) + jhc = hail_package.HailContext.getOrCreate( + jbackend, branching_factor, optimizer_iterations) else: - self._jbackend = hail_package.backend.spark.SparkBackend.apply( + jbackend = hail_package.backend.spark.SparkBackend.apply( jsc, app_name, master, local, log, True, append, skip_logging_configuration, min_block_size, tmpdir, local_tmpdir, gcs_requester_pays_project, gcs_requester_pays_buckets) - self._jhc = hail_package.HailContext.apply( - self._jbackend, branching_factor, optimizer_iterations) + jhc = hail_package.HailContext.apply( + jbackend, branching_factor, optimizer_iterations) - self._jsc = self._jbackend.sc() + self._jsc = jbackend.sc() if sc: self.sc = sc else: - self.sc = pyspark.SparkContext(gateway=self._gateway, jsc=self._jvm.JavaSparkContext(self._jsc)) - self._jspark_session = self._jbackend.sparkSession() + self.sc = pyspark.SparkContext(gateway=self._gateway, jsc=jvm.JavaSparkContext(self._jsc)) + self._jspark_session = jbackend.sparkSession() self._spark_session = pyspark.sql.SparkSession(self.sc, self._jspark_session) - self._registered_ir_function_names: Set[str] = set() - - # This has to go after creating the SparkSession. Unclear why. - # Maybe it does its own patch? - install_exception_handler() - from hail.context import version - - py_version = version() - jar_version = self._jhc.version() - if jar_version != py_version: - raise RuntimeError(f"Hail version mismatch between JAR and Python library\n" - f" JAR: {jar_version}\n" - f" Python: {py_version}") + super(SparkBackend, self).__init__(jvm, jbackend, jhc) self._fs = None self._logger = None @@ -248,7 +134,7 @@ def __init__(self, idempotent, sc, spark_conf, app_name, master, if self._jsc.uiWebUrl().isDefined(): sys.stderr.write('SparkUI available at {}\n'.format(self._jsc.uiWebUrl().get())) - self._jbackend.startProgressBar() + jbackend.startProgressBar() self._initialize_flags({}) @@ -259,29 +145,10 @@ def __init__(self, idempotent, sc, spark_conf, app_name, master, def validate_file(self, uri: str) -> None: validate_file(uri, self._router_async_fs) - def jvm(self): - return self._jvm - - def hail_package(self): - return self._hail_package - - def utils_package_object(self): - return self._utils_package_object - def stop(self): - self._jbackend.close() - self._jhc.stop() - self._jhc = None + super().stop() self.sc.stop() self.sc = None - self._registered_ir_function_names = set() - uninstall_exception_handler() - - @property - def logger(self): - if self._logger is None: - self._logger = Log4jLogger(self._utils_package_object) - return self._logger @property def fs(self): @@ -290,19 +157,21 @@ def fs(self): return self._fs def from_spark(self, df, key): - return Table._from_java(self._jbackend.pyFromDF(df._jdf, key)) + result_tuple = self._jbackend.pyFromDF(df._jdf, key) + tir_id, type_json = result_tuple._1(), result_tuple._2() + return Table._from_java(ttable._from_json(orjson.loads(type_json)), tir_id) def to_spark(self, t, flatten): t = t.expand_types() if flatten: t = t.flatten() - return pyspark.sql.DataFrame(self._jbackend.pyToDF(self._to_java_table_ir(t._tir)), self._spark_session) + return pyspark.sql.DataFrame(self._jbackend.pyToDF(self._render_ir(t._tir)), self._spark_session) def register_ir_function(self, name, type_parameters, argument_names, argument_types, return_type, body): - r = CSERenderer(stop_at_jir=True) + r = CSERenderer() assert not body._ir.uses_randomness code = r(body._ir) - jbody = (self._parse_value_ir(code, ref_map=dict(zip(argument_names, argument_types)), ir_map=r.jirs)) + jbody = self._parse_value_ir(code, ref_map=dict(zip(argument_names, argument_types))) self._registered_ir_function_names.add(name) self.hail_package().expr.ir.functions.IRFunctionRegistry.pyRegisterIR( @@ -312,19 +181,6 @@ def register_ir_function(self, name, type_parameters, argument_names, argument_t return_type._parsable_string(), jbody) - def _is_registered_ir_function_name(self, name: str) -> bool: - return name in self._registered_ir_function_names - - def read_multiple_matrix_tables(self, paths: 'List[str]', intervals: 'List[hl.Interval]', intervals_type): - json_repr = { - 'paths': paths, - 'intervals': intervals_type._convert_to_json(intervals), - 'intervalPointType': intervals_type.element_type.point_type._parsable_string(), - } - - results = self._jhc.backend().pyReadMultipleMatrixTables(json.dumps(json_repr)) - return [MatrixTable._from_java(jm) for jm in results] - @property def requires_lowering(self): return False diff --git a/hail/python/hail/context.py b/hail/python/hail/context.py index 6cc6ffc68f8..21570e143a0 100644 --- a/hail/python/hail/context.py +++ b/hail/python/hail/context.py @@ -427,7 +427,8 @@ def init_spark(sc=None, _optimizer_iterations=None, gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None ): - from hail.backend.spark_backend import SparkBackend, connect_logger + from hail.backend.py4j_backend import connect_logger + from hail.backend.spark_backend import SparkBackend log = _get_log(log) tmpdir = _get_tmpdir(tmp_dir) @@ -554,7 +555,8 @@ def init_local( _optimizer_iterations=None, gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None ): - from hail.backend.local_backend import LocalBackend, connect_logger + from hail.backend.py4j_backend import connect_logger + from hail.backend.local_backend import LocalBackend log = _get_log(log) tmpdir = _get_tmpdir(tmpdir) diff --git a/hail/python/hail/expr/expressions/expression_utils.py b/hail/python/hail/expr/expressions/expression_utils.py index 25656849b9a..c2ab0ed7133 100644 --- a/hail/python/hail/expr/expressions/expression_utils.py +++ b/hail/python/hail/expr/expressions/expression_utils.py @@ -1,5 +1,5 @@ from typing import Set, Dict -from hail.typecheck import typecheck, setof, nullable +from hail.typecheck import typecheck, setof from .indices import Indices, Aggregation from ..expressions import Expression, ExpressionException, expr_any @@ -130,8 +130,8 @@ def analyze(caller: str, raise errors[0] -@typecheck(expression=expr_any, _execute_kwargs=nullable(dict)) -def eval_timed(expression, *, _execute_kwargs=None): +@typecheck(expression=expr_any) +def eval_timed(expression): """Evaluate a Hail expression, returning the result and the times taken for each stage in the evaluation process. @@ -158,12 +158,11 @@ def eval_timed(expression, *, _execute_kwargs=None): uid = Env.get_uid() ir = expression._indices.source.select_globals(**{uid: expression}).index_globals()[uid]._ir - _execute_kwargs = _execute_kwargs or {} - return Env.backend().execute(MakeTuple([ir]), timed=True, **_execute_kwargs)[0] + return Env.backend().execute(MakeTuple([ir]), timed=True)[0] -@typecheck(expression=expr_any, _execute_kwargs=nullable(dict)) -def eval(expression, *, _execute_kwargs=None): +@typecheck(expression=expr_any) +def eval(expression): """Evaluate a Hail expression, returning the result. This method is extremely useful for learning about Hail expressions and @@ -189,7 +188,7 @@ def eval(expression, *, _execute_kwargs=None): ------- Any """ - return eval_timed(expression, _execute_kwargs=_execute_kwargs)[0] + return eval_timed(expression)[0] @typecheck(expression=expr_any) diff --git a/hail/python/hail/ir/__init__.py b/hail/python/hail/ir/__init__.py index f7f66d3cc67..6568447b109 100644 --- a/hail/python/hail/ir/__init__.py +++ b/hail/python/hail/ir/__init__.py @@ -36,7 +36,7 @@ MatrixRowsHead, MatrixColsHead, MatrixRowsTail, MatrixColsTail, \ MatrixExplodeCols, CastTableToMatrix, MatrixAnnotateRowsTable, \ MatrixAnnotateColsTable, MatrixToMatrixApply, MatrixRename, \ - MatrixFilterIntervals, JavaMatrix + MatrixFilterIntervals from .blockmatrix_ir import BlockMatrixRead, BlockMatrixMap, BlockMatrixMap2, \ BlockMatrixDot, BlockMatrixBroadcast, BlockMatrixAgg, BlockMatrixFilter, \ BlockMatrixDensify, BlockMatrixSparsifier, BandSparsifier, \ @@ -259,7 +259,6 @@ 'MatrixToMatrixApply', 'MatrixRename', 'MatrixFilterIntervals', - 'JavaMatrix', 'MatrixReader', 'MatrixNativeReader', 'MatrixRangeReader', diff --git a/hail/python/hail/ir/base_ir.py b/hail/python/hail/ir/base_ir.py index 777e32c1f69..2ca2a33d833 100644 --- a/hail/python/hail/ir/base_ir.py +++ b/hail/python/hail/ir/base_ir.py @@ -34,7 +34,7 @@ def __init__(self, *children): self._stack_trace = None def __str__(self): - r = PlainRenderer(stop_at_jir=False) + r = PlainRenderer() return r(self) def render_head(self, r: Renderer): diff --git a/hail/python/hail/ir/blockmatrix_ir.py b/hail/python/hail/ir/blockmatrix_ir.py index eeefc02e7b5..a6278c65fa9 100644 --- a/hail/python/hail/ir/blockmatrix_ir.py +++ b/hail/python/hail/ir/blockmatrix_ir.py @@ -408,21 +408,6 @@ def _compute_type(self, deep_typecheck): return tblockmatrix(hl.tfloat64, tensor_shape, is_row_vector, self.block_size) -class JavaBlockMatrix(BlockMatrixIR): - def __init__(self, jir): - super().__init__() - self._jir = jir - - def render_head(self, r): - return f'(JavaBlockMatrix {r.add_jir(self._jir)}' - - def _compute_type(self, deep_typecheck): - if self._type is None: - return hl.tblockmatrix._from_java(self._jir.typ()) - else: - return self._type - - def tensor_shape_to_matrix_shape(bmir): shape = bmir.typ.shape is_row_vector = bmir.typ.is_row_vector diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index 85bc52aff5b..0654cf0447e 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -1,4 +1,4 @@ -from typing import Callable, TypeVar, cast +from typing import Callable, Optional, TypeVar, cast from typing_extensions import ParamSpec import copy import json @@ -7,7 +7,7 @@ from hailtop.hail_decorator import decorator import hail -from hail.expr.types import dtype, HailType, hail_type, tint32, tint64, \ +from hail.expr.types import HailType, hail_type, tint32, tint64, \ tfloat32, tfloat64, tstr, tbool, tarray, tstream, tndarray, tset, tdict, \ tstruct, ttuple, tinterval, tvoid, trngstate, tlocus, tcall from hail.ir.blockmatrix_writer import BlockMatrixWriter, BlockMatrixMultiWriter @@ -3705,23 +3705,36 @@ def _compute_type(self, env, agg_env, deep_typecheck): return self.virtual_ir.typ +class JavaIRSharedReference: + def __init__(self, ir_id): + self._id = ir_id + + def __del__(self): + from hail.backend.py4j_backend import Py4JBackend + if Env._hc: + backend = Env.backend() + assert isinstance(backend, Py4JBackend) + backend._jbackend.removeJavaIR(self._id) + + class JavaIR(IR): - def __init__(self, jir): + def __init__(self, hail_type: HailType, ir_id: int, ref: Optional[JavaIRSharedReference] = None): super(JavaIR, self).__init__() - self._jir = jir - super().__init__() + self._type = hail_type + self._id = ir_id + self._ref = ref or JavaIRSharedReference(ir_id) def copy(self): - return JavaIR(self._jir) + return JavaIR(self._type, self._id, self._ref) def render_head(self, r): - return f'(JavaIR{r.add_jir(self._jir)}' + return f'(JavaIR {self._id}' def _eq(self, other): - return self._jir == other._jir + return self._id == other._id def _compute_type(self, env, agg_env, deep_typecheck): - return dtype(self._jir.typ().toString()) + return self._type def subst(ir, env, agg_env): diff --git a/hail/python/hail/ir/matrix_ir.py b/hail/python/hail/ir/matrix_ir.py index aea1a70aa1e..44fcc8f44b6 100644 --- a/hail/python/hail/ir/matrix_ir.py +++ b/hail/python/hail/ir/matrix_ir.py @@ -4,7 +4,6 @@ from hail.ir.base_ir import BaseIR, MatrixIR from hail.ir.utils import modify_deep_field, zip_with_index, zip_with_index_field, default_row_uid, default_col_uid, unpack_row_uid, unpack_col_uid import hail.ir.ir as ir -from hail.utils import FatalError from hail.utils.misc import escape_str, parsable_strings, escape_id from hail.utils.jsonx import dump_json from hail.utils.java import Env @@ -1192,22 +1191,3 @@ def _eq(self, other): def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck) return self.child.typ - - -class JavaMatrix(MatrixIR): - def __init__(self, jir): - super().__init__() - self._jir = jir - - def _handle_randomness(self, row_uid_field_name, col_uid_field_name): - raise FatalError('JavaMatrix does not support randomness in consumers') - - def render_head(self, r): - return f'(JavaMatrix {r.add_jir(self._jir)}' - - def _compute_type(self, deep_typecheck): - if self._type is None: - return hl.tmatrix._from_java(self._jir.typ()) - else: - return self._type - diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index d431a9d87c8..9a1d946ef22 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -91,21 +91,13 @@ def is_empty(self) -> bool: class Renderer: @abc.abstractmethod - def add_jir(self, jir): + def __call__(self, x: 'Renderable'): pass class PlainRenderer(Renderer): - def __init__(self, stop_at_jir=False): - self.stop_at_jir = stop_at_jir + def __init__(self): self.count = 0 - self.jirs = {} - - def add_jir(self, jir): - jir_id = f'm{self.count}' - self.count += 1 - self.jirs[jir_id] = jir - return jir_id def __call__(self, x: 'Renderable'): stack = RQStack() @@ -113,23 +105,10 @@ def __call__(self, x: 'Renderable'): while x is not None or stack.non_empty(): if x is not None: - # TODO: it would be nice to put the JavaIR logic in BaseIR somewhere but this isn't trivial - if self.stop_at_jir and hasattr(x, '_jir'): - jir_id = self.add_jir(x._jir) - if isinstance(x, ir.MatrixIR): - builder.append(f'(JavaMatrix {jir_id})') - elif isinstance(x, ir.TableIR): - builder.append(f'(JavaTable {jir_id})') - elif isinstance(x, ir.BlockMatrixIR): - builder.append(f'(JavaBlockMatrix {jir_id})') - else: - assert isinstance(x, ir.IR) - builder.append(f'(JavaIR {jir_id})') - else: - head = x.render_head(self) - if head != '': - builder.append(x.render_head(self)) - stack.push(RenderableQueue(x.render_children(self), x.render_tail(self))) + head = x.render_head(self) + if head != '': + builder.append(x.render_head(self)) + stack.push(RenderableQueue(x.render_children(self), x.render_tail(self))) x = None else: top = stack.peek() @@ -153,32 +132,9 @@ def __call__(self, x: 'Renderable'): class CSERenderer(Renderer): - def __init__(self, stop_at_jir=False): - self.stop_at_jir = stop_at_jir - self.jir_count = 0 - self.jirs = {} + def __init__(self): self.memo: Dict[int, Sequence[str]] = {} - def add_jir(self, jir): - jir_id = f'm{self.jir_count}' - self.jir_count += 1 - self.jirs[jir_id] = jir - return jir_id - - def _add_jir(self, node): - jir_id = self.add_jir(node._jir) - if isinstance(node, ir.MatrixIR): - jref = f'(JavaMatrix {jir_id})' - elif isinstance(node, ir.TableIR): - jref = f'(JavaTable {jir_id})' - elif isinstance(node, ir.BlockMatrixIR): - jref = f'(JavaBlockMatrix {jir_id})' - else: - assert isinstance(node, ir.IR) - jref = f'(JavaIR {jir_id})' - - self.memo[id(node)] = jref - def __call__(self, root: 'ir.BaseIR') -> str: binding_sites = CSEAnalysisPass(self)(root) return CSEPrintPass(self)(root, binding_sites) @@ -243,10 +199,6 @@ def __call__(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: child = node.children[child_idx] - if self.renderer.stop_at_jir and hasattr(child, '_jir'): - self.renderer._add_jir(child) - continue - child_frame = frame.make_child_frame(len(stack)) if isinstance(child, ir.IR): diff --git a/hail/python/hail/ir/table_ir.py b/hail/python/hail/ir/table_ir.py index 9aa822ab0b0..38fcaae51c3 100644 --- a/hail/python/hail/ir/table_ir.py +++ b/hail/python/hail/ir/table_ir.py @@ -1210,15 +1210,23 @@ def _handle_randomness(self, uid_field_name): class JavaTable(TableIR): - def __init__(self, jir): + def __init__(self, table_type, tir_id: int): super().__init__() - self._jir = jir + self._type = table_type + self._id = tir_id def _handle_randomness(self, uid_field_name): raise FatalError('JavaTable does not support randomness in consumers') def render_head(self, r): - return f'(JavaTable {r.add_jir(self._jir)}' + return f'(JavaTable {self._id}' def _compute_type(self, deep_typecheck): - return hl.ttable._from_java(self._jir.typ()) + return self._type + + def __del__(self): + from hail.backend.py4j_backend import Py4JBackend + if Env._hc: + backend = Env.backend() + assert isinstance(backend, Py4JBackend) + backend._jbackend.removeJavaIR(self._id) diff --git a/hail/python/hail/matrixtable.py b/hail/python/hail/matrixtable.py index a3929b22ee7..7fc32ba5c86 100644 --- a/hail/python/hail/matrixtable.py +++ b/hail/python/hail/matrixtable.py @@ -536,10 +536,6 @@ class MatrixTable(ExprContainer): >>> print(entry_stats.global_gq_mean) """ - @staticmethod - def _from_java(jmir): - return MatrixTable(ir.JavaMatrix(jmir)) - @staticmethod @typecheck( globals=nullable(dictof(str, anytype)), diff --git a/hail/python/hail/table.py b/hail/python/hail/table.py index 8bd9b689af5..8fb1ba2ec3c 100644 --- a/hail/python/hail/table.py +++ b/hail/python/hail/table.py @@ -337,8 +337,8 @@ class Table(ExprContainer): """ @staticmethod - def _from_java(jtir): - return Table(ir.JavaTable(jtir)) + def _from_java(table_type, ir_id): + return Table(ir.JavaTable(table_type, ir_id)) def __init__(self, tir): super(Table, self).__init__() diff --git a/hail/python/test/hail/backend/test_service_backend.py b/hail/python/test/hail/backend/test_service_backend.py index 88e8b2ed31d..7050fb724e3 100644 --- a/hail/python/test/hail/backend/test_service_backend.py +++ b/hail/python/test/hail/backend/test_service_backend.py @@ -23,14 +23,14 @@ def test_tiny_driver_has_tiny_memory(): def test_big_driver_has_big_memory(): backend = hl.current_backend() assert isinstance(backend, ServiceBackend) + # A fresh backend is used for every test so this should only affect this method + backend.driver_cores = 8 + backend.driver_memory = 'highmem' t = hl.utils.range_table(100_000_000, 50) # The pytest (client-side) worker dies if we try to realize all 100M rows in memory. # Instead, we realize the 100M rows in memory on the driver and then take just the first 10M # rows back to the client. - hl.eval( - t.aggregate(hl.agg.collect(t.idx), _localize=False)[:10_000_000], - _execute_kwargs={'driver_cores': 8, 'driver_memory': 'highmem'} - ) + hl.eval(t.aggregate(hl.agg.collect(t.idx), _localize=False)[:10_000_000]) @qobtest @@ -52,14 +52,13 @@ def test_tiny_worker_has_tiny_memory(): def test_big_worker_has_big_memory(): backend = hl.current_backend() assert isinstance(backend, ServiceBackend) + backend.worker_cores = 8 + backend.worker_memory = 'highmem' t = hl.utils.range_table(2, n_partitions=2).annotate(nd=hl.nd.ones((30_000, 30_000))) t = t.annotate(nd_sum=t.nd.sum()) # We only eval the small thing so that we trigger an OOM on the worker # but not the driver or client - hl.eval( - t.aggregate(hl.agg.sum(t.nd_sum), _localize=False), - _execute_kwargs={'worker_cores': 8, 'worker_memory': 'highmem'} - ) + hl.eval(t.aggregate(hl.agg.sum(t.nd_sum), _localize=False)) @qobtest diff --git a/hail/python/test/hail/genetics/test_reference_genome.py b/hail/python/test/hail/genetics/test_reference_genome.py index 53dae55b167..30a15b88b9d 100644 --- a/hail/python/test/hail/genetics/test_reference_genome.py +++ b/hail/python/test/hail/genetics/test_reference_genome.py @@ -128,6 +128,7 @@ def test_reference_genome_liftover(): grch38.remove_liftover("GRCh37") +@qobtest def test_liftover_strand(): grch37 = hl.get_reference('GRCh37') grch37.add_liftover(resource('grch37_to_grch38_chr20.over.chain.gz'), 'GRCh38') @@ -184,6 +185,7 @@ def assert_rg_loaded_correctly(name): assert_rg_loaded_correctly('test_rg_2') +@qobtest def test_custom_reference_read_write(): hl.ReferenceGenome("dk", ['hello'], {"hello": 123}) ht = hl.utils.range_table(5) diff --git a/hail/src/main/scala/is/hail/HailContext.scala b/hail/src/main/scala/is/hail/HailContext.scala index 3abf5866284..4e4063378b8 100644 --- a/hail/src/main/scala/is/hail/HailContext.scala +++ b/hail/src/main/scala/is/hail/HailContext.scala @@ -123,9 +123,6 @@ object HailContext { info(s"Running Hail version ${ theContext.version }") - // needs to be after `theContext` is set, since this creates broadcasts - backend.addDefaultReferences() - theContext } diff --git a/hail/src/main/scala/is/hail/HailFeatureFlags.scala b/hail/src/main/scala/is/hail/HailFeatureFlags.scala index 255d9c3d4da..225baf4e274 100644 --- a/hail/src/main/scala/is/hail/HailFeatureFlags.scala +++ b/hail/src/main/scala/is/hail/HailFeatureFlags.scala @@ -47,7 +47,7 @@ object HailFeatureFlags { ) ) - def fromMap(m: mutable.Map[String, String]): HailFeatureFlags = + def fromMap(m: Map[String, String]): HailFeatureFlags = new HailFeatureFlags( mutable.Map( HailFeatureFlags.defaults.map { diff --git a/hail/src/main/scala/is/hail/backend/Backend.scala b/hail/src/main/scala/is/hail/backend/Backend.scala index 25841e6f8f6..0e8f6e2be9e 100644 --- a/hail/src/main/scala/is/hail/backend/Backend.scala +++ b/hail/src/main/scala/is/hail/backend/Backend.scala @@ -1,16 +1,32 @@ package is.hail.backend +import java.io._ +import java.nio.charset.StandardCharsets + +import org.json4s._ +import org.json4s.jackson.{JsonMethods, Serialization} + import is.hail.asm4s._ import is.hail.backend.spark.SparkBackend import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} import is.hail.expr.ir.{CodeCacheKey, CompiledFunction, LoweringAnalyses, SortField, TableIR, TableReader} +import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.fs._ +import is.hail.io.plink.LoadPlink +import is.hail.io.vcf.LoadVCF +import is.hail.expr.ir.{IRParser, BaseIR} import is.hail.linalg.BlockMatrix import is.hail.types._ +import is.hail.types.encoded.EType +import is.hail.types.virtual.TFloat64 +import is.hail.types.physical.PTuple import is.hail.utils._ import is.hail.variant.ReferenceGenome +import scala.collection.mutable import scala.reflect.ClassTag +import is.hail.expr.ir.IRParserEnvironment + object Backend { @@ -19,6 +35,12 @@ object Backend { id += 1 s"hail_query_$id" } + + private var irID: Int = 0 + def nextIRID(): Int = { + irID += 1 + irID + } } abstract class BroadcastValue[T] { def value: T } @@ -28,6 +50,16 @@ trait BackendContext { } abstract class Backend { + val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map() + + protected[this] def addJavaIR(ir: BaseIR): Int = { + val id = Backend.nextIRID() + persistedIR += (id -> ir) + id + } + + def removeJavaIR(id: Int): Unit = persistedIR.remove(id) + def defaultParallelism: Int def canExecuteParallelTasksOnDriver: Boolean = true @@ -115,6 +147,91 @@ abstract class Backend { inputIR: TableIR, analyses: LoweringAnalyses ): TableStage + + def withExecuteContext[T](methodName: String): (ExecuteContext => T) => T + + final def valueType(s: String): Array[Byte] = { + withExecuteContext("tableType") { ctx => + val v = IRParser.parse_value_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + v.typ.toString.getBytes(StandardCharsets.UTF_8) + } + } + + private[this] def jsonToBytes(f: => JValue): Array[Byte] = { + JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8) + } + + final def tableType(s: String): Array[Byte] = jsonToBytes { + withExecuteContext("tableType") { ctx => + val x = IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + x.typ.toJSON + } + } + + final def matrixTableType(s: String): Array[Byte] = jsonToBytes { + withExecuteContext("matrixTableType") { ctx => + IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)).typ.pyJson + } + } + + final def blockMatrixType(s: String): Array[Byte] = jsonToBytes { + withExecuteContext("blockMatrixType") { ctx => + val x = IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + val t = x.typ + JObject( + "element_type" -> JString(t.elementType.toString), + "shape" -> JArray(t.shape.map(s => JInt(s)).toList), + "is_row_vector" -> JBool(t.isRowVector), + "block_size" -> JInt(t.blockSize) + ) + } + } + + def loadReferencesFromDataset(path: String): Array[Byte] = { + withExecuteContext("loadReferencesFromDataset") { ctx => + val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) + rgs.foreach(addReference) + + implicit val formats: Formats = defaultJSONFormats + Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8) + } + } + + def fromFASTAFile(name: String, fastaFile: String, indexFile: String, + xContigs: Array[String], yContigs: Array[String], mtContigs: Array[String], + parInput: Array[String]): Array[Byte] = { + withExecuteContext("fromFASTAFile") { ctx => + val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, + xContigs, yContigs, mtContigs, parInput) + rg.toJSONString.getBytes(StandardCharsets.UTF_8) + } + } + + def parseVCFMetadata(path: String): Array[Byte] = jsonToBytes { + withExecuteContext("parseVCFMetadata") { ctx => + val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path) + implicit val formats = defaultJSONFormats + Extraction.decompose(metadata) + } + } + + def importFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String): Array[Byte] = { + withExecuteContext("importFam") { ctx => + LoadPlink.importFamJSON(ctx.fs, path, isQuantPheno, delimiter, missingValue).getBytes(StandardCharsets.UTF_8) + } + } + + def execute(ir: String, timed: Boolean)(consume: (ExecuteContext, Either[Unit, (PTuple, Long)], String) => Unit): Unit = () + + def encodeToOutputStream(ctx: ExecuteContext, t: PTuple, off: Long, bufferSpecString: String, os: OutputStream): Unit = { + val bs = BufferSpec.parseOrDefault(bufferSpecString) + assert(t.size == 1) + val elementType = t.fields(0).typ + val codec = TypedCodecSpec( + EType.fromTypeAllOptional(elementType.virtualType), elementType.virtualType, bs) + assert(t.isFieldDefined(off, 0)) + codec.encode(ctx, elementType, t.loadField(off, 0), os) + } } trait BackendWithCodeCache { diff --git a/hail/src/main/scala/is/hail/backend/BackendServer.scala b/hail/src/main/scala/is/hail/backend/BackendServer.scala new file mode 100644 index 00000000000..d730ca67611 --- /dev/null +++ b/hail/src/main/scala/is/hail/backend/BackendServer.scala @@ -0,0 +1,95 @@ +package is.hail.backend + +import java.net.InetSocketAddress +import java.nio.charset.StandardCharsets +import com.sun.net.httpserver.{HttpContext, HttpExchange, HttpHandler, HttpServer} + +import org.json4s._ +import org.json4s.jackson.{JsonMethods, Serialization} + +import is.hail.utils._ + +case class IRTypePayload(ir: String) +case class LoadReferencesFromDatasetPayload(path: String) +case class FromFASTAFilePayload(name: String, fasta_file: String, index_file: String, + x_contigs: Array[String], y_contigs: Array[String], mt_contigs: Array[String], + par: Array[String]) +case class ParseVCFMetadataPayload(path: String) +case class ImportFamPayload(path: String, quant_pheno: Boolean, delimiter: String, missing: String) +case class ExecutePayload(ir: String, stream_codec: String, timed: Boolean) + +object BackendServer { + def apply(backend: Backend) = new BackendServer(backend) +} + +class BackendServer(backend: Backend) { + // 0 => let the OS pick an available port + private[this] val httpServer = HttpServer.create(new InetSocketAddress(0), 10) + private[this] val handler = new BackendHttpHandler(backend) + + def port = httpServer.getAddress.getPort + + def start(): Unit = { + httpServer.createContext("/", handler) + httpServer.setExecutor(null) + httpServer.start() + } + + def stop(): Unit = { + httpServer.stop(10) + } +} + +class BackendHttpHandler(backend: Backend) extends HttpHandler { + def handle(exchange: HttpExchange): Unit = { + implicit val formats: Formats = DefaultFormats + + try { + val body = using(exchange.getRequestBody)(JsonMethods.parse(_)) + if (exchange.getRequestURI.getPath == "/execute") { + val config = body.extract[ExecutePayload] + backend.execute(config.ir, config.timed) { (ctx, res, timings) => + exchange.getResponseHeaders().add("X-Hail-Timings", timings) + res match { + case Left(_) => exchange.sendResponseHeaders(200, -1L) + case Right((t, off)) => + exchange.sendResponseHeaders(200, 0L) // 0 => an arbitrarily long response body + using(exchange.getResponseBody()) { os => + backend.encodeToOutputStream(ctx, t, off, config.stream_codec, os) + } + } + } + return + } + val response: Array[Byte] = exchange.getRequestURI.getPath match { + case "/value/type" => backend.valueType(body.extract[IRTypePayload].ir) + case "/table/type" => backend.tableType(body.extract[IRTypePayload].ir) + case "/matrixtable/type" => backend.matrixTableType(body.extract[IRTypePayload].ir) + case "/blockmatrix/type" => backend.blockMatrixType(body.extract[IRTypePayload].ir) + case "/references/load" => backend.loadReferencesFromDataset(body.extract[LoadReferencesFromDatasetPayload].path) + case "/references/from_fasta" => + val config = body.extract[FromFASTAFilePayload] + backend.fromFASTAFile(config.name, config.fasta_file, config.index_file, + config.x_contigs, config.y_contigs, config.mt_contigs, config.par) + case "/vcf/metadata/parse" => backend.parseVCFMetadata(body.extract[ParseVCFMetadataPayload].path) + case "/fam/import" => + val config = body.extract[ImportFamPayload] + backend.importFam(config.path, config.quant_pheno, config.delimiter, config.missing) + } + + exchange.sendResponseHeaders(200, response.length) + using(exchange.getResponseBody())(_.write(response)) + } catch { + case t: Throwable => + val (shortMessage, expandedMessage, errorId) = handleForPython(t) + val errorJson = JObject( + "short" -> JString(shortMessage), + "expanded" -> JString(expandedMessage), + "error_id" -> JInt(errorId) + ) + val errorBytes = JsonMethods.compact(errorJson).getBytes(StandardCharsets.UTF_8) + exchange.sendResponseHeaders(500, errorBytes.length) + using(exchange.getResponseBody())(_.write(errorBytes)) + } + } +} diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index 7952e862da0..22b727fadc8 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -25,6 +25,7 @@ import org.json4s.jackson.{JsonMethods, Serialization} import org.sparkproject.guava.util.concurrent.MoreExecutors import java.io.PrintWriter +import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -55,6 +56,7 @@ object LocalBackend { gcsRequesterPaysProject, gcsRequesterPaysBuckets ) + theLocalBackend.addDefaultReferences() theLocalBackend } @@ -111,6 +113,15 @@ class LocalBackend( ExecutionCache.fromFlags(flags, fs, tmpdir) }) + def withExecuteContext[T](methodName: String): (ExecuteContext => T) => T = { f => + ExecutionTimer.logTime(methodName) { timer => + ExecuteContext.scoped(tmpdir, tmpdir, this, fs, timer, null, theHailClassLoader, this.references, flags, new BackendContext { + override val executionCache: ExecutionCache = + ExecutionCache.fromFlags(flags, fs, tmpdir) + })(f) + } + } + def broadcast[T: ClassTag](value: T): BroadcastValue[T] = new LocalBroadcastValue[T](value) private[this] var stageIdx: Int = 0 @@ -146,7 +157,7 @@ class LocalBackend( def stop(): Unit = LocalBackend.stop() - private[this] def _jvmLowerAndExecute(ctx: ExecuteContext, ir0: IR, print: Option[PrintWriter] = None): (Option[SingleCodeType], Long) = { + private[this] def _jvmLowerAndExecute(ctx: ExecuteContext, ir0: IR, print: Option[PrintWriter] = None): Either[Unit, (PTuple, Long)] = { val ir = LoweringPipeline.darrayLowerer(true)(DArrayLowering.All).apply(ctx, ir0).asInstanceOf[IR] if (!Compilable(ir)) @@ -162,11 +173,10 @@ class LocalBackend( } ctx.timer.time("Run") { - ctx.scopedExecution((hcl, fs, htc, r) => f(hcl, fs, htc, r).apply(r)) - (pt, 0) + Left(ctx.scopedExecution((hcl, fs, htc, r) => f(hcl, fs, htc, r).apply(r))) } } else { - val (pt, f) = ctx.timer.time("Compile") { + val (Some(PTypeReferenceSingleCodeType(pt: PTuple)), f) = ctx.timer.time("Compile") { Compile[AsmFunction1RegionLong](ctx, FastSeq(), FastSeq(classInfo[Region]), LongInfo, @@ -175,12 +185,12 @@ class LocalBackend( } ctx.timer.time("Run") { - (pt, ctx.scopedExecution((hcl, fs, htc, r) => f(hcl, fs, htc, r).apply(r))) + Right((pt, ctx.scopedExecution((hcl, fs, htc, r) => f(hcl, fs, htc, r).apply(r)))) } } } - private[this] def _execute(ctx: ExecuteContext, ir: IR): (Option[SingleCodeType], Long) = { + private[this] def _execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)] = { TypeCheck(ctx, ir) Validate(ir) val queryID = Backend.nextID() @@ -192,41 +202,61 @@ class LocalBackend( } - def executeToJavaValue(timer: ExecutionTimer, ir: IR): Any = + def executeToJavaValue(timer: ExecutionTimer, ir: IR): (Any, ExecutionTimer) = withExecuteContext(timer) { ctx => - val (pt, a) = _execute(ctx, ir) - val result = pt match { - case None => + val result = _execute(ctx, ir) match { + case Left(_) => (null, ctx.timer) - case Some(PTypeReferenceSingleCodeType(pt: PTuple)) => - (SafeRow(pt, a).get(0), ctx.timer) + case Right((pt, off)) => + (SafeRow(pt, off).get(0), ctx.timer) } result } def executeToEncoded(timer: ExecutionTimer, ir: IR, bs: BufferSpec): Array[Byte] = withExecuteContext(timer) { ctx => - val (pt, a) = _execute(ctx, ir) - val result = pt match { - case None => - Array[Byte]() - case Some(PTypeReferenceSingleCodeType(pt: PTuple)) => + val result = _execute(ctx, ir) match { + case Left(_) => Array[Byte]() + case Right((pt, off)) => val elementType = pt.fields(0).typ - assert(pt.isFieldDefined(a, 0)) + assert(pt.isFieldDefined(off, 0)) val codec = TypedCodecSpec( EType.fromTypeAllOptional(elementType.virtualType), elementType.virtualType, bs) - codec.encode(ctx, elementType, pt.loadField(a, 0)) + codec.encode(ctx, elementType, pt.loadField(off, 0)) } result } + def executeLiteral(irStr: String): Int = { + ExecutionTimer.logTime("SparkBackend.executeLiteral") { timer => + withExecuteContext(timer) { ctx => + val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + val t = ir.typ + assert(t.isRealizable) + val queryID = Backend.nextID() + log.info(s"starting execution of query $queryID} of initial size ${ IRSize(ir) }") + val retVal = _execute(ctx, ir) + val literalIR = retVal match { + case Left(x) => throw new HailException("Can't create literal") + case Right((pt, addr)) => GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) + } + log.info(s"finished execution of query $queryID") + addJavaIR(literalIR) + } + } + } - def executeLiteral(ir: IR): IR = { - ExecutionTimer.logTime("LocalBackend.executeLiteral") { timer => - val t = ir.typ - assert(t.isRealizable) - val (value, timings) = executeToJavaValue(timer, ir) - Literal.coerce(t, value) + override def execute(ir: String, timed: Boolean)(consume: (ExecuteContext, Either[Unit, (PTuple, Long)], String) => Unit): Unit = { + withExecuteContext("LocalBackend.execute") { ctx => + val res = ctx.timer.time("execute") { + val irData = IRParser.parse_value_ir(ir, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + val queryID = Backend.nextID() + log.info(s"starting execution of query $queryID of initial size ${ IRSize(irData) }") + _execute(ctx, irData) + } + ctx.timer.finish() + val timings = if (timed) Serialization.write(Map("timings" -> ctx.timer.toMap))(new DefaultFormats {}) else "" + consume(ctx, res, timings) } } @@ -287,36 +317,34 @@ class LocalBackend( } def pyRemoveSequence(name: String) = references(name).removeSequence() - def parse_value_ir(s: String, refMap: java.util.Map[String, String], irMap: java.util.Map[String, BaseIR]): IR = { + def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = { ExecutionTimer.logTime("LocalBackend.parse_value_ir") { timer => withExecuteContext(timer) { ctx => - IRParser.parse_value_ir(s, IRParserEnvironment(ctx, BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*), irMap.asScala.toMap)) + IRParser.parse_value_ir(s, IRParserEnvironment(ctx, BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*), persistedIR.toMap)) } } } - def parse_table_ir(s: String, irMap: java.util.Map[String, BaseIR]): TableIR = { + def parse_table_ir(s: String): TableIR = { ExecutionTimer.logTime("LocalBackend.parse_table_ir") { timer => withExecuteContext(timer) { ctx => - IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = irMap.asScala.toMap)) + IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) } } } - def parse_matrix_ir(s: String, irMap: java.util.Map[String, BaseIR]): MatrixIR = { + def parse_matrix_ir(s: String): MatrixIR = { ExecutionTimer.logTime("LocalBackend.parse_matrix_ir") { timer => withExecuteContext(timer) { ctx => - IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = irMap.asScala.toMap)) + IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) } } } - def parse_blockmatrix_ir( - s: String, irMap: java.util.Map[String, BaseIR] - ): BlockMatrixIR = { + def parse_blockmatrix_ir(s: String): BlockMatrixIR = { ExecutionTimer.logTime("LocalBackend.parse_blockmatrix_ir") { timer => withExecuteContext(timer) { ctx => - IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = irMap.asScala.toMap)) + IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) } } } @@ -330,17 +358,6 @@ class LocalBackend( ): TableReader = LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt, nPartitions) - def pyLoadReferencesFromDataset(path: String): String = { - val rgs = ReferenceGenome.fromHailDataset(fs, path) - rgs.foreach(addReference) - - implicit val formats: Formats = defaultJSONFormats - Serialization.write(rgs.map(_.toJSON).toFastSeq) - } - - def pyImportFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String): String = - LoadPlink.importFamJSON(fs, path, isQuantPheno, delimiter, missingValue) - def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String): Unit = ??? def unpersist(backendContext: BackendContext, id: String): Unit = ??? diff --git a/hail/src/main/scala/is/hail/backend/service/Main.scala b/hail/src/main/scala/is/hail/backend/service/Main.scala index 49440f945ed..910f1e930ad 100644 --- a/hail/src/main/scala/is/hail/backend/service/Main.scala +++ b/hail/src/main/scala/is/hail/backend/service/Main.scala @@ -12,7 +12,7 @@ object Main { def main(argv: Array[String]): Unit = { argv(3) match { case WORKER => Worker.main(argv) - case DRIVER => ServiceBackendSocketAPI2.main(argv) + case DRIVER => ServiceBackendAPI.main(argv) case kind => throw new RuntimeException(s"unknown kind: ${kind}") } } diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index cb50da0fc4b..9a79685fc78 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -43,7 +43,7 @@ class ServiceBackendContext( val workerMemory: String, val storageRequirement: String, val regions: Array[String], - val cloudfuseConfig: Array[(String, String, Boolean)], + val cloudfuseConfig: Array[CloudfuseConfig], val profile: Boolean, val executionCache: ExecutionCache, ) extends BackendContext with Serializable { @@ -51,6 +51,61 @@ class ServiceBackendContext( object ServiceBackend { private val log = Logger.getLogger(getClass.getName()) + + def apply( + jarLocation: String, + name: String, + theHailClassLoader: HailClassLoader, + batchClient: BatchClient, + batchId: Option[Long], + scratchDir: String = sys.env.get("HAIL_WORKER_SCRATCH_DIR").getOrElse(""), + rpcConfig: ServiceBackendRPCPayload + ): ServiceBackend = { + + val flags = HailFeatureFlags.fromMap(rpcConfig.flags) + val shouldProfile = flags.get("profile") != null + val fs = FS.cloudSpecificFS(s"${scratchDir}/secrets/gsa-key/key.json", Some(flags)) + + val backendContext = new ServiceBackendContext( + rpcConfig.billing_project, + rpcConfig.remote_tmpdir, + rpcConfig.worker_cores, + rpcConfig.worker_memory, + rpcConfig.storage, + rpcConfig.regions, + rpcConfig.cloudfuse_configs, + shouldProfile, + ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir) + ) + + val backend = new ServiceBackend( + jarLocation, + name, + new HailClassLoader(getClass().getClassLoader()), + batchClient, + batchId, + flags, + rpcConfig.tmp_dir, + fs, + backendContext, + scratchDir + ) + backend.addDefaultReferences() + + rpcConfig.custom_references.foreach { s => + backend.addReference(ReferenceGenome.fromJSON(s)) + } + rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => + liftoversForSource.foreach { case (destGenome, chainFile) => + backend.addLiftover(sourceGenome, chainFile, destGenome) + } + } + rpcConfig.sequences.foreach { case (rg, seq) => + backend.addSequence(rg, seq.fasta, seq.index) + } + + backend + } } class ServiceBackend( @@ -59,6 +114,10 @@ class ServiceBackend( val theHailClassLoader: HailClassLoader, val batchClient: BatchClient, val curBatchId: Option[Long], + val flags: HailFeatureFlags, + val tmpdir: String, + val fs: FS, + val serviceBackendContext: ServiceBackendContext, val scratchDir: String = sys.env.get("HAIL_WORKER_SCRATCH_DIR").getOrElse(""), ) extends Backend with BackendWithNoCodeCache { import ServiceBackend.log @@ -176,11 +235,11 @@ class ServiceBackend( "mount_tokens" -> JBool(true), "resources" -> resources, "regions" -> JArray(backendContext.regions.map(JString).toList), - "cloudfuse" -> JArray(backendContext.cloudfuseConfig.map { case (bucket, mountPoint, readonly) => + "cloudfuse" -> JArray(backendContext.cloudfuseConfig.map { config => JObject( - "bucket" -> JString(bucket), - "mount_path" -> JString(mountPoint), - "read_only" -> JBool(readonly) + "bucket" -> JString(config.bucket), + "mount_path" -> JString(config.mount_path), + "read_only" -> JBool(config.read_only) ) }.toList) ) @@ -249,47 +308,6 @@ class ServiceBackend( def stop(): Unit = executor.shutdownNow() - def valueType( - ctx: ExecuteContext, - s: String - ): String = { - val x = IRParser.parse_value_ir(ctx, s) - x.typ.toString - } - - def tableType( - ctx: ExecuteContext, - s: String - ): String = { - val x = IRParser.parse_table_ir(ctx, s) - val t = x.typ - val jv = JObject("global_type" -> JString(t.globalType.toString), - "row_type" -> JString(t.rowType.toString), - "row_key" -> JArray(t.key.map(f => JString(f)).toList)) - JsonMethods.compact(jv) - } - - def matrixTableType( - ctx: ExecuteContext, - s: String - ): String = { - val x = IRParser.parse_matrix_ir(ctx, s) - JsonMethods.compact(x.typ.pyJson) - } - - def blockMatrixType( - ctx: ExecuteContext, - s: String - ): String = { - val x = IRParser.parse_blockmatrix_ir(ctx, s) - val t = x.typ - val jv = JObject("element_type" -> JString(t.elementType.toString), - "shape" -> JArray(t.shape.map(s => JInt(s)).toList), - "is_row_vector" -> JBool(t.isRowVector), - "block_size" -> JInt(t.blockSize)) - JsonMethods.compact(jv) - } - private[this] def execute(ctx: ExecuteContext, _x: IR, bufferSpecString: String): Array[Byte] = { TypeCheck(ctx, _x) Validate(_x) @@ -305,19 +323,21 @@ class ServiceBackend( ctx.scopedExecution((hcl, fs, htc, r) => f(hcl, fs, htc, r).apply(r)) Array() } else { - val (Some(PTypeReferenceSingleCodeType(pt)), f) = Compile[AsmFunction1RegionLong](ctx, + val (Some(PTypeReferenceSingleCodeType(pt: PTuple)), f) = Compile[AsmFunction1RegionLong](ctx, FastSeq(), - FastSeq[TypeInfo[_]](classInfo[Region]), LongInfo, + FastSeq(classInfo[Region]), LongInfo, MakeTuple.ordered(FastSeq(x)), optimize = true) val retPType = pt.asInstanceOf[PBaseStruct] + val elementType = pt.fields(0).typ val off = ctx.scopedExecution((hcl, fs, htc, r) => f(hcl, fs, htc, r).apply(r)) val codec = TypedCodecSpec( - EType.fromTypeAllOptional(retPType.virtualType), - retPType.virtualType, + EType.fromTypeAllOptional(elementType.virtualType), + elementType.virtualType, BufferSpec.parseOrDefault(bufferSpecString) ) - codec.encode(ctx, retPType, off) + assert(pt.isFieldDefined(off, 0)) + codec.encode(ctx, elementType, pt.loadField(off, 0)) } } @@ -349,36 +369,6 @@ class ServiceBackend( def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType = ??? - def loadReferencesFromDataset( - ctx: ExecuteContext, - path: String - ): String = { - val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) - rgs.foreach(addReference) - - implicit val formats: Formats = defaultJSONFormats - Serialization.write(rgs.map(_.toJSON).toFastSeq) - } - - def parseVCFMetadata( - ctx: ExecuteContext, - path: String - ): String = { - val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path) - implicit val formats = defaultJSONFormats - JsonMethods.compact(Extraction.decompose(metadata)) - } - - def importFam( - ctx: ExecuteContext, - path: String, - quantPheno: Boolean, - delimiter: String, - missing: String - ): String = { - LoadPlink.importFamJSON(ctx.fs, path, quantPheno, delimiter, missing) - } - def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses @@ -386,25 +376,40 @@ class ServiceBackend( LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses) } - def fromFASTAFile( - ctx: ExecuteContext, - name: String, - fastaFile: String, - indexFile: String, - xContigs: Array[String], - yContigs: Array[String], - mtContigs: Array[String], - parInput: Array[String] - ): String = { - val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, xContigs, yContigs, mtContigs, parInput) - rg.toJSONString + def withExecuteContext[T](methodName: String): (ExecuteContext => T) => T = { f => + ExecutionTimer.logTime(methodName) { timer => + ExecuteContext.scoped( + tmpdir, + "file:///tmp", + this, + fs, + timer, + null, + theHailClassLoader, + references, + flags, + serviceBackendContext + )(f) + } + } + + def addLiftover(name: String, chainFile: String, destRGName: String): Unit = { + withExecuteContext("addLiftover") { ctx => + references(name).addLiftover(ctx, chainFile, destRGName) + } + } + + def addSequence(name: String, fastaFile: String, indexFile: String): Unit = { + withExecuteContext("addSequence") { ctx => + references(name).addSequence(ctx, fastaFile, indexFile) + } } } class EndOfInputException extends RuntimeException class HailBatchFailure(message: String) extends RuntimeException(message) -object ServiceBackendSocketAPI2 { +object ServiceBackendAPI { private[this] val log = Logger.getLogger(getClass.getName()) def main(argv: Array[String]): Unit = { @@ -431,82 +436,27 @@ object ServiceBackendSocketAPI2 { var batchId = BatchConfig.fromConfigFile(s"$scratchDir/batch-config/batch-config.json").map(_.batchId) log.info("BatchConfig parsed.") + implicit val formats: Formats = DefaultFormats + val input = using(fs.openNoCompression(inputURL))(JsonMethods.parse(_)) + val rpcConfig = (input \ "config").extract[ServiceBackendRPCPayload] + // FIXME: when can the classloader be shared? (optimizer benefits!) - val backend = new ServiceBackend( - jarLocation, name, new HailClassLoader(getClass().getClassLoader()), batchClient, batchId, scratchDir) + val backend = ServiceBackend( + jarLocation, name, new HailClassLoader(getClass().getClassLoader()), batchClient, batchId, scratchDir, + rpcConfig + ) log.info("ServiceBackend allocated.") if (HailContext.isInitialized) { HailContext.get.backend = backend - backend.addDefaultReferences() log.info("Default references added to already initialized HailContexet.") } else { HailContext(backend, 50, 3) log.info("HailContexet initialized.") } - new ServiceBackendSocketAPI2(backend, fs, inputURL, outputURL).executeOneCommand() - } -} - -private class HailSocketAPIInputStream( - private[this] val in: InputStream -) extends AutoCloseable { - private[this] var closed: Boolean = false - private[this] val dummy = new Array[Byte](8) - - def read(bytes: Array[Byte], off: Int, n: Int): Unit = { - assert(off + n <= bytes.length) - var read = 0 - while (read < n) { - val r = in.read(bytes, off + read, n - read) - if (r < 0) { - throw new EndOfInputException - } else { - read += r - } - } - } - - def readBool(): Boolean = { - read(dummy, 0, 1) - Memory.loadByte(dummy, 0) != 0.toByte - } - - def readInt(): Int = { - read(dummy, 0, 4) - Memory.loadInt(dummy, 0) - } - - def readLong(): Long = { - read(dummy, 0, 8) - Memory.loadLong(dummy, 0) - } - - def readBytes(): Array[Byte] = { - val n = readInt() - val bytes = new Array[Byte](n) - read(bytes, 0, n) - bytes - } - - def readString(): String = new String(readBytes(), StandardCharsets.UTF_8) - - def readStringArray(): Array[String] = { - val n = readInt() - val arr = new Array[String](n) - var i = 0 - while (i < n) { - arr(i) = readString() - i += 1 - } - arr - } - - def close(): Unit = { - if (!closed) { - in.close() - closed = true - } + val action = (input \ "action").extract[Int] + val payload = (input \ "payload") + new ServiceBackendAPI(backend, fs, outputURL).executeOneCommand(action, payload) } } @@ -545,10 +495,43 @@ private class HailSocketAPIOutputStream( } } -class ServiceBackendSocketAPI2( +case class CloudfuseConfig(bucket: String, mount_path: String, read_only: Boolean) + +case class SequenceConfig(fasta: String, index: String) + +case class ServiceBackendRPCPayload( + tmp_dir: String, + remote_tmpdir: String, + billing_project: String, + worker_cores: String, + worker_memory: String, + storage: String, + cloudfuse_configs: Array[CloudfuseConfig], + regions: Array[String], + flags: Map[String, String], + custom_references: Array[String], + liftovers: Map[String, Map[String, String]], + sequences: Map[String, SequenceConfig], +) + +case class ServiceBackendExecutePayload( + functions: Array[SerializedIRFunction], + idempotency_token: String, + payload: ExecutePayload, +) + +case class SerializedIRFunction( + name: String, + type_parameters: Array[String], + value_parameter_names: Array[String], + value_parameter_types: Array[String], + return_type: String, + rendered_body: String, +) + +class ServiceBackendAPI( private[this] val backend: ServiceBackend, private[this] val fs: FS, - private[this] val inputURL: String, private[this] val outputURL: String, ) extends Thread { private[this] val LOAD_REFERENCES_FROM_DATASET = 1 @@ -563,251 +546,75 @@ class ServiceBackendSocketAPI2( private[this] val log = Logger.getLogger(getClass.getName()) - private[this] def parseInputToCommandThunk(): () => Array[Byte] = retryTransientErrors { - using(fs.openNoCompression(inputURL)) { inputStream => - val input = new HailSocketAPIInputStream(inputStream) - - var nFlagsRemaining = input.readInt() - val flagsMap = mutable.Map[String, String]() - while (nFlagsRemaining > 0) { - val flagName = input.readString() - val flagValue = input.readString() - flagsMap.update(flagName, flagValue) - nFlagsRemaining -= 1 - } - val nCustomReferences = input.readInt() - var i = 0 - while (i < nCustomReferences) { - backend.addReference(ReferenceGenome.fromJSON(input.readString())) - i += 1 - } - val nLiftoverSourceGenomes = input.readInt() - val liftovers = mutable.Map[String, mutable.Map[String, String]]() - i = 0 - while (i < nLiftoverSourceGenomes) { - val sourceGenome = input.readString() - val nLiftovers = input.readInt() - liftovers(sourceGenome) = mutable.Map[String, String]() - var j = 0 - while (j < nLiftovers) { - val destGenome = input.readString() - val chainFile = input.readString() - liftovers(sourceGenome)(destGenome) = chainFile - j += 1 - } - i += 1 - } - val nAddedSequences = input.readInt() - val addedSequences = mutable.Map[String, (String, String)]() - i = 0 - while (i < nAddedSequences) { - val rgName = input.readString() - val fastaFile = input.readString() - val indexFile = input.readString() - addedSequences(rgName) = (fastaFile, indexFile) - i += 1 - } - val workerCores = input.readString() - val workerMemory = input.readString() - - var nRegions = input.readInt() - val regions = { - val regionsArrayBuffer = mutable.ArrayBuffer[String]() - while (nRegions > 0) { - val region = input.readString() - regionsArrayBuffer += region - nRegions -= 1 - } - regionsArrayBuffer.toArray - } - - val storageRequirement = input.readString() - val nCloudfuseConfigElements = input.readInt() - val cloudfuseConfig = new Array[(String, String, Boolean)](nCloudfuseConfigElements) - i = 0 - while (i < nCloudfuseConfigElements) { - val bucket = input.readString() - val mountPoint = input.readString() - val readonly = input.readBool() - cloudfuseConfig(i) = (bucket, mountPoint, readonly) - i += 1 - } - - val cmd = input.readInt() - - val tmpdir = input.readString() - val billingProject = input.readString() - val remoteTmpDir = input.readString() - def withExecuteContext( - methodName: String, - method: ExecuteContext => Array[Byte] - ): () => Array[Byte] = { - val flags = HailFeatureFlags.fromMap(flagsMap) - val shouldProfile = flags.get("profile") != null - val fs = FS.cloudSpecificFS(s"${backend.scratchDir}/secrets/gsa-key/key.json", Some(flags)) - - { () => - ExecutionTimer.logTime(methodName) { timer => - ExecuteContext.scoped( - tmpdir, - "file:///tmp", - backend, - fs, - timer, - null, - backend.theHailClassLoader, - backend.references, - flags, - new ServiceBackendContext(billingProject, remoteTmpDir, workerCores, workerMemory, storageRequirement, regions, cloudfuseConfig, shouldProfile, - ExecutionCache.fromFlags(flags, fs, remoteTmpDir) - ) - ) { ctx => - liftovers.foreach { case (sourceGenome, liftoversForSource) => - liftoversForSource.foreach { case (destGenome, chainFile) => - ctx.getReference(sourceGenome).addLiftover(ctx, chainFile, destGenome) - } - } - addedSequences.foreach { case (rg, (fastaFile, indexFile)) => - ctx.getReference(rg).addSequence(ctx, fastaFile, indexFile) - } - method(ctx) - } + private[this] def doAction(action: Int, payload: JValue): Array[Byte] = retryTransientErrors { + implicit val formats: Formats = DefaultFormats + (action: @switch) match { + case LOAD_REFERENCES_FROM_DATASET => + val path = payload.extract[LoadReferencesFromDatasetPayload].path + backend.loadReferencesFromDataset(path) + case VALUE_TYPE => + val ir = payload.extract[IRTypePayload].ir + backend.valueType(ir) + case TABLE_TYPE => + val ir = payload.extract[IRTypePayload].ir + backend.tableType(ir) + case MATRIX_TABLE_TYPE => + val ir = payload.extract[IRTypePayload].ir + backend.matrixTableType(ir) + case BLOCK_MATRIX_TYPE => + val ir = payload.extract[IRTypePayload].ir + backend.blockMatrixType(ir) + case EXECUTE => + val qobExecutePayload = payload.extract[ServiceBackendExecutePayload] + val bufferSpecString = qobExecutePayload.payload.stream_codec + val code = qobExecutePayload.payload.ir + val token = qobExecutePayload.idempotency_token + backend.withExecuteContext("ServiceBackend.execute") { ctx => + withIRFunctionsReadFromInput(qobExecutePayload.functions, ctx) { () => + backend.execute(ctx, code, token, bufferSpecString) } } - } - - (cmd: @switch) match { - case LOAD_REFERENCES_FROM_DATASET => - val path = input.readString() - withExecuteContext( - "ServiceBackend.loadReferencesFromDataset", - backend.loadReferencesFromDataset(_, path).getBytes(StandardCharsets.UTF_8) - ) - case VALUE_TYPE => - val s = input.readString() - withExecuteContext( - "ServiceBackend.valueType", - backend.valueType(_, s).getBytes(StandardCharsets.UTF_8) - ) - case TABLE_TYPE => - val s = input.readString() - withExecuteContext( - "ServiceBackend.tableType", - backend.tableType(_, s).getBytes(StandardCharsets.UTF_8) - ) - case MATRIX_TABLE_TYPE => - val s = input.readString() - withExecuteContext( - "ServiceBackend.matrixTableType", - backend.matrixTableType(_, s).getBytes(StandardCharsets.UTF_8) - ) - case BLOCK_MATRIX_TYPE => - val s = input.readString() - withExecuteContext( - "ServiceBackend.blockMatrixType", - backend.blockMatrixType(_, s).getBytes(StandardCharsets.UTF_8) - ) - case EXECUTE => - val code = input.readString() - val token = input.readString() - withExecuteContext( - "ServiceBackend.execute", - { ctx => - withIRFunctionsReadFromInput(input, ctx) { () => - val bufferSpecString = input.readString() - backend.execute(ctx, code, token, bufferSpecString) - } - } - ) - case PARSE_VCF_METADATA => - val path = input.readString() - withExecuteContext( - "ServiceBackend.parseVCFMetadata", - backend.parseVCFMetadata(_, path).getBytes(StandardCharsets.UTF_8) - ) - case IMPORT_FAM => - val path = input.readString() - val quantPheno = input.readBool() - val delimiter = input.readString() - val missing = input.readString() - withExecuteContext( - "ServiceBackend.importFam", - backend.importFam(_, path, quantPheno, delimiter, missing).getBytes(StandardCharsets.UTF_8) - ) - case FROM_FASTA_FILE => - val name = input.readString() - val fastaFile = input.readString() - val indexFile = input.readString() - val xContigs = input.readStringArray() - val yContigs = input.readStringArray() - val mtContigs = input.readStringArray() - val parInput = input.readStringArray() - withExecuteContext( - "ServiceBackend.fromFASTAFile", - backend.fromFASTAFile( - _, - name, - fastaFile, - indexFile, - xContigs, - yContigs, - mtContigs, - parInput - ).getBytes(StandardCharsets.UTF_8) - ) - } + case PARSE_VCF_METADATA => + val path = payload.extract[ParseVCFMetadataPayload].path + backend.parseVCFMetadata(path) + case IMPORT_FAM => + val famPayload = payload.extract[ImportFamPayload] + val path = famPayload.path + val quantPheno = famPayload.quant_pheno + val delimiter = famPayload.delimiter + val missing = famPayload.missing + backend.importFam(path, quantPheno, delimiter, missing) + case FROM_FASTA_FILE => + val fastaPayload = payload.extract[FromFASTAFilePayload] + backend.fromFASTAFile( + fastaPayload.name, + fastaPayload.fasta_file, + fastaPayload.index_file, + fastaPayload.x_contigs, + fastaPayload.y_contigs, + fastaPayload.mt_contigs, + fastaPayload.par + ) } } private[this] def withIRFunctionsReadFromInput( - input: HailSocketAPIInputStream, + serializedFunctions: Array[SerializedIRFunction], ctx: ExecuteContext )( body: () => Array[Byte] ): Array[Byte] = { try { - var nFunctionsRemaining = input.readInt() - while (nFunctionsRemaining > 0) { - val name = input.readString() - - val nTypeParametersRemaining = input.readInt() - val typeParameters = new Array[String](nTypeParametersRemaining) - var i = 0 - while (i < nTypeParametersRemaining) { - typeParameters(i) = input.readString() - i += 1 - } - - val nValueParameterNamesRemaining = input.readInt() - val valueParameterNames = new Array[String](nValueParameterNamesRemaining) - i = 0 - while (i < nValueParameterNamesRemaining) { - valueParameterNames(i) = input.readString() - i += 1 - } - - val nValueParameterTypesRemaining = input.readInt() - val valueParameterTypes = new Array[String](nValueParameterTypesRemaining) - i = 0 - while (i < nValueParameterTypesRemaining) { - valueParameterTypes(i) = input.readString() - i += 1 - } - - val returnType = input.readString() - - val renderedBody = input.readString() - + serializedFunctions.foreach { func => IRFunctionRegistry.pyRegisterIRForServiceBackend( ctx, - name, - typeParameters, - valueParameterNames, - valueParameterTypes, - returnType, - renderedBody + func.name, + func.type_parameters, + func.value_parameter_names, + func.value_parameter_types, + func.return_type, + func.rendered_body ) - nFunctionsRemaining -= 1 } body() } finally { @@ -815,11 +622,9 @@ class ServiceBackendSocketAPI2( } } - def executeOneCommand(): Unit = { - val commandThunk = parseInputToCommandThunk() - + def executeOneCommand(action: Int, payload: JValue): Unit = { try { - val result = commandThunk() + val result = doAction(action, payload) retryTransientErrors { using(fs.createNoCompression(outputURL)) { outputStream => val output = new HailSocketAPIOutputStream(outputStream) diff --git a/hail/src/main/scala/is/hail/backend/service/Worker.scala b/hail/src/main/scala/is/hail/backend/service/Worker.scala index 5a63d3a4071..53619d5a956 100644 --- a/hail/src/main/scala/is/hail/backend/service/Worker.scala +++ b/hail/src/main/scala/is/hail/backend/service/Worker.scala @@ -159,11 +159,11 @@ object Worker { timer.start("executeFunction") if (HailContext.isInitialized) { - HailContext.get.backend = new ServiceBackend(null, null, new HailClassLoader(getClass().getClassLoader()), null, None) + HailContext.get.backend = new ServiceBackend(null, null, new HailClassLoader(getClass().getClassLoader()), null, None, null, null, null, null) } else { HailContext( // FIXME: workers should not have backends, but some things do need hail contexts - new ServiceBackend(null, null, new HailClassLoader(getClass().getClassLoader()), null, None)) + new ServiceBackend(null, null, new HailClassLoader(getClass().getClassLoader()), null, None, null, null, null, null)) } val result = using(new ServiceTaskContext(i)) { htc => diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index d7363b06ae8..5f03e000f9c 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -32,7 +32,8 @@ import org.json4s import org.json4s.jackson.{JsonMethods, Serialization} import org.json4s.{DefaultFormats, Formats} -import java.io.{Closeable, PrintWriter} +import com.sun.net.httpserver.{HttpExchange} +import java.io.{Closeable, PrintWriter, OutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -246,6 +247,7 @@ object SparkBackend { sc1.uiWebUrl.foreach(ui => info(s"SparkUI: $ui")) theSparkBackend = new SparkBackend(tmpdir, localTmpdir, sc1, gcsRequesterPaysProject, gcsRequesterPaysBuckets) + theSparkBackend.addDefaultReferences() theSparkBackend } @@ -358,6 +360,16 @@ class SparkBackend( } ) + def withExecuteContext[T](methodName: String): (ExecuteContext => T) => T = { f => + ExecutionTimer.logTime(methodName) { timer => + ExecuteContext.scoped(tmpdir, tmpdir, this, fs, timer, null, theHailClassLoader, this.references, flags, new BackendContext { + override val executionCache: ExecutionCache = + ExecutionCache.fromFlags(flags, fs, tmpdir) + })(f) + } + } + + def broadcast[T : ClassTag](value: T): BroadcastValue[T] = new SparkBroadcastValue[T](sc.broadcast(value)) override def parallelizeAndComputeWithIndex( @@ -506,11 +518,12 @@ class SparkBackend( } } - def executeLiteral(ir: IR): IR = { - val t = ir.typ - assert(t.isRealizable) + def executeLiteral(irStr: String): Int = { ExecutionTimer.logTime("SparkBackend.executeLiteral") { timer => withExecuteContext(timer) { ctx => + val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + val t = ir.typ + assert(t.isRealizable) val queryID = Backend.nextID() log.info(s"starting execution of query $queryID} of initial size ${ IRSize(ir) }") val retVal = _execute(ctx, ir, true) @@ -519,18 +532,23 @@ class SparkBackend( case Right((pt, addr)) => GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) } log.info(s"finished execution of query $queryID") - literalIR + addJavaIR(literalIR) } } } - def executeJSON(ir: IR): String = { - val (jsonValue, timer) = ExecutionTimer.time("SparkBackend.executeJSON") { timer => - val t = ir.typ - val value = execute(timer, ir, optimize = true) - JsonMethods.compact(JSONAnnotationImpex.exportAnnotation(value, t)) + override def execute(ir: String, timed: Boolean)(consume: (ExecuteContext, Either[Unit, (PTuple, Long)], String) => Unit): Unit = { + withExecuteContext("SparkBackend.execute") { ctx => + val res = ctx.timer.time("execute") { + val irData = IRParser.parse_value_ir(ir, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + val queryID = Backend.nextID() + log.info(s"starting execution of query $queryID of initial size ${ IRSize(irData) }") + _execute(ctx, irData, true) + } + ctx.timer.finish() + val timings = if (timed) Serialization.write(Map("timings" -> ctx.timer.toMap))(new DefaultFormats {}) else "" + consume(ctx, res, timings) } - Serialization.write(Map("value" -> jsonValue, "timings" -> timer.toMap))(new DefaultFormats {}) } def executeEncode(ir: IR, bufferSpecString: String, timed: Boolean): (Array[Byte], String) = { @@ -574,19 +592,22 @@ class SparkBackend( } } - def pyFromDF(df: DataFrame, jKey: java.util.List[String]): TableIR = { + def pyFromDF(df: DataFrame, jKey: java.util.List[String]): (Int, String) = { ExecutionTimer.logTime("SparkBackend.pyFromDF") { timer => val key = jKey.asScala.toArray.toFastSeq val signature = SparkAnnotationImpex.importType(df.schema).setRequired(true).asInstanceOf[PStruct] withExecuteContext(timer, selfContainedExecution = false) { ctx => - TableLiteral(TableValue(ctx, signature.virtualType.asInstanceOf[TStruct], key, df.rdd, Some(signature)), ctx.theHailClassLoader) + val tir = TableLiteral(TableValue(ctx, signature.virtualType.asInstanceOf[TStruct], key, df.rdd, Some(signature)), ctx.theHailClassLoader) + val id = addJavaIR(tir) + (id, JsonMethods.compact(tir.typ.toJSON)) } } } - def pyToDF(tir: TableIR): DataFrame = { + def pyToDF(s: String): DataFrame = { ExecutionTimer.logTime("SparkBackend.pyToDF") { timer => withExecuteContext(timer, selfContainedExecution = false) { ctx => + val tir = IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) Interpret(tir, ctx).toDF() } } @@ -614,14 +635,6 @@ class SparkBackend( matrixReaders.asJava } - def pyLoadReferencesFromDataset(path: String): String = { - val rgs = ReferenceGenome.fromHailDataset(fs, path) - rgs.foreach(addReference) - - implicit val formats: Formats = defaultJSONFormats - Serialization.write(rgs.map(_.toJSON).toFastSeq) - } - def pyAddReference(jsonConfig: String): Unit = addReference(ReferenceGenome.fromJSON(jsonConfig)) def pyRemoveReference(name: String): Unit = removeReference(name) @@ -685,36 +698,34 @@ class SparkBackend( } } - def parse_value_ir(s: String, refMap: java.util.Map[String, String], irMap: java.util.Map[String, BaseIR]): IR = { + def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = { ExecutionTimer.logTime("SparkBackend.parse_value_ir") { timer => withExecuteContext(timer) { ctx => - IRParser.parse_value_ir(s, IRParserEnvironment(ctx, BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*), irMap.asScala.toMap)) + IRParser.parse_value_ir(s, IRParserEnvironment(ctx, BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*), irMap = persistedIR.toMap)) } } } - def parse_table_ir(s: String, irMap: java.util.Map[String, BaseIR]): TableIR = { + def parse_table_ir(s: String): TableIR = { ExecutionTimer.logTime("SparkBackend.parse_table_ir") { timer => withExecuteContext(timer, selfContainedExecution = false) { ctx => - IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = irMap.asScala.toMap)) + IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) } } } - def parse_matrix_ir(s: String, irMap: java.util.Map[String, BaseIR]): MatrixIR = { + def parse_matrix_ir(s: String): MatrixIR = { ExecutionTimer.logTime("SparkBackend.parse_matrix_ir") { timer => withExecuteContext(timer, selfContainedExecution = false) { ctx => - IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = irMap.asScala.toMap)) + IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) } } } - def parse_blockmatrix_ir( - s: String, irMap: java.util.Map[String, BaseIR] - ): BlockMatrixIR = { + def parse_blockmatrix_ir(s: String): BlockMatrixIR = { ExecutionTimer.logTime("SparkBackend.parse_blockmatrix_ir") { timer => withExecuteContext(timer, selfContainedExecution = false) { ctx => - IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = irMap.asScala.toMap)) + IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) } } } @@ -754,9 +765,6 @@ class SparkBackend( RVDTableReader(RVD.unkeyed(rowPType, orderedCRDD), globalsLit, rt) } - def pyImportFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String): String = - LoadPlink.importFamJSON(fs, path, isQuantPheno, delimiter, missingValue) - def close(): Unit = { longLifeTempFileManager.cleanup() } diff --git a/hail/src/main/scala/is/hail/expr/ir/Parser.scala b/hail/src/main/scala/is/hail/expr/ir/Parser.scala index 3983d152d01..a2c30a0e6de 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Parser.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Parser.scala @@ -127,7 +127,7 @@ object IRLexer extends JavaTokenParsers { case class IRParserEnvironment( ctx: ExecuteContext, refMap: BindingEnv[Type] = BindingEnv.empty[Type], - irMap: Map[String, BaseIR] = Map.empty, + irMap: Map[Int, BaseIR] = Map.empty, ) { def promoteAgg: IRParserEnvironment = copy(refMap = refMap.promoteAgg) @@ -1529,8 +1529,8 @@ object IRParser { dynamicID <- ir_value_expr(env)(it) } yield CollectDistributedArray(ctxs, globals, cname, gname, body, dynamicID, staticID) case "JavaIR" => - val name = identifier(it) - done(env.irMap(name).asInstanceOf[IR]) + val id = int32_literal(it) + done(env.irMap(id).asInstanceOf[IR]) case "ReadPartition" => val requestedTypeRaw = it.head match { case x: IdentifierToken if x.value == "None" || x.value == "DropRowUIDs" => @@ -1795,8 +1795,8 @@ object IRParser { body <- table_ir(env.onlyRelational.bindRelational(name, value.typ))(it) } yield RelationalLetTable(name, value, body) case "JavaTable" => - val name = identifier(it) - done(env.irMap(name).asInstanceOf[TableIR]) + val id = int32_literal(it) + done(env.irMap(id).asInstanceOf[TableIR]) } } @@ -2002,9 +2002,6 @@ object IRParser { value <- ir_value_expr(env.onlyRelational)(it) body <- matrix_ir(env.onlyRelational.bindRelational(name, value.typ))(it) } yield RelationalLetMatrixTable(name, value, body) - case "JavaMatrix" => - val name = identifier(it) - done(env.irMap(name).asInstanceOf[MatrixIR]) } } @@ -2141,9 +2138,6 @@ object IRParser { value <- ir_value_expr(env.onlyRelational)(it) body <- blockmatrix_ir(env.onlyRelational.bindRelational(name, value.typ))(it) } yield RelationalLetBlockMatrix(name, value, body) - case "JavaBlockMatrix" => - val name = identifier(it) - done(env.irMap(name).asInstanceOf[BlockMatrixIR]) } } diff --git a/hail/src/main/scala/is/hail/io/CodecSpec.scala b/hail/src/main/scala/is/hail/io/CodecSpec.scala index 65d14ded8db..671959eaf83 100644 --- a/hail/src/main/scala/is/hail/io/CodecSpec.scala +++ b/hail/src/main/scala/is/hail/io/CodecSpec.scala @@ -35,13 +35,17 @@ trait AbstractTypedCodecSpec extends Spec { def encode(ctx: ExecuteContext, t: PType, offset: Long): Array[Byte] = { val baos = new ByteArrayOutputStream() - using(buildEncoder(ctx, t)(baos, ctx.theHailClassLoader))(_.writeRegionValue(offset)) + encode(ctx, t, offset, baos) baos.toByteArray } + def encode(ctx: ExecuteContext, t: PType, offset: Long, os: OutputStream): Unit = { + using(buildEncoder(ctx, t)(os, ctx.theHailClassLoader))(_.writeRegionValue(offset)) + } + def encodeArrays(ctx: ExecuteContext, t: PType, offset: Long): Array[Array[Byte]] = { val baos = new ArrayOfByteArrayOutputStream() - using(buildEncoder(ctx, t)(baos, ctx.theHailClassLoader))(_.writeRegionValue(offset)) + encode(ctx, t, offset, baos) baos.toByteArrays() } diff --git a/hail/src/main/scala/is/hail/types/TableType.scala b/hail/src/main/scala/is/hail/types/TableType.scala index c1a370c9030..2defa934819 100644 --- a/hail/src/main/scala/is/hail/types/TableType.scala +++ b/hail/src/main/scala/is/hail/types/TableType.scala @@ -5,6 +5,8 @@ import is.hail.types.physical.{PStruct, PType} import is.hail.types.virtual.{TStruct, Type} import is.hail.rvd.RVDType import is.hail.utils._ + +import org.json4s._ import org.json4s.CustomSerializer import org.json4s.JsonAST.JString @@ -78,4 +80,12 @@ case class TableType(rowType: TStruct, key: IndexedSeq[String], globalType: TStr newline() sb += '}' } + + def toJSON: JObject = { + JObject( + "global_type" -> JString(globalType.toString), + "row_type" -> JString(rowType.toString), + "row_key" -> JArray(key.map(f => JString(f)).toList) + ) + } } diff --git a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala index a057b1e4b74..852e95f01ed 100644 --- a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala @@ -3101,27 +3101,18 @@ class IRSuite extends HailSuite { @Test def testCachedIR() { val cached = Literal(TSet(TInt32), Set(1)) - val s = s"(JavaIR __uid1)" + val s = s"(JavaIR 1)" val x2 = ExecuteContext.scoped() { ctx => - IRParser.parse_value_ir(s, IRParserEnvironment(ctx, irMap = Map("__uid1" -> cached))) + IRParser.parse_value_ir(s, IRParserEnvironment(ctx, irMap = Map(1 -> cached))) } assert(x2 eq cached) } @Test def testCachedTableIR() { val cached = TableRange(1, 1) - val s = s"(JavaTable __uid1)" + val s = s"(JavaTable 1)" val x2 = ExecuteContext.scoped() { ctx => - IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = Map("__uid1" -> cached))) - } - assert(x2 eq cached) - } - - @Test def testCachedMatrixIR() { - val cached = MatrixIR.range(3, 7, None) - val s = s"(JavaMatrix __uid1)" - val x2 = ExecuteContext.scoped() { ctx => - IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = Map("__uid1" -> cached))) + IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = Map(1 -> cached))) } assert(x2 eq cached) }