Skip to content

Commit

Permalink
[vds/combiner] Deduplicate functions in combiner internals
Browse files Browse the repository at this point in the history
`make_variants_matrix_table` and `make_variant_stream`, shared a function
for processing entries. As did `make_reference_matrix_table` and
`make_reference_stream`. This deduplicates that functionality by
lifting out the internal functions to top level definitions.
  • Loading branch information
chrisvittal committed Jun 17, 2024
1 parent 63db10a commit 6fc485c
Showing 1 changed file with 86 additions and 166 deletions.
252 changes: 86 additions & 166 deletions hail/python/hail/vds/combiner/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,87 @@
_transform_reference_fuction_map: Dict[Tuple[HailType, Tuple[str, ...]], Function] = {}
_merge_function_map: Dict[Tuple[HailType, HailType], Function] = {}

def get_lgt(gt, n_alleles, has_non_ref, row):
index = gt.unphase().unphased_diploid_gt_index()
n_no_nonref = n_alleles - hl.int(has_non_ref)
triangle_without_nonref = hl.triangle(n_no_nonref)
return (
hl.case()
.when(gt.is_haploid(), hl.or_missing(gt[0] < n_no_nonref, gt))
.when(index < triangle_without_nonref, gt)
.when(index < hl.triangle(n_alleles), hl.missing('call'))
.or_error('invalid call ' + hl.str(gt) + ' at site ' + hl.str(row.locus))
)

def make_var_entry_struct(e, info_to_keep, alleles_len, has_non_ref, row):
handled_fields = dict()
handled_names = {'LA', 'gvcf_info', 'LAD', 'AD', 'LGT', 'GT', 'LPL', 'PL', 'LPGT', 'PGT'}

if 'GT' not in e:
raise hl.utils.FatalError("the Hail VDS combiner expects input GVCFs to have a 'GT' field in FORMAT.")

handled_fields['LA'] = hl.range(0, alleles_len - hl.if_else(has_non_ref, 1, 0))
handled_fields['LGT'] = get_lgt(e.GT, alleles_len, has_non_ref, row)
if 'AD' in e:
handled_fields['LAD'] = hl.if_else(has_non_ref, e.AD[:-1], e.AD)
if 'PGT' in e:
handled_fields['LPGT'] = (
e.PGT if e.PGT.dtype != hl.tcall else get_lgt(e.PGT, alleles_len, has_non_ref, row)
)
if 'PL' in e:
handled_fields['LPL'] = hl.if_else(
has_non_ref,
hl.if_else(
alleles_len > 2,
hl.if_else(e.GT.is_haploid(), e.PL[:-1], e.PL[:-alleles_len]),
hl.missing(e.PL.dtype),
),
hl.if_else(alleles_len > 1, e.PL, hl.missing(e.PL.dtype)),
)
handled_fields['RGQ'] = hl.if_else(
has_non_ref,
hl.if_else(
e.GT.is_haploid(),
e.PL[alleles_len - 1],
e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()],
),
hl.missing(e.PL.dtype.element_type),
)

handled_fields['gvcf_info'] = (
hl.case()
.when(
hl.is_missing(row.info.END),
parse_allele_specific_fields(row.info.select(*info_to_keep), has_non_ref),
)
.or_missing()
)

pass_through_fields = {k: v for k, v in e.items() if k not in handled_names}
return hl.struct(**handled_fields, **pass_through_fields)

def make_ref_entry_struct(e, entry_to_keep, row):
handled_fields = dict()
# we drop PL/PGT by default, but if `entry_to_keep` has them, we need to
# convert them to local versions for consistency.
handled_names = {'AD', 'GT', 'PGT', 'PL'}

if 'GT' in entry_to_keep:
handled_fields['LGT'] = e['GT']
if 'PGT' in entry_to_keep:
handled_fields['LPGT'] = e['PGT']
if 'AD' in entry_to_keep:
handled_fields['LAD'] = e['AD'][:1]
if 'PL' in entry_to_keep:
handled_fields['LPL'] = e['PL'][:1]

reference_fields = {k: v for k, v in e.items() if k in entry_to_keep and k not in handled_names}
return (
hl.case()
.when(e.GT.is_hom_ref(), hl.struct(END=row.info.END, **reference_fields, **handled_fields))
.or_error('found END with non reference-genotype at' + hl.str(row.locus))
)


def make_variants_matrix_table(mt: MatrixTable, info_to_keep: Optional[Collection[str]] = None) -> MatrixTable:
if info_to_keep is None:
Expand All @@ -33,65 +114,6 @@ def make_variants_matrix_table(mt: MatrixTable, info_to_keep: Optional[Collectio
transform_row = _transform_variant_function_map.get((mt.row.dtype, info_key))
if transform_row is None or not hl.current_backend()._is_registered_ir_function_name(transform_row._name):

def get_lgt(gt, n_alleles, has_non_ref, row):
index = gt.unphase().unphased_diploid_gt_index()
n_no_nonref = n_alleles - hl.int(has_non_ref)
triangle_without_nonref = hl.triangle(n_no_nonref)
return (
hl.case()
.when(gt.is_haploid(), hl.or_missing(gt[0] < n_no_nonref, gt))
.when(index < triangle_without_nonref, gt)
.when(index < hl.triangle(n_alleles), hl.missing('call'))
.or_error('invalid call ' + hl.str(gt) + ' at site ' + hl.str(row.locus))
)

def make_entry_struct(e, alleles_len, has_non_ref, row):
handled_fields = dict()
handled_names = {'LA', 'gvcf_info', 'LAD', 'AD', 'LGT', 'GT', 'LPL', 'PL', 'LPGT', 'PGT'}

if 'GT' not in e:
raise hl.utils.FatalError("the Hail VDS combiner expects input GVCFs to have a 'GT' field in FORMAT.")

handled_fields['LA'] = hl.range(0, alleles_len - hl.if_else(has_non_ref, 1, 0))
handled_fields['LGT'] = get_lgt(e.GT, alleles_len, has_non_ref, row)
if 'AD' in e:
handled_fields['LAD'] = hl.if_else(has_non_ref, e.AD[:-1], e.AD)
if 'PGT' in e:
handled_fields['LPGT'] = (
e.PGT if e.PGT.dtype != hl.tcall else get_lgt(e.PGT, alleles_len, has_non_ref, row)
)
if 'PL' in e:
handled_fields['LPL'] = hl.if_else(
has_non_ref,
hl.if_else(
alleles_len > 2,
hl.if_else(e.GT.is_haploid(), e.PL[:-1], e.PL[:-alleles_len]),
hl.missing(e.PL.dtype),
),
hl.if_else(alleles_len > 1, e.PL, hl.missing(e.PL.dtype)),
)
handled_fields['RGQ'] = hl.if_else(
has_non_ref,
hl.if_else(
e.GT.is_haploid(),
e.PL[alleles_len - 1],
e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()],
),
hl.missing(e.PL.dtype.element_type),
)

handled_fields['gvcf_info'] = (
hl.case()
.when(
hl.is_missing(row.info.END),
parse_allele_specific_fields(row.info.select(*info_to_keep), has_non_ref),
)
.or_missing()
)

pass_through_fields = {k: v for k, v in e.items() if k not in handled_names}
return hl.struct(**handled_fields, **pass_through_fields)

transform_row = hl.experimental.define_function(
lambda row: hl.rbind(
hl.len(row.alleles),
Expand All @@ -100,7 +122,7 @@ def make_entry_struct(e, alleles_len, has_non_ref, row):
locus=row.locus,
alleles=hl.if_else(has_non_ref, row.alleles[:-1], row.alleles),
**({'rsid': row.rsid} if 'rsid' in row else {}),
__entries=row.__entries.map(lambda e: make_entry_struct(e, alleles_len, has_non_ref, row)),
__entries=row.__entries.map(lambda e: make_var_entry_struct(e, info_to_keep, alleles_len, has_non_ref, row)),
),
),
mt.row.dtype,
Expand All @@ -120,39 +142,17 @@ def make_reference_stream(stream, entry_to_keep: Collection[str]):
stream = stream.filter(lambda elt: hl.is_defined(elt.info.END))
entry_key = tuple(sorted(entry_to_keep)) # hashable stable value

def make_entry_struct(e, row):
handled_fields = dict()
# we drop PL/PGT by default, but if `entry_to_keep` has them, we need to
# convert them to local versions for consistency.
handled_names = {'AD', 'GT', 'PGT', 'PL'}

if 'GT' in entry_to_keep:
handled_fields['LGT'] = e['GT']
if 'PGT' in entry_to_keep:
handled_fields['LPGT'] = e['PGT']
if 'AD' in entry_to_keep:
handled_fields['LAD'] = e['AD'][:1]
if 'PL' in entry_to_keep:
handled_fields['LPL'] = e['PL'][:1]

reference_fields = {k: v for k, v in e.items() if k in entry_to_keep and k not in handled_names}
return (
hl.case()
.when(e.GT.is_hom_ref(), hl.struct(END=row.info.END, **reference_fields, **handled_fields))
.or_error('found END with non reference-genotype at' + hl.str(row.locus))
)

row_type = stream.dtype.element_type
transform_row = _transform_reference_fuction_map.get((row_type, entry_key))
if transform_row is None or not hl.current_backend()._is_registered_ir_function_name(transform_row._name):
transform_row = hl.experimental.define_function(
lambda row: hl.struct(locus=row.locus, __entries=row.__entries.map(lambda e: make_entry_struct(e, row))),
lambda row: hl.struct(locus=row.locus, __entries=row.__entries.map(lambda e: make_ref_entry_struct(e, entry_to_keep, row))),
row_type,
)
_transform_reference_fuction_map[row_type, entry_key] = transform_row

return stream.map(
lambda row: hl.struct(locus=row.locus, __entries=row.__entries.map(lambda e: make_entry_struct(e, row)))
lambda row: hl.struct(locus=row.locus, __entries=row.__entries.map(lambda e: make_ref_entry_struct(e, entry_to_keep, row)))
)


Expand All @@ -169,64 +169,6 @@ def make_variant_stream(stream, info_to_keep):

transform_row = _transform_variant_function_map.get((row_type, info_key))
if transform_row is None or not hl.current_backend()._is_registered_ir_function_name(transform_row._name):

def get_lgt(e, n_alleles, has_non_ref, row):
index = e.GT.unphased_diploid_gt_index()
n_no_nonref = n_alleles - hl.int(has_non_ref)
triangle_without_nonref = hl.triangle(n_no_nonref)
return (
hl.case()
.when(e.GT.is_haploid(), hl.or_missing(e.GT[0] < n_no_nonref, e.GT))
.when(index < triangle_without_nonref, e.GT)
.when(index < hl.triangle(n_alleles), hl.missing('call'))
.or_error('invalid GT ' + hl.str(e.GT) + ' at site ' + hl.str(row.locus))
)

def make_entry_struct(e, alleles_len, has_non_ref, row):
handled_fields = dict()
handled_names = {'LA', 'gvcf_info', 'LAD', 'AD', 'LGT', 'GT', 'LPL', 'PL', 'LPGT', 'PGT'}

if 'GT' not in e:
raise hl.utils.FatalError("the Hail GVCF combiner expects GVCFs to have a 'GT' field in FORMAT.")

handled_fields['LA'] = hl.range(0, alleles_len - hl.if_else(has_non_ref, 1, 0))
handled_fields['LGT'] = get_lgt(e, alleles_len, has_non_ref, row)
if 'AD' in e:
handled_fields['LAD'] = hl.if_else(has_non_ref, e.AD[:-1], e.AD)
if 'PGT' in e:
handled_fields['LPGT'] = e.PGT
if 'PL' in e:
handled_fields['LPL'] = hl.if_else(
has_non_ref,
hl.if_else(
alleles_len > 2,
hl.if_else(e.GT.is_haploid(), e.PL[:-1], e.PL[:-alleles_len]),
hl.missing(e.PL.dtype),
),
hl.if_else(alleles_len > 1, e.PL, hl.missing(e.PL.dtype)),
)
handled_fields['RGQ'] = hl.if_else(
has_non_ref,
hl.if_else(
e.GT.is_haploid(),
e.PL[alleles_len - 1],
e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()],
),
hl.missing(e.PL.dtype.element_type),
)

handled_fields['gvcf_info'] = (
hl.case()
.when(
hl.is_missing(row.info.END),
parse_allele_specific_fields(row.info.select(*info_to_keep), has_non_ref),
)
.or_missing()
)

pass_through_fields = {k: v for k, v in e.items() if k not in handled_names}
return hl.struct(**handled_fields, **pass_through_fields)

transform_row = hl.experimental.define_function(
lambda row: hl.rbind(
hl.len(row.alleles),
Expand All @@ -235,7 +177,7 @@ def make_entry_struct(e, alleles_len, has_non_ref, row):
locus=row.locus,
alleles=hl.if_else(has_non_ref, row.alleles[:-1], row.alleles),
**({'rsid': row.rsid} if 'rsid' in row else {}),
__entries=row.__entries.map(lambda e: make_entry_struct(e, alleles_len, has_non_ref, row)),
__entries=row.__entries.map(lambda e: make_var_entry_struct(e, info_to_keep, alleles_len, has_non_ref, row)),
),
),
row_type,
Expand All @@ -259,33 +201,11 @@ def make_reference_matrix_table(mt: MatrixTable, entry_to_keep: Collection[str])
mt = mt.filter_rows(hl.is_defined(mt.info.END))
entry_key = tuple(sorted(entry_to_keep)) # hashable stable value

def make_entry_struct(e, row):
handled_fields = dict()
# we drop PL/PGT by default, but if `entry_to_keep` has them, we need to
# convert them to local versions for consistency.
handled_names = {'AD', 'GT', 'PGT', 'PL'}

if 'GT' in entry_to_keep:
handled_fields['LGT'] = e['GT']
if 'PGT' in entry_to_keep:
handled_fields['LPGT'] = e['PGT']
if 'AD' in entry_to_keep:
handled_fields['LAD'] = e['AD'][:1]
if 'PL' in entry_to_keep:
handled_fields['LPL'] = e['PL'][:1]

reference_fields = {k: v for k, v in e.items() if k in entry_to_keep and k not in handled_names}
return (
hl.case()
.when(e.GT.is_hom_ref(), hl.struct(END=row.info.END, **reference_fields, **handled_fields))
.or_error('found END with non reference-genotype at' + hl.str(row.locus))
)

mt = localize(mt).key_by('locus')
transform_row = _transform_reference_fuction_map.get((mt.row.dtype, entry_key))
if transform_row is None or not hl.current_backend()._is_registered_ir_function_name(transform_row._name):
transform_row = hl.experimental.define_function(
lambda row: hl.struct(locus=row.locus, __entries=row.__entries.map(lambda e: make_entry_struct(e, row))),
lambda row: hl.struct(locus=row.locus, __entries=row.__entries.map(lambda e: make_ref_entry_struct(e, entry_to_keep, row))),
mt.row.dtype,
)
_transform_reference_fuction_map[mt.row.dtype, entry_key] = transform_row
Expand Down

0 comments on commit 6fc485c

Please sign in to comment.