Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[vds/combiner] Deduplicate functions in combiner internals #14583

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 97 additions & 167 deletions hail/python/hail/vds/combiner/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,88 @@
_merge_function_map: Dict[Tuple[HailType, HailType], Function] = {}


def get_lgt(gt, n_alleles, has_non_ref, row):
index = gt.unphase().unphased_diploid_gt_index()
n_no_nonref = n_alleles - hl.int(has_non_ref)
triangle_without_nonref = hl.triangle(n_no_nonref)
return (
hl.case()
.when(gt.is_haploid(), hl.or_missing(gt[0] < n_no_nonref, gt))
.when(index < triangle_without_nonref, gt)
.when(index < hl.triangle(n_alleles), hl.missing('call'))
.or_error('invalid call ' + hl.str(gt) + ' at site ' + hl.str(row.locus))
)


def make_var_entry_struct(e, info_to_keep, alleles_len, has_non_ref, row):
handled_fields = dict()
handled_names = {'LA', 'gvcf_info', 'LAD', 'AD', 'LGT', 'GT', 'LPL', 'PL', 'LPGT', 'PGT'}

if 'GT' not in e:
raise hl.utils.FatalError("the Hail VDS combiner expects input GVCFs to have a 'GT' field in FORMAT.")

handled_fields['LA'] = hl.range(0, alleles_len - hl.if_else(has_non_ref, 1, 0))
handled_fields['LGT'] = get_lgt(e.GT, alleles_len, has_non_ref, row)
if 'AD' in e:
handled_fields['LAD'] = hl.if_else(has_non_ref, e.AD[:-1], e.AD)
if 'PGT' in e:
handled_fields['LPGT'] = e.PGT if e.PGT.dtype != hl.tcall else get_lgt(e.PGT, alleles_len, has_non_ref, row)
if 'PL' in e:
handled_fields['LPL'] = hl.if_else(
has_non_ref,
hl.if_else(
alleles_len > 2,
hl.if_else(e.GT.is_haploid(), e.PL[:-1], e.PL[:-alleles_len]),
hl.missing(e.PL.dtype),
),
hl.if_else(alleles_len > 1, e.PL, hl.missing(e.PL.dtype)),
)
handled_fields['RGQ'] = hl.if_else(
has_non_ref,
hl.if_else(
e.GT.is_haploid(),
e.PL[alleles_len - 1],
e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()],
),
hl.missing(e.PL.dtype.element_type),
)

handled_fields['gvcf_info'] = (
hl.case()
.when(
hl.is_missing(row.info.END),
parse_allele_specific_fields(row.info.select(*info_to_keep), has_non_ref),
)
.or_missing()
)

pass_through_fields = {k: v for k, v in e.items() if k not in handled_names}
return hl.struct(**handled_fields, **pass_through_fields)


def make_ref_entry_struct(e, entry_to_keep, row):
handled_fields = dict()
# we drop PL/PGT by default, but if `entry_to_keep` has them, we need to
# convert them to local versions for consistency.
handled_names = {'AD', 'GT', 'PGT', 'PL'}

if 'GT' in entry_to_keep:
handled_fields['LGT'] = e['GT']
if 'PGT' in entry_to_keep:
handled_fields['LPGT'] = e['PGT']
if 'AD' in entry_to_keep:
handled_fields['LAD'] = e['AD'][:1]
if 'PL' in entry_to_keep:
handled_fields['LPL'] = e['PL'][:1]

reference_fields = {k: v for k, v in e.items() if k in entry_to_keep and k not in handled_names}
return (
hl.case()
.when(e.GT.is_hom_ref(), hl.struct(END=row.info.END, **reference_fields, **handled_fields))
.or_error('found END with non reference-genotype at' + hl.str(row.locus))
)


def make_variants_matrix_table(mt: MatrixTable, info_to_keep: Optional[Collection[str]] = None) -> MatrixTable:
if info_to_keep is None:
info_to_keep = []
Expand All @@ -32,66 +114,6 @@ def make_variants_matrix_table(mt: MatrixTable, info_to_keep: Optional[Collectio

transform_row = _transform_variant_function_map.get((mt.row.dtype, info_key))
if transform_row is None or not hl.current_backend()._is_registered_ir_function_name(transform_row._name):

def get_lgt(gt, n_alleles, has_non_ref, row):
index = gt.unphase().unphased_diploid_gt_index()
n_no_nonref = n_alleles - hl.int(has_non_ref)
triangle_without_nonref = hl.triangle(n_no_nonref)
return (
hl.case()
.when(gt.is_haploid(), hl.or_missing(gt[0] < n_no_nonref, gt))
.when(index < triangle_without_nonref, gt)
.when(index < hl.triangle(n_alleles), hl.missing('call'))
.or_error('invalid call ' + hl.str(gt) + ' at site ' + hl.str(row.locus))
)

def make_entry_struct(e, alleles_len, has_non_ref, row):
handled_fields = dict()
handled_names = {'LA', 'gvcf_info', 'LAD', 'AD', 'LGT', 'GT', 'LPL', 'PL', 'LPGT', 'PGT'}

if 'GT' not in e:
raise hl.utils.FatalError("the Hail VDS combiner expects input GVCFs to have a 'GT' field in FORMAT.")

handled_fields['LA'] = hl.range(0, alleles_len - hl.if_else(has_non_ref, 1, 0))
handled_fields['LGT'] = get_lgt(e.GT, alleles_len, has_non_ref, row)
if 'AD' in e:
handled_fields['LAD'] = hl.if_else(has_non_ref, e.AD[:-1], e.AD)
if 'PGT' in e:
handled_fields['LPGT'] = (
e.PGT if e.PGT.dtype != hl.tcall else get_lgt(e.PGT, alleles_len, has_non_ref, row)
)
if 'PL' in e:
handled_fields['LPL'] = hl.if_else(
has_non_ref,
hl.if_else(
alleles_len > 2,
hl.if_else(e.GT.is_haploid(), e.PL[:-1], e.PL[:-alleles_len]),
hl.missing(e.PL.dtype),
),
hl.if_else(alleles_len > 1, e.PL, hl.missing(e.PL.dtype)),
)
handled_fields['RGQ'] = hl.if_else(
has_non_ref,
hl.if_else(
e.GT.is_haploid(),
e.PL[alleles_len - 1],
e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()],
),
hl.missing(e.PL.dtype.element_type),
)

handled_fields['gvcf_info'] = (
hl.case()
.when(
hl.is_missing(row.info.END),
parse_allele_specific_fields(row.info.select(*info_to_keep), has_non_ref),
)
.or_missing()
)

pass_through_fields = {k: v for k, v in e.items() if k not in handled_names}
return hl.struct(**handled_fields, **pass_through_fields)

transform_row = hl.experimental.define_function(
lambda row: hl.rbind(
hl.len(row.alleles),
Expand All @@ -100,7 +122,9 @@ def make_entry_struct(e, alleles_len, has_non_ref, row):
locus=row.locus,
alleles=hl.if_else(has_non_ref, row.alleles[:-1], row.alleles),
**({'rsid': row.rsid} if 'rsid' in row else {}),
__entries=row.__entries.map(lambda e: make_entry_struct(e, alleles_len, has_non_ref, row)),
__entries=row.__entries.map(
lambda e: make_var_entry_struct(e, info_to_keep, alleles_len, has_non_ref, row)
),
),
),
mt.row.dtype,
Expand All @@ -120,39 +144,21 @@ def make_reference_stream(stream, entry_to_keep: Collection[str]):
stream = stream.filter(lambda elt: hl.is_defined(elt.info.END))
entry_key = tuple(sorted(entry_to_keep)) # hashable stable value

def make_entry_struct(e, row):
handled_fields = dict()
# we drop PL/PGT by default, but if `entry_to_keep` has them, we need to
# convert them to local versions for consistency.
handled_names = {'AD', 'GT', 'PGT', 'PL'}

if 'GT' in entry_to_keep:
handled_fields['LGT'] = e['GT']
if 'PGT' in entry_to_keep:
handled_fields['LPGT'] = e['PGT']
if 'AD' in entry_to_keep:
handled_fields['LAD'] = e['AD'][:1]
if 'PL' in entry_to_keep:
handled_fields['LPL'] = e['PL'][:1]

reference_fields = {k: v for k, v in e.items() if k in entry_to_keep and k not in handled_names}
return (
hl.case()
.when(e.GT.is_hom_ref(), hl.struct(END=row.info.END, **reference_fields, **handled_fields))
.or_error('found END with non reference-genotype at' + hl.str(row.locus))
)

row_type = stream.dtype.element_type
transform_row = _transform_reference_fuction_map.get((row_type, entry_key))
if transform_row is None or not hl.current_backend()._is_registered_ir_function_name(transform_row._name):
transform_row = hl.experimental.define_function(
lambda row: hl.struct(locus=row.locus, __entries=row.__entries.map(lambda e: make_entry_struct(e, row))),
lambda row: hl.struct(
locus=row.locus, __entries=row.__entries.map(lambda e: make_ref_entry_struct(e, entry_to_keep, row))
),
row_type,
)
_transform_reference_fuction_map[row_type, entry_key] = transform_row

return stream.map(
lambda row: hl.struct(locus=row.locus, __entries=row.__entries.map(lambda e: make_entry_struct(e, row)))
lambda row: hl.struct(
locus=row.locus, __entries=row.__entries.map(lambda e: make_ref_entry_struct(e, entry_to_keep, row))
)
)


Expand All @@ -169,64 +175,6 @@ def make_variant_stream(stream, info_to_keep):

transform_row = _transform_variant_function_map.get((row_type, info_key))
if transform_row is None or not hl.current_backend()._is_registered_ir_function_name(transform_row._name):

def get_lgt(e, n_alleles, has_non_ref, row):
index = e.GT.unphased_diploid_gt_index()
n_no_nonref = n_alleles - hl.int(has_non_ref)
triangle_without_nonref = hl.triangle(n_no_nonref)
return (
hl.case()
.when(e.GT.is_haploid(), hl.or_missing(e.GT[0] < n_no_nonref, e.GT))
.when(index < triangle_without_nonref, e.GT)
.when(index < hl.triangle(n_alleles), hl.missing('call'))
.or_error('invalid GT ' + hl.str(e.GT) + ' at site ' + hl.str(row.locus))
)

def make_entry_struct(e, alleles_len, has_non_ref, row):
handled_fields = dict()
handled_names = {'LA', 'gvcf_info', 'LAD', 'AD', 'LGT', 'GT', 'LPL', 'PL', 'LPGT', 'PGT'}

if 'GT' not in e:
raise hl.utils.FatalError("the Hail GVCF combiner expects GVCFs to have a 'GT' field in FORMAT.")

handled_fields['LA'] = hl.range(0, alleles_len - hl.if_else(has_non_ref, 1, 0))
handled_fields['LGT'] = get_lgt(e, alleles_len, has_non_ref, row)
if 'AD' in e:
handled_fields['LAD'] = hl.if_else(has_non_ref, e.AD[:-1], e.AD)
if 'PGT' in e:
handled_fields['LPGT'] = e.PGT
if 'PL' in e:
handled_fields['LPL'] = hl.if_else(
has_non_ref,
hl.if_else(
alleles_len > 2,
hl.if_else(e.GT.is_haploid(), e.PL[:-1], e.PL[:-alleles_len]),
hl.missing(e.PL.dtype),
),
hl.if_else(alleles_len > 1, e.PL, hl.missing(e.PL.dtype)),
)
handled_fields['RGQ'] = hl.if_else(
has_non_ref,
hl.if_else(
e.GT.is_haploid(),
e.PL[alleles_len - 1],
e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()],
),
hl.missing(e.PL.dtype.element_type),
)

handled_fields['gvcf_info'] = (
hl.case()
.when(
hl.is_missing(row.info.END),
parse_allele_specific_fields(row.info.select(*info_to_keep), has_non_ref),
)
.or_missing()
)

pass_through_fields = {k: v for k, v in e.items() if k not in handled_names}
return hl.struct(**handled_fields, **pass_through_fields)

transform_row = hl.experimental.define_function(
lambda row: hl.rbind(
hl.len(row.alleles),
Expand All @@ -235,7 +183,9 @@ def make_entry_struct(e, alleles_len, has_non_ref, row):
locus=row.locus,
alleles=hl.if_else(has_non_ref, row.alleles[:-1], row.alleles),
**({'rsid': row.rsid} if 'rsid' in row else {}),
__entries=row.__entries.map(lambda e: make_entry_struct(e, alleles_len, has_non_ref, row)),
__entries=row.__entries.map(
lambda e: make_var_entry_struct(e, info_to_keep, alleles_len, has_non_ref, row)
),
),
),
row_type,
Expand All @@ -259,33 +209,13 @@ def make_reference_matrix_table(mt: MatrixTable, entry_to_keep: Collection[str])
mt = mt.filter_rows(hl.is_defined(mt.info.END))
entry_key = tuple(sorted(entry_to_keep)) # hashable stable value

def make_entry_struct(e, row):
handled_fields = dict()
# we drop PL/PGT by default, but if `entry_to_keep` has them, we need to
# convert them to local versions for consistency.
handled_names = {'AD', 'GT', 'PGT', 'PL'}

if 'GT' in entry_to_keep:
handled_fields['LGT'] = e['GT']
if 'PGT' in entry_to_keep:
handled_fields['LPGT'] = e['PGT']
if 'AD' in entry_to_keep:
handled_fields['LAD'] = e['AD'][:1]
if 'PL' in entry_to_keep:
handled_fields['LPL'] = e['PL'][:1]

reference_fields = {k: v for k, v in e.items() if k in entry_to_keep and k not in handled_names}
return (
hl.case()
.when(e.GT.is_hom_ref(), hl.struct(END=row.info.END, **reference_fields, **handled_fields))
.or_error('found END with non reference-genotype at' + hl.str(row.locus))
)

mt = localize(mt).key_by('locus')
transform_row = _transform_reference_fuction_map.get((mt.row.dtype, entry_key))
if transform_row is None or not hl.current_backend()._is_registered_ir_function_name(transform_row._name):
transform_row = hl.experimental.define_function(
lambda row: hl.struct(locus=row.locus, __entries=row.__entries.map(lambda e: make_entry_struct(e, row))),
lambda row: hl.struct(
locus=row.locus, __entries=row.__entries.map(lambda e: make_ref_entry_struct(e, entry_to_keep, row))
),
mt.row.dtype,
)
_transform_reference_fuction_map[mt.row.dtype, entry_key] = transform_row
Expand Down