Skip to content

Commit

Permalink
[vds/combiner] Stop dropping GT in reference data during gvcf import (#…
Browse files Browse the repository at this point in the history
…14560)

CHANGELOG: The gvcf import stage of the VDS combiner now preserves the
GT of reference blocks. Some datasets have haploid calls on sex
chromosomes, and the fact that the reference was haploid should be
preserved.
  • Loading branch information
chrisvittal committed Jul 30, 2024
1 parent 56138e8 commit cc26435
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 21 deletions.
25 changes: 16 additions & 9 deletions hail/python/hail/vds/combiner/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,14 @@ def make_reference_stream(stream, entry_to_keep: Collection[str]):

def make_entry_struct(e, row):
handled_fields = dict()
# we drop PL by default, but if `entry_to_keep` has it then PL needs to be
# turned into LPL
handled_names = {'AD', 'PL'}

# we drop PL/PGT by default, but if `entry_to_keep` has them, we need to
# convert them to local versions for consistency.
handled_names = {'AD', 'GT', 'PGT', 'PL'}

if 'GT' in entry_to_keep:
handled_fields['LGT'] = e['GT']
if 'PGT' in entry_to_keep:
handled_fields['LPGT'] = e['PGT']
if 'AD' in entry_to_keep:
handled_fields['LAD'] = e['AD'][:1]
if 'PL' in entry_to_keep:
Expand Down Expand Up @@ -230,7 +234,6 @@ def make_entry_struct(e, alleles_len, has_non_ref, row):
)
_transform_variant_function_map[row_type, info_key] = transform_row

from hail.expr import construct_expr
from hail.utils.java import Env

uid = Env.get_uid()
Expand All @@ -250,10 +253,14 @@ def make_reference_matrix_table(mt: MatrixTable, entry_to_keep: Collection[str])

def make_entry_struct(e, row):
handled_fields = dict()
# we drop PL by default, but if `entry_to_keep` has it then PL needs to be
# turned into LPL
handled_names = {'AD', 'PL'}

# we drop PL/PGT by default, but if `entry_to_keep` has them, we need to
# convert them to local versions for consistency.
handled_names = {'AD', 'GT', 'PGT', 'PL'}

if 'GT' in entry_to_keep:
handled_fields['LGT'] = e['GT']
if 'PGT' in entry_to_keep:
handled_fields['LPGT'] = e['PGT']
if 'AD' in entry_to_keep:
handled_fields['LAD'] = e['AD'][:1]
if 'PL' in entry_to_keep:
Expand Down
14 changes: 9 additions & 5 deletions hail/python/hail/vds/combiner/variant_dataset_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,9 @@ def maybe_load_from_saved_path(save_path: str) -> Optional[VariantDatasetCombine
if vds_paths:
# sync up gvcf_reference_entry_fields_to_keep and they reference entry types from the VDS
vds = hl.vds.read_vds(vds_paths[0], _warn_no_ref_block_max_length=False)
vds_ref_entry = set(vds.reference_data.entry) - {'END'}
vds_ref_entry = set(
name[1:] if name in ('LGT', 'LPGT') else name for name in vds.reference_data.entry if name != 'END'
)
if gvcf_reference_entry_fields_to_keep is not None and vds_ref_entry != gvcf_reference_entry_fields_to_keep:
warning(
"Mismatch between 'gvcf_reference_entry_fields' to keep and VDS reference data "
Expand All @@ -806,9 +808,11 @@ def maybe_load_from_saved_path(save_path: str) -> Optional[VariantDatasetCombine

# sync up call_fields and call fields present in the VDS
all_entry_types = chain(vds.reference_data._type.entry_type.items(), vds.variant_data._type.entry_type.items())
vds_call_fields = {name for name, typ in all_entry_types if typ == hl.tcall} - {'LGT', 'GT'}
if 'LPGT' in vds_call_fields:
vds_call_fields = (vds_call_fields - {'LPGT'}) | {'PGT'}
vds_call_fields = {
name[1:] if name == 'LPGT' else name
for name, typ in all_entry_types
if typ == hl.tcall and name not in ('LGT', 'GT')
}
if set(call_fields) != vds_call_fields:
warning(
"Mismatch between 'call_fields' and VDS call fields. "
Expand All @@ -830,7 +834,7 @@ def maybe_load_from_saved_path(save_path: str) -> Optional[VariantDatasetCombine
gvcf_type = mt._type
if gvcf_reference_entry_fields_to_keep is None:
rmt = mt.filter_rows(hl.is_defined(mt.info.END))
gvcf_reference_entry_fields_to_keep = defined_entry_fields(rmt, 100_000) - {'GT', 'PGT', 'PL'}
gvcf_reference_entry_fields_to_keep = defined_entry_fields(rmt, 100_000) - {'PGT', 'PL'}
if vds is None:
vds = transform_gvcf(
mt._key_rows_by_assert_sorted('locus'), gvcf_reference_entry_fields_to_keep, gvcf_info_to_keep
Expand Down
19 changes: 12 additions & 7 deletions hail/python/hail/vds/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,21 @@ def coalesce_join(ref, var):
call_field = 'GT' if 'GT' in var else 'LGT'
assert call_field in var, var.dtype

shared_fields = [call_field, *list(f for f in ref.dtype if f in var.dtype)]
shared_field_set = set(shared_fields)
var_fields = [f for f in var.dtype if f not in shared_field_set]
if call_field not in ref:
ref_call_field = 'GT' if 'GT' in ref else 'LGT'
if ref_call_field not in ref:
ref = ref.annotate(**{call_field: hl.call(0, 0)})
else:
ref = ref.annotate(**{call_field: ref[ref_call_field]})

# call_field is now in both ref and var
ref_set, var_set = set(ref.dtype), set(var.dtype)
shared_fields, var_fields = var_set & ref_set, var_set - ref_set

return hl.if_else(
hl.is_defined(var),
var.select(*shared_fields, *var_fields),
ref.annotate(**{call_field: hl.call(0, 0)}).select(
*shared_fields, **{f: hl.missing(var[f].dtype) for f in var_fields}
),
ref.select(*shared_fields, **{f: hl.missing(var[f].dtype) for f in var_fields}),
)

dr = dr.annotate(
Expand Down Expand Up @@ -138,7 +143,7 @@ def rewrite_ref(r):
for k, t in merged_schema.items():
if k == 'LA':
ref_block_selector[k] = hl.literal([0])
elif k in ('LGT', 'GT'):
elif k in ('LGT', 'GT') and k not in r:
ref_block_selector[k] = hl.call(0, 0)
else:
ref_block_selector[k] = r[k] if k in r else hl.missing(t)
Expand Down
3 changes: 3 additions & 0 deletions hail/python/test/hail/vds/test_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def test_combiner_works():
assert 'LPGT' in comb.variant_data.entry
assert comb.variant_data.LPGT.dtype == hl.tcall

# see https://github.com/hail-is/hail/issues/14564 for why this assertion is here
assert 'LGT' in comb.reference_data.entry

assert len(parts) == comb.variant_data.n_partitions()
comb.variant_data._force_count_rows()
comb.reference_data._force_count_rows()
Expand Down

0 comments on commit cc26435

Please sign in to comment.