Skip to content

Commit

Permalink
[query] Avoid py4j for python-backend interactions (#13797)
Browse files Browse the repository at this point in the history
CHANGELOG: Fixes #13756: operations that collect large results such as
`to_pandas` may require up to 3x less memory.

This turns all "actions", i.e. backend methods supported by QoB into
HTTP endpoints on the spark and local backends. This intentionally
avoids py4j because py4j was really designed to pass function names and
references around and does not handle large payloads well (such as
results from a `collect`). Specifically, py4j uses a text-based protocol
on top of TCP that substantially inflates the memory requirement for
communicating large byte arrays. On the Java side, py4j serializes every
binary payload as a Base64-encoded `java.lang.String`, which between the
Base64 encoding and `String`'s use of UTF-16 results in a memory
footprint of the `String` being `4/3 * 2 = 8/3` nearly three times the
size of the byte array on either side of the py4j pipe. py4j also
appears to do an entire copy of this payload, which means nearly a 6x
memory requirement for sending back bytes. Using our own socket means we
can directly send back the response bytes to python without any of this
overhead, even going so far as to encode results directly into the TCP
output stream. Formalizing the API between python and java also allows
us to reuse the same payload schema across all three backends.
  • Loading branch information
daniel-goldstein authored Oct 20, 2023
1 parent 209404e commit c73386f
Show file tree
Hide file tree
Showing 31 changed files with 1,054 additions and 1,352 deletions.
118 changes: 98 additions & 20 deletions hail/python/hail/backend/backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Mapping, List, Union, TypeVar, Tuple, Dict, Optional, Any, AbstractSet
import abc
from enum import Enum
from dataclasses import dataclass
import warnings
import orjson
import pkg_resources
Expand All @@ -10,8 +12,12 @@

from ..builtin_references import BUILTIN_REFERENCE_RESOURCE_PATHS
from ..expr import Expression
from ..expr.types import HailType
from ..ir import BaseIR
from ..expr.table_type import ttable
from ..expr.matrix_type import tmatrix
from ..expr.blockmatrix_type import tblockmatrix
from ..expr.types import HailType, dtype
from ..ir import BaseIR, finalize_randomness
from ..ir.renderer import CSERenderer
from ..linalg.blockmatrix import BlockMatrix
from ..matrixtable import MatrixTable
from ..table import Table
Expand Down Expand Up @@ -70,6 +76,57 @@ def local_jar_information() -> LocalJarInformation:
)


class ActionTag(Enum):
LOAD_REFERENCES_FROM_DATASET = 1
VALUE_TYPE = 2
TABLE_TYPE = 3
MATRIX_TABLE_TYPE = 4
BLOCK_MATRIX_TYPE = 5
EXECUTE = 6
PARSE_VCF_METADATA = 7
IMPORT_FAM = 8
FROM_FASTA_FILE = 9

@dataclass
class ActionPayload:
pass

@dataclass
class IRTypePayload(ActionPayload):
ir: str

@dataclass
class ExecutePayload(ActionPayload):
ir: str
stream_codec: str
timed: bool

@dataclass
class LoadReferencesFromDatasetPayload(ActionPayload):
path: str

@dataclass
class ParseVCFMetadataPayload(ActionPayload):
path: str

@dataclass
class ImportFamPayload(ActionPayload):
path: str
quant_pheno: bool
delimiter: str
missing: str

@dataclass
class FromFASTAFilePayload(ActionPayload):
name: str
fasta_file: str
index_file: str
x_contigs: List[str]
y_contigs: List[str]
mt_contigs: List[str]
par: List[str]


class Backend(abc.ABC):
# Must match knownFlags in HailFeatureFlags.scala
_flags_env_vars_and_defaults: Dict[str, Tuple[str, Optional[str]]] = {
Expand Down Expand Up @@ -115,33 +172,52 @@ def validate_file(self, uri: str):
def stop(self):
pass

@abc.abstractmethod
def execute(self, ir: BaseIR, timed: bool = False) -> Any:
pass
payload = ExecutePayload(self._render_ir(ir), '{"name":"StreamBufferSpec"}', timed)
try:
result, timings = self._rpc(ActionTag.EXECUTE, payload)
except FatalError as e:
raise e.maybe_user_error(ir) from None
value = ir.typ._from_encoding(result)
return (value, timings) if timed else value

@abc.abstractmethod
async def _async_execute(self, ir, timed=False):
def _rpc(self, action: ActionTag, payload: ActionPayload) -> Tuple[bytes, str]:
pass

@abc.abstractmethod
def _render_ir(self, ir):
r = CSERenderer()
return r(finalize_randomness(ir))

def value_type(self, ir):
pass
payload = IRTypePayload(ir=self._render_ir(ir))
type_bytes, _ = self._rpc(ActionTag.VALUE_TYPE, payload)
return dtype(type_bytes.decode('utf-8'))

@abc.abstractmethod
def table_type(self, tir):
pass
payload = IRTypePayload(ir=self._render_ir(tir))
type_bytes, _ = self._rpc(ActionTag.TABLE_TYPE, payload)
return ttable._from_json(orjson.loads(type_bytes))

@abc.abstractmethod
def matrix_type(self, mir):
pass
payload = IRTypePayload(ir=self._render_ir(mir))
type_bytes, _ = self._rpc(ActionTag.MATRIX_TABLE_TYPE, payload)
return tmatrix._from_json(orjson.loads(type_bytes))

def blockmatrix_type(self, bmir):
payload = IRTypePayload(ir=self._render_ir(bmir))
type_bytes, _ = self._rpc(ActionTag.BLOCK_MATRIX_TYPE, payload)
return tblockmatrix._from_json(orjson.loads(type_bytes))

@abc.abstractmethod
def load_references_from_dataset(self, path):
pass
payload = LoadReferencesFromDatasetPayload(path=path)
references_json_bytes, _ = self._rpc(ActionTag.LOAD_REFERENCES_FROM_DATASET, payload)
return orjson.loads(references_json_bytes)

@abc.abstractmethod
def from_fasta_file(self, name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par):
pass
payload = FromFASTAFilePayload(name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par)
rg_json_bytes, _ = self._rpc(ActionTag.FROM_FASTA_FILE, payload)
return orjson.loads(rg_json_bytes)

def add_reference(self, rg):
self._references[rg.name] = rg
Expand Down Expand Up @@ -184,9 +260,10 @@ def add_liftover(self, name, chain_file, dest_reference_genome):
def remove_liftover(self, name, dest_reference_genome):
pass

@abc.abstractmethod
def parse_vcf_metadata(self, path):
pass
payload = ParseVCFMetadataPayload(path)
metadata_json_bytes, _ = self._rpc(ActionTag.PARSE_VCF_METADATA, payload)
return orjson.loads(metadata_json_bytes)

@property
@abc.abstractmethod
Expand All @@ -198,9 +275,10 @@ def logger(self):
def fs(self) -> FS:
pass

@abc.abstractmethod
def import_fam(self, path: str, quant_pheno: bool, delimiter: str, missing: str):
pass
payload = ImportFamPayload(path, quant_pheno, delimiter, missing)
fam_json_bytes, _ = self._rpc(ActionTag.IMPORT_FAM, payload)
return orjson.loads(fam_json_bytes)

def persist(self, dataset: Dataset) -> Dataset:
from hail.context import TemporaryFilename
Expand Down Expand Up @@ -242,7 +320,7 @@ def _initialize_flags(self, initial_flags: Dict[str, str]) -> None:
}, **initial_flags)

@abc.abstractmethod
def set_flags(self, **flags: Mapping[str, str]):
def set_flags(self, **flags: str):
"""Set Hail flags."""
pass

Expand Down
Loading

0 comments on commit c73386f

Please sign in to comment.