Skip to content

Commit

Permalink
[vds] Stop dropping GT/PGT in reference data during import
Browse files Browse the repository at this point in the history
Reference GT/PGT may have ploidy information, so we need to stop
dropping the GT/PGT.
  • Loading branch information
chrisvittal committed May 22, 2024
1 parent d63c91e commit 56baa59
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 12 deletions.
13 changes: 10 additions & 3 deletions hail/python/hail/vds/combiner/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,12 @@ def make_entry_struct(e, row):
handled_fields = dict()
# we drop PL by default, but if `entry_to_keep` has it then PL needs to be
# turned into LPL
handled_names = {'AD', 'PL'}
handled_names = {'AD', 'GT', 'PGT', 'PL'}

if 'GT' in entry_to_keep:
handled_fields['LGT'] = e['GT']
if 'PGT' in entry_to_keep:
handled_fields['LPGT'] = e['PGT']
if 'AD' in entry_to_keep:
handled_fields['LAD'] = e['AD'][:1]
if 'PL' in entry_to_keep:
Expand Down Expand Up @@ -230,7 +234,6 @@ def make_entry_struct(e, alleles_len, has_non_ref, row):
)
_transform_variant_function_map[row_type, info_key] = transform_row

from hail.expr import construct_expr
from hail.utils.java import Env

uid = Env.get_uid()
Expand All @@ -252,8 +255,12 @@ def make_entry_struct(e, row):
handled_fields = dict()
# we drop PL by default, but if `entry_to_keep` has it then PL needs to be
# turned into LPL
handled_names = {'AD', 'PL'}
handled_names = {'AD', 'GT', 'PGT', 'PL'}

if 'GT' in entry_to_keep:
handled_fields['LGT'] = e['GT']
if 'PGT' in entry_to_keep:
handled_fields['LPGT'] = e['PGT']
if 'AD' in entry_to_keep:
handled_fields['LAD'] = e['AD'][:1]
if 'PL' in entry_to_keep:
Expand Down
14 changes: 9 additions & 5 deletions hail/python/hail/vds/combiner/variant_dataset_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,9 @@ def maybe_load_from_saved_path(save_path: str) -> Optional[VariantDatasetCombine
if vds_paths:
# sync up gvcf_reference_entry_fields_to_keep and they reference entry types from the VDS
vds = hl.vds.read_vds(vds_paths[0], _warn_no_ref_block_max_length=False)
vds_ref_entry = set(vds.reference_data.entry) - {'END'}
vds_ref_entry = set(
name[1:] if name in ('LGT', 'LPGT') else name for name in vds.reference_data.entry if name != 'END'
)
if gvcf_reference_entry_fields_to_keep is not None and vds_ref_entry != gvcf_reference_entry_fields_to_keep:
warning(
"Mismatch between 'gvcf_reference_entry_fields' to keep and VDS reference data "
Expand All @@ -806,9 +808,11 @@ def maybe_load_from_saved_path(save_path: str) -> Optional[VariantDatasetCombine

# sync up call_fields and call fields present in the VDS
all_entry_types = chain(vds.reference_data._type.entry_type.items(), vds.variant_data._type.entry_type.items())
vds_call_fields = {name for name, typ in all_entry_types if typ == hl.tcall} - {'LGT', 'GT'}
if 'LPGT' in vds_call_fields:
vds_call_fields = (vds_call_fields - {'LPGT'}) | {'PGT'}
vds_call_fields = {
name[1:] if name == 'LPGT' else name
for name, typ in all_entry_types
if typ == hl.tcall and name not in ('LGT', 'GT')
}
if set(call_fields) != vds_call_fields:
warning(
"Mismatch between 'call_fields' and VDS call fields. "
Expand All @@ -830,7 +834,7 @@ def maybe_load_from_saved_path(save_path: str) -> Optional[VariantDatasetCombine
gvcf_type = mt._type
if gvcf_reference_entry_fields_to_keep is None:
rmt = mt.filter_rows(hl.is_defined(mt.info.END))
gvcf_reference_entry_fields_to_keep = defined_entry_fields(rmt, 100_000) - {'GT', 'PGT', 'PL'}
gvcf_reference_entry_fields_to_keep = defined_entry_fields(rmt, 100_000) - {'PL'}
if vds is None:
vds = transform_gvcf(
mt._key_rows_by_assert_sorted('locus'), gvcf_reference_entry_fields_to_keep, gvcf_info_to_keep
Expand Down
16 changes: 12 additions & 4 deletions hail/python/hail/vds/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,24 @@ def coalesce_join(ref, var):
call_field = 'GT' if 'GT' in var else 'LGT'
assert call_field in var, var.dtype

if call_field in ref:
ref_call_field = call_field
else:
ref_call_field = 'GT' if 'GT' in ref else 'LGT'

if ref_call_field not in ref:
ref = ref.annotate(**{call_field: hl.call(0, 0)})
if call_field not in ref:
ref = ref.annotate(**{call_field: ref[ref_call_field]})

shared_fields = [call_field, *list(f for f in ref.dtype if f in var.dtype)]
shared_field_set = set(shared_fields)
var_fields = [f for f in var.dtype if f not in shared_field_set]

return hl.if_else(
hl.is_defined(var),
var.select(*shared_fields, *var_fields),
ref.annotate(**{call_field: hl.call(0, 0)}).select(
*shared_fields, **{f: hl.missing(var[f].dtype) for f in var_fields}
),
ref.select(*shared_fields, **{f: hl.missing(var[f].dtype) for f in var_fields}),
)

dr = dr.annotate(
Expand Down Expand Up @@ -138,7 +146,7 @@ def rewrite_ref(r):
for k, t in merged_schema.items():
if k == 'LA':
ref_block_selector[k] = hl.literal([0])
elif k in ('LGT', 'GT'):
elif k in ('LGT', 'GT') and k not in r:
ref_block_selector[k] = hl.call(0, 0)
else:
ref_block_selector[k] = r[k] if k in r else hl.missing(t)
Expand Down

0 comments on commit 56baa59

Please sign in to comment.