Skip to content

Commit

Permalink
[query] fix & improve pprint for hl.Struct (#12901)
Browse files Browse the repository at this point in the history
CHANGELOG: `hl.Struct` now has a correct and useful implementation of
pprint.

For structs with keys that were not identifiers, we produced incorrect
`repr` output. For `pprint`, we just `pprint`'ed a dictionary (so you
cannot even tell that the object was an `hl.Struct`). This PR:

1. Fixes `hl.Struct.__str__` to use the kwargs or dictionary
representation based on whether the keys are Python identifiers.
2. Teaches `StructPrettyPrinter` to first try to `repr` the struct (this
is what the default pretty printer does)
3. Teaches `StructPrettyPrinter` to properly pretty print a struct as an
`hl.Struct` preferring the kwarg representation when appropriate.
4. Teaches `_same` to use pretty printing when showing differing
records.
  • Loading branch information
danking authored Apr 28, 2023
1 parent 1389874 commit 948b1d9
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 6 deletions.
5 changes: 3 additions & 2 deletions hail/python/hail/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas
import numpy as np
import pyspark
import pprint
from typing import Optional, Dict, Callable, Sequence, Union

from hail.expr.expressions import Expression, StructExpression, \
Expand Down Expand Up @@ -3669,7 +3670,7 @@ def _same(self, other, tolerance=1e-6, absolute=False, reorder_fields=False):

if not hl.eval(_values_similar(t[left_global_value], t[right_global_value], tolerance, absolute)):
g = hl.eval(t.globals)
print(f'Table._same: globals differ: {g[left_global_value]}, {g[right_global_value]}')
print(f'Table._same: globals differ:\n{pprint.pformat(g[left_global_value])}\n{pprint.pformat(g[right_global_value])}')
return False

if not t.all(hl.is_defined(t[left_value]) & hl.is_defined(t[right_value])
Expand All @@ -3678,7 +3679,7 @@ def _same(self, other, tolerance=1e-6, absolute=False, reorder_fields=False):
t = t.filter(~ _values_similar(t[left_value], t[right_value], tolerance, absolute))
bad_rows = t.take(10)
for r in bad_rows:
print(f' Row mismatch at key={r._key}:\n L: {r[left_value]}\n R: {r[right_value]}')
print(f' Row mismatch at key={r._key}:\n Left:\n{pprint.pformat(r[left_value])}\n Right:\n{pprint.pformat(r[right_value])}')
return False

return True
Expand Down
58 changes: 54 additions & 4 deletions hail/python/hail/utils/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,17 @@ def __repr__(self):
return str(self)

def __str__(self):
return 'Struct({})'.format(', '.join('{}={}'.format(k, repr(v)) for k, v in self._fields.items()))
if all(k.isidentifier() for k in self._fields):
return (
'Struct('
+ ', '.join(f'{k}={repr(v)}' for k, v in self._fields.items())
+ ')'
)
return (
'Struct(**{'
+ ', '.join(f'{repr(k)}: {repr(v)}' for k, v in self._fields.items())
+ '})'
)

def __eq__(self, other):
return isinstance(other, Struct) and self._fields == other._fields
Expand Down Expand Up @@ -241,10 +251,50 @@ def to_dict(struct):


class StructPrettyPrinter(pprint.PrettyPrinter):
def _format(self, obj, *args, **kwargs):
def _format(self, obj, stream, indent, allowance, context, level, *args, **kwargs):
if isinstance(obj, Struct):
obj = to_dict(obj)
return _old_printer._format(self, obj, *args, **kwargs)
rep = self._repr(obj, context, level)
max_width = self._width - indent - allowance
if len(rep) <= max_width:
stream.write(rep)
return

stream.write('Struct(')
indent += len('Struct(')
if all(k.isidentifier() for k in obj):
n = len(obj.items())
for i, (k, v) in enumerate(obj.items()):
is_first = i == 0
is_last = i == n - 1

if not is_first:
stream.write(' ' * indent)
stream.write(k)
stream.write('=')
this_indent = indent + len(k) + len('=')
self._format(v, stream, this_indent, allowance, context, level, *args, **kwargs)
if not is_last:
stream.write(',\n')
else:
stream.write('**{')
indent += len('**{')
n = len(obj.items())
for i, (k, v) in enumerate(obj.items()):
is_first = i == 0
is_last = i == n - 1

if not is_first:
stream.write(' ' * indent)
stream.write(repr(k))
stream.write(': ')
this_indent = indent + len(repr(k)) + len(': ')
self._format(v, stream, this_indent, allowance, context, level, *args, **kwargs)
if not is_last:
stream.write(',\n')
stream.write('}')
stream.write(')')
else:
_old_printer._format(self, obj, stream, indent, allowance, context, level, *args, **kwargs)


pprint.PrettyPrinter = StructPrettyPrinter # monkey-patch pprint

0 comments on commit 948b1d9

Please sign in to comment.