Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] Use EncodedLiteral instead of Literal from python to scala #13814

Merged
merged 40 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
6ecc90b
[query] Use EncodedLiteral instead of Literal from python to scala
daniel-goldstein Oct 12, 2023
8f2ed10
make encoding/decoding void an error
daniel-goldstein Oct 14, 2023
00a0f42
explicitly decode tstr as utf-8
daniel-goldstein Oct 14, 2023
8c9e459
implement call
daniel-goldstein Oct 14, 2023
cc697d1
dont try to decode void results
daniel-goldstein Oct 14, 2023
9080922
set requiredness on decoded ndarray element types
daniel-goldstein Oct 16, 2023
9cb4764
change name of fromTypeAllOptional to fromPythonTypeEncoding
daniel-goldstein Oct 16, 2023
2c339e6
transposition encodes ok
daniel-goldstein Oct 16, 2023
dc3b935
encoded literals are constants too
daniel-goldstein Oct 16, 2023
272b511
some more tests
daniel-goldstein Oct 16, 2023
4f72864
fix ndarrays of zero size
daniel-goldstein Oct 17, 2023
1471073
fix copying of EncodedLiteral nodes
daniel-goldstein Oct 17, 2023
4717c9f
v does not exist
daniel-goldstein Oct 17, 2023
134bc2d
Merge branch 'main' into encode-literals-python-to-scala
daniel-goldstein Oct 17, 2023
5c89d8a
add EncodedLiteral to ExtractIntervalFilters
daniel-goldstein Oct 17, 2023
39de392
ArrayMaximalIndependentSet is not Interpretable
daniel-goldstein Oct 18, 2023
2559a3e
some progress maybe
daniel-goldstein Oct 18, 2023
bff6d68
more wip
daniel-goldstein Oct 18, 2023
2813af2
wip
daniel-goldstein Oct 18, 2023
3d344ee
add that file that i forgot
daniel-goldstein Oct 18, 2023
b5719f9
do the sorting thing
daniel-goldstein Oct 19, 2023
2a9aea0
memoize
daniel-goldstein Oct 19, 2023
4a26ee5
fix missingness
daniel-goldstein Oct 19, 2023
4d34009
clean up cruft
daniel-goldstein Oct 19, 2023
4abf0d5
not sure if this actually does anything
daniel-goldstein Oct 19, 2023
64624fc
allocate the original array in the destination region
daniel-goldstein Oct 20, 2023
def2154
more tests
daniel-goldstein Oct 20, 2023
20d0aca
wip
daniel-goldstein Oct 20, 2023
323553c
Merge branch 'main' into encode-literals-python-to-scala
daniel-goldstein Oct 23, 2023
e49630c
fix
daniel-goldstein Oct 23, 2023
77c21d0
bring back wirespec
daniel-goldstein Oct 23, 2023
a88f5a7
debugging
daniel-goldstein Oct 24, 2023
630b448
test
daniel-goldstein Oct 24, 2023
c5e86e4
wip
daniel-goldstein Oct 24, 2023
353dea8
python and scala must agree that dict elements are required
daniel-goldstein Oct 24, 2023
3501bd7
comment out some debugging information
daniel-goldstein Oct 24, 2023
761025d
add unsorted set etype for sets coming from python
daniel-goldstein Oct 25, 2023
1177e1d
fix bug in TypeInfo
daniel-goldstein Oct 25, 2023
e24fe97
remove a bunch of debugging prints
daniel-goldstein Oct 25, 2023
d47433a
fix
daniel-goldstein Oct 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions hail/python/hail/backend/py4j_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from hail.expr.blockmatrix_type import tblockmatrix
from hail.expr.matrix_type import tmatrix
from hail.expr.table_type import ttable
from hail.expr.types import dtype
from hail.expr.types import dtype, tvoid

from .backend import Backend, fatal_error_from_java_error_triplet

Expand Down Expand Up @@ -75,7 +75,11 @@ def execute(self, ir, timed=False):
try:
result_tuple = self._jbackend.executeEncode(jir, stream_codec, timed)
(result, timings) = (result_tuple._1(), result_tuple._2())
value = ir.typ._from_encoding(result)

if ir.typ == tvoid:
value = None
else:
value = ir.typ._from_encoding(result)

return (value, timings) if timed else value
except FatalError as e:
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def typecheck_expr(t, x):
assert isinstance(x, builtins.str)
return construct_expr(ir.Str(x), tstr)
else:
return construct_expr(ir.Literal(dtype, x), dtype)
return construct_expr(ir.EncodedLiteral(dtype, x), dtype)


@deprecated(version="0.2.59", reason="Replaced by hl.if_else")
Expand Down
147 changes: 141 additions & 6 deletions hail/python/hail/expr/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .. import genetics
from ..typecheck import typecheck, typecheck_method, oneof, transformed, nullable
from ..utils.struct import Struct
from ..utils.byte_reader import ByteReader
from ..utils.byte_reader import ByteReader, ByteWriter
from ..utils.misc import lookup_bit
from ..utils.java import escape_parsable
from ..genetics.reference_genome import reference_genome_type
Expand Down Expand Up @@ -275,9 +275,17 @@ def _convert_from_json(self, x, _should_freeze: bool = False):
def _from_encoding(self, encoding):
return self._convert_from_encoding(ByteReader(memoryview(encoding)))

def _to_encoding(self, value) -> bytes:
buf = bytearray()
self._convert_to_encoding(ByteWriter(buf), value)
return bytes(buf)

def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False):
raise ValueError("Not implemented yet")

def _convert_to_encoding(self, byte_writer, value):
raise ValueError("Not implemented yet")

def _traverse(self, obj, f):
"""Traverse a nested type and object.

Expand Down Expand Up @@ -339,8 +347,11 @@ def subst(self):
def clear(self):
pass

def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False):
return None
def _convert_from_encoding(self, *_):
raise ValueError("Cannot decode void type")

def _convert_to_encoding(self, *_):
daniel-goldstein marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Cannot encode void type")


class _tint32(HailType):
Expand Down Expand Up @@ -395,6 +406,9 @@ def to_numpy(self):
def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False) -> int:
return byte_reader.read_int32()

def _convert_to_encoding(self, byte_writer: ByteWriter, value):
byte_writer.write_int32(value)

def _byte_size(self):
return 4

Expand Down Expand Up @@ -450,6 +464,9 @@ def to_numpy(self):
def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False) -> int:
return byte_reader.read_int64()

def _convert_to_encoding(self, byte_writer: ByteWriter, value):
byte_writer.write_int64(value)

def _byte_size(self):
return 8

Expand Down Expand Up @@ -488,6 +505,9 @@ def _convert_to_json(self, x):
def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False) -> float:
return byte_reader.read_float32()

def _convert_to_encoding(self, byte_writer: ByteWriter, value):
byte_writer.write_float32(value)

def unify(self, t):
return t == tfloat32

Expand Down Expand Up @@ -550,6 +570,9 @@ def to_numpy(self):
def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False) -> float:
return byte_reader.read_float64()

def _convert_to_encoding(self, byte_writer: ByteWriter, value):
byte_writer.write_float64(value)

def _byte_size(self):
return 8

Expand Down Expand Up @@ -587,10 +610,15 @@ def clear(self):

def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False) -> str:
length = byte_reader.read_int32()
str_literal = byte_reader.read_bytes(length).decode()
str_literal = byte_reader.read_bytes(length).decode('utf-8')

daniel-goldstein marked this conversation as resolved.
Show resolved Hide resolved
return str_literal

def _convert_to_encoding(self, byte_writer: ByteWriter, value):
value_bytes = value.encode('utf-8')
byte_writer.write_int32(len(value_bytes))
byte_writer.write_bytes(value_bytes)


class _tbool(HailType):
"""Hail type for Boolean (``True`` or ``False``) values.
Expand Down Expand Up @@ -632,6 +660,8 @@ def _byte_size(self):
def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False) -> bool:
return byte_reader.read_bool()

def _convert_to_encoding(self, byte_writer: ByteWriter, value):
byte_writer.write_bool(value)

class _trngstate(HailType):

Expand Down Expand Up @@ -789,6 +819,17 @@ def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False) -> n
np_type = self.element_type.to_numpy()
return np.ndarray(shape=shape, buffer=np.array(elements, dtype=np_type), dtype=np_type, order="F")

def _convert_to_encoding(self, byte_writer, value: np.ndarray):
for dim in value.shape:
byte_writer.write_int64(dim)

if value.size > 0:
if self.element_type in _numeric_types:
byte_writer.write_bytes(value.data)
else:
for elem in np.nditer(value, order='F'):
self.element_type._convert_to_encoding(byte_writer, elem)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel good about this now. You're using ENDArrayColumnMajor and the column ordering is Fortran, so ti all checks out.



class tarray(HailType):
"""Hail type for variable-length arrays of elements.
Expand Down Expand Up @@ -900,6 +941,23 @@ def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False) -> U
return decoded


def _convert_to_encoding(self, byte_writer: ByteWriter, value):
length = len(value)
byte_writer.write_int32(length)
i = 0
while i < length:
missing_byte = 0
for j in range(min(8, length - i)):
if value[i + j] is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if value[i + j] is None:
if value[i + j] in (None, pd.NA):

missing_byte |= 1 << j
byte_writer.write_byte(missing_byte)
i += 8

for element in value:
if element is not None:
self.element_type._convert_to_encoding(byte_writer, element)


class tstream(HailType):
@typecheck_method(element_type=hail_type)
def __init__(self, element_type):
Expand Down Expand Up @@ -1038,6 +1096,9 @@ def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False) -> U
return frozenset(s)
return set(s)

def _convert_to_encoding(self, byte_writer: ByteWriter, value):
self._array_repr._convert_to_encoding(byte_writer, list(value))

def _propagate_jtypes(self, jtype):
self._element_type._add_jtype(jtype.elementType())

Expand All @@ -1064,6 +1125,9 @@ def _convert_from_json_na(self, x, _should_freeze: bool = False):
def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False):
return self.t._convert_from_encoding(byte_reader, _should_freeze=True)

def _convert_to_encoding(self, byte_writer, x):
return self.t._convert_to_encoding(byte_writer, x)


class tdict(HailType):
"""Hail type for key-value maps.
Expand Down Expand Up @@ -1166,6 +1230,10 @@ def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False) -> U
return frozendict(d)
return d

def _convert_to_encoding(self, byte_writer: ByteWriter, value):
array_of_pairs = [{'key': k, 'value': v} for k, v in value.items()]
self._array_repr._convert_to_encoding(byte_writer, array_of_pairs)

def _propagate_jtypes(self, jtype):
self._key_type._add_jtype(jtype.keyType())
self._value_type._add_jtype(jtype.valueType())
Expand Down Expand Up @@ -1332,7 +1400,7 @@ def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False) -> S
kwargs = {}

current_missing_byte = None
for i, (f, t) in enumerate(self._field_types.items()):
for i, (f, t) in enumerate(self.items()):
which_missing_bit = i % 8
if which_missing_bit == 0:
current_missing_byte = missing_bytes[i // 8]
Expand All @@ -1345,6 +1413,22 @@ def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False) -> S

return Struct(**kwargs)

def _convert_to_encoding(self, byte_writer: ByteWriter, value):
keys = list(self.keys())
length = len(keys)
i = 0
while i < length:
missing_byte = 0
for j in range(min(8, length - i)):
if value[keys[i + j]] is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here for missing values from pandas.

missing_byte |= 1 << j
byte_writer.write_byte(missing_byte)
i += 8

for f, t in self.items():
if value[f] is not None:
t._convert_to_encoding(byte_writer, value[f])

def _is_prefix_of(self, other):
return (isinstance(other, tstruct)
and len(self._fields) <= len(other._fields)
Expand Down Expand Up @@ -1619,6 +1703,20 @@ def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False) -> t

return tuple(answer)

def _convert_to_encoding(self, byte_writer, value):
length = len(self)
i = 0
while i < length:
missing_byte = 0
for j in range(min(8, length - i)):
if value[i + j] is None:
missing_byte |= 1 << j
byte_writer.write_byte(missing_byte)
i += 8
for i, t in enumerate(self.types):
if value[i] is not None:
t._convert_to_encoding(byte_writer, value[i])

def unify(self, t):
if not (isinstance(t, ttuple) and len(self.types) == len(t.types)):
return False
Expand All @@ -1638,7 +1736,7 @@ def _get_context(self):
return HailTypeContext.union(*self.types)


def allele_pair(j, k):
def allele_pair(j: int, k: int):
assert j >= 0 and j <= 0xffff
assert k >= 0 and k <= 0xffff
return j | (k << 16)
Expand Down Expand Up @@ -1754,6 +1852,31 @@ def call_allele_pair(i):

return genetics.Call(alleles, phased)

def _convert_to_encoding(self, byte_writer, value: genetics.Call):
int_rep = 0

int_rep |= value.ploidy << 1
if value.phased:
int_rep |= 1

def diploid_gt_index(j: int, k: int):
assert j <= k
return k * (k + 1) // 2 + j

def allele_pair_rep(c: genetics.Call):
[j, k] = c.alleles
if c.phased:
return diploid_gt_index(j, j + k)
return diploid_gt_index(j, k)

assert value.ploidy <= 2
if value.ploidy == 1:
int_rep |= value.alleles[0] << 3
elif value.ploidy == 2:
int_rep |= allele_pair_rep(value) << 3

byte_writer.write_int32(int_rep)

def unify(self, t):
return t == tcall

Expand Down Expand Up @@ -1839,6 +1962,9 @@ def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False) -> g
as_struct = tlocus.struct_repr._convert_from_encoding(byte_reader)
return genetics.Locus(as_struct.contig, as_struct.pos, self.reference_genome)

def _convert_to_encoding(self, byte_writer, value: genetics.Locus):
tlocus.struct_repr._convert_to_encoding(byte_writer, {'contig': value.contig, 'pos': value.position})

def unify(self, t):
return isinstance(t, tlocus) and self.reference_genome == t.reference_genome

Expand Down Expand Up @@ -1936,6 +2062,15 @@ def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False):
interval_as_struct.includes_end,
point_type=self.point_type)

def _convert_to_encoding(self, byte_writer, value):
interval_dict = {
'start': value.start,
'end': value.end,
'includes_start': value.includes_start,
'includes_end': value.includes_end,
}
self._struct_repr._convert_to_encoding(byte_writer, interval_dict)

def unify(self, t):
return isinstance(t, tinterval) and self.point_type.unify(t.point_type)

Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/genetics/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __getitem__(self, item):
return self._alleles[item]

@property
def alleles(self):
def alleles(self) -> Sequence[int]:
"""Get the alleles of this call.

Returns
Expand Down
3 changes: 2 additions & 1 deletion hail/python/hail/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .ir import MatrixWrite, MatrixMultiWrite, BlockMatrixWrite, \
BlockMatrixMultiWrite, TableToValueApply, \
MatrixToValueApply, BlockMatrixToValueApply, BlockMatrixCollect, \
Literal, LiftMeOut, Join, JavaIR, I32, I64, F32, F64, Str, FalseIR, TrueIR, \
Literal, EncodedLiteral, LiftMeOut, Join, JavaIR, I32, I64, F32, F64, Str, FalseIR, TrueIR, \
Void, Cast, NA, IsNA, If, Coalesce, Let, AggLet, Ref, TopLevelReference, ProjectedTopLevelReference, SelectedTopLevelReference, \
TailLoop, Recur, ApplyBinaryPrimOp, ApplyUnaryPrimOp, ApplyComparisonOp, \
MakeArray, ArrayRef, ArraySlice, ArrayLen, ArrayZeros, StreamIota, StreamRange, StreamGrouped, MakeNDArray, \
Expand Down Expand Up @@ -227,6 +227,7 @@
'MatrixToValueApply',
'BlockMatrixToValueApply',
'Literal',
'EncodedLiteral',
'LiftMeOut',
'Join',
'JavaIR',
Expand Down
29 changes: 29 additions & 0 deletions hail/python/hail/ir/ir.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable, TypeVar, cast
from typing_extensions import ParamSpec
import base64
import copy
import json
from collections import defaultdict
Expand Down Expand Up @@ -3644,6 +3645,34 @@ def _compute_type(self, env, agg_env, deep_typecheck):
return self._typ


class EncodedLiteral(IR):
@typecheck_method(typ=hail_type, value=anytype, encoded_value=nullable(str))
def __init__(self, typ, value, *, encoded_value = None):
super(EncodedLiteral, self).__init__()
self._typ: HailType = typ
self._value = value
self._encoded_value = encoded_value

@property
def encoded_value(self):
if self._encoded_value is None:
self._encoded_value = base64.b64encode(self._typ._to_encoding(self._value)).decode('utf-8')
return self._encoded_value

def copy(self):
return EncodedLiteral(self._typ, self._value, encoded_value=self._encoded_value)

def head_str(self):
return f'{self._typ._parsable_string()} "{self.encoded_value}"'

def _eq(self, other):
return other._typ == self._typ and \
other.encoded_value == self.encoded_value

def _compute_type(self, env, agg_env, deep_typecheck):
return self._typ


class LiftMeOut(IR):
@typecheck_method(child=IR)
def __init__(self, child):
Expand Down
Loading