From 0d1d7373376c4441b63a0e2f19d4701442f21017 Mon Sep 17 00:00:00 2001 From: Tim Poterba Date: Wed, 1 Mar 2023 09:22:54 -0500 Subject: [PATCH 01/11] [plotting] Add manhattan and ggplot methods built on plotly It's a snarly problem to add support for all the customization necessary to build these out of ggplot directly, but this lets us use ggplot in the tutorial notebooks without having to use a separate plotting lib. --- hail/python/hail/ggplot/__init__.py | 5 +- hail/python/hail/ggplot/premade.py | 261 ++++++++++++++++++++++++++++ 2 files changed, 265 insertions(+), 1 deletion(-) create mode 100644 hail/python/hail/ggplot/premade.py diff --git a/hail/python/hail/ggplot/__init__.py b/hail/python/hail/ggplot/__init__.py index e6d2a195295..c543b93a7f6 100644 --- a/hail/python/hail/ggplot/__init__.py +++ b/hail/python/hail/ggplot/__init__.py @@ -10,6 +10,7 @@ scale_color_manual, scale_color_continuous, scale_fill_discrete, scale_fill_hue, scale_fill_identity, scale_fill_continuous,\ scale_fill_manual, scale_shape_manual, scale_shape_auto from .facets import vars, facet_wrap +from .premade import manhattan_plot, qq_plot __all__ = [ "aes", @@ -54,5 +55,7 @@ "scale_shape_manual", "scale_shape_auto", "facet_wrap", - "vars" + "vars", + "manhattan_plot", + "qq_plot", ] diff --git a/hail/python/hail/ggplot/premade.py b/hail/python/hail/ggplot/premade.py new file mode 100644 index 00000000000..74ce2e7e06b --- /dev/null +++ b/hail/python/hail/ggplot/premade.py @@ -0,0 +1,261 @@ +import math +from typing import Tuple, Dict, Union + +import plotly.express as px +import plotly.graph_objects as go + +import hail +from hail import Table +from hail.expr.expressions import expr_float64, expr_locus, expr_any, expr_numeric, Expression, NumericExpression +from hail.plot.plots import _collect_scatter_plot_data +from hail.typecheck import nullable, typecheck, numeric, dictof, oneof, sized_tupleof + + +@typecheck(pvals=expr_float64, locus=nullable(expr_locus()), title=nullable(str), + size=int, hover_fields=nullable(dictof(str, expr_any)), collect_all=bool, n_divisions=int, + significance_line=nullable(numeric)) +def manhattan_plot(pvals, locus=None, title=None, size=5, hover_fields=None, collect_all=False, n_divisions=500, + significance_line=5e-8): + """Create a Manhattan plot. (https://en.wikipedia.org/wiki/Manhattan_plot) + + Parameters + ---------- + pvals : :class:`.Float64Expression` + P-values to be plotted. + locus : :class:`.LocusExpression` + Locus values to be plotted. + title : str + Title of the plot. + size : int + Size of markers in screen space units. + hover_fields : Dict[str, :class:`.Expression`] + Dictionary of field names and values to be shown in the HoverTool of the plot. + collect_all : bool + Whether to collect all values or downsample before plotting. + n_divisions : int + Factor by which to downsample (default value = 500). A lower input results in fewer output datapoints. + significance_line : float, optional + p-value at which to add a horizontal, dotted red line indicating + genome-wide significance. If ``None``, no line is added. + + Returns + ------- + :class:`bokeh.plotting.figure.Figure` + """ + if locus is None: + locus = pvals._indices.source.locus + + ref = locus.dtype.reference_genome + + if hover_fields is None: + hover_fields = {} + + hover_fields['locus'] = hail.str(locus) + hover_fields['contig_even'] = locus.contig_idx % 2 + + pvals = -hail.log10(pvals) + + source_pd = _collect_scatter_plot_data( + ('_global_locus', locus.global_position()), + ('_pval', pvals), + fields=hover_fields, + n_divisions=None if collect_all else n_divisions + ) + source_pd['p_value'] = [10 ** (-p) for p in source_pd['_pval']] + source_pd['_contig'] = [locus.split(":")[0] for locus in source_pd['locus']] + + observed_contigs = set(source_pd['_contig']) + observed_contigs = [contig for contig in ref.contigs.copy() if contig in observed_contigs] + + contig_ticks = [ref._contig_global_position(contig) + ref.contig_length(contig) // 2 for contig in observed_contigs] + + extra_hover = {k: True for k in hover_fields if k not in ('locus', 'contig_even')} + fig = px.scatter(source_pd, x='_global_locus', y='_pval', color='contig_even', + labels={ + '_global_locus': 'global position', + '_pval': '-log10 p-value', + 'p_value': 'p-value', + 'locus': 'locus' + }, + color_continuous_scale=["#00539C", "#EEA47F"], + template='ggplot2', + hover_name="locus", + hover_data={'contig_even': False, + '_global_locus': False, + '_pval': False, + 'p_value': ':.2e', + **extra_hover}) + + fig.update_layout( + xaxis_title="Genomic Coordinate", + yaxis_title="-log10 p-value", + title=title, + showlegend=False, + coloraxis_showscale=False, + xaxis=dict( + tickmode='array', + tickvals=contig_ticks, + ticktext=observed_contigs + ) + ) + + fig.update_traces(marker=dict(size=size), + selector=dict(mode='markers')) + + if significance_line is not None: + fig.add_hline(y=-math.log10(significance_line), line_dash='dash', line_color='red', opacity=1, line_width=2) + + return fig + + +@typecheck(pvals=oneof(expr_numeric, sized_tupleof(str, expr_numeric)), + label=nullable(oneof(dictof(str, expr_any), expr_any)), title=nullable(str), + xlabel=nullable(str), ylabel=nullable(str), size=int, + hover_fields=nullable(dictof(str, expr_any)), + width=int, height=int, collect_all=bool, n_divisions=nullable(int), missing_label=str) +def qq_plot( + pvals: Union[NumericExpression, Tuple[str, NumericExpression]], + label: Union[Expression, Dict[str, Expression]] = None, + title: str = 'Q-Q plot', + xlabel: str = 'Expected -log10(p)', + ylabel: str = 'Observed -log10(p)', + size: int = 6, + hover_fields: Dict[str, Expression] = None, + width: int = 800, + height: int = 800, + collect_all: bool = False, + n_divisions: int = 500, + missing_label: str = 'NA' +): + """Create a Quantile-Quantile plot. (https://en.wikipedia.org/wiki/Q-Q_plot) + + ``pvals`` must be either: + - a :class:`.NumericExpression` + - a tuple (str, :class:`.NumericExpression`). If passed as a tuple the first element is used as the hover label. + + If no label or a single label is provided, then returns :class:`bokeh.plotting.figure.Figure` + Otherwise returns a :class:`bokeh.models.layouts.Column` containing: + - a :class:`bokeh.models.widgets.inputs.Select` dropdown selection widget for labels + - a :class:`bokeh.plotting.figure.Figure` containing the interactive qq plot + + Points will be colored by one of the labels defined in the ``label`` using the color scheme defined in + the corresponding entry of ``colors`` if provided (otherwise a default scheme is used). To specify your color + mapper, check `the bokeh documentation `__ + for CategoricalMapper for categorical labels, and for LinearColorMapper and LogColorMapper + for continuous labels. + For categorical labels, clicking on one of the items in the legend will hide/show all points with the corresponding label. + Note that using many different labelling schemes in the same plots, particularly if those labels contain many + different classes could slow down the plot interactions. + + Hovering on points will display their coordinates, labels and any additional fields specified in ``hover_fields``. + + Parameters + ---------- + pvals : :class:`.NumericExpression` or (str, :class:`.NumericExpression`) + List of x-values to be plotted. + label : :class:`.Expression` or Dict[str, :class:`.Expression`]] + Either a single expression (if a single label is desired), or a + dictionary of label name -> label value for x and y values. + Used to color each point w.r.t its label. + When multiple labels are given, a dropdown will be displayed with the different options. + Can be used with categorical or continuous expressions. + title : str + Title of the scatterplot. + xlabel : str + X-axis label. + ylabel : str + Y-axis label. + size : int + Size of markers in screen space units. + hover_fields : Dict[str, :class:`.Expression`] + Extra fields to be displayed when hovering over a point on the plot. + width: int + Plot width + height: int + Plot height + collect_all : bool + Whether to collect all values or downsample before plotting. + n_divisions : int + Factor by which to downsample (default value = 500). A lower input results in fewer output datapoints. + missing_label: str + Label to use when a point is missing data for a categorical label + + Returns + ------- + :class:`bokeh.plotting.figure.Figure` if no label or a single label was given, otherwise :class:`bokeh.models.layouts.Column` + """ + hover_fields = {} if hover_fields is None else hover_fields + label = {} if label is None else {'label': label} if isinstance(label, Expression) else label + source = pvals._indices.source + if 'locus' in source.row: + hover_fields['__locus'] = source['locus'] + + if isinstance(source, Table): + ht = source.select(p_value=pvals, **hover_fields, **label) + else: + ht = source.select_rows(p_value=pvals, **hover_fields, **label).rows() + ht = ht.key_by().select('p_value', *hover_fields, *label).key_by('p_value').persist() + n = ht.count() + ht = ht.annotate( + observed_p=-hail.log10(ht['p_value']), + expected_p=-hail.log10((hail.scan.count() + 1) / n) + ).persist() + + if 'p_value' not in hover_fields: + hover_fields['p_value'] = ht.p_value + + df = _collect_scatter_plot_data( + ('expected p-value', ht.expected_p), + ('observed p-value', ht.observed_p), + fields={k: ht[k] for k in hover_fields}, + n_divisions=None if collect_all else n_divisions, + missing_label=missing_label + ) + + fig = px.scatter(df, x='expected p-value', y='observed p-value', + template='ggplot2', + hover_name="__locus", + hover_data={**{k: True for k in hover_fields}, + '__locus': False, + 'p_value': ':.2e'}) + + fig.update_traces(marker=dict(size=size, color='black'), + selector=dict(mode='markers')) + + from hail.methods.statgen import _lambda_gc_agg + lambda_gc, max_p = ht.aggregate( + (_lambda_gc_agg(ht['p_value']), hail.agg.max(hail.max(ht.observed_p, ht.expected_p)))) + fig.add_trace(go.Scatter(x=[0, max_p + 1], y=[0, max_p + 1], + mode='lines', + name='expected', + line=dict(color='red', width=3, dash='dot'))) + + label_color = 'red' if lambda_gc > 1.25 else 'orange' if lambda_gc > 1.1 else 'black' + + lgc_label = f'λ GC: {lambda_gc:.2f}' + + fig.update_layout( + autosize=False, + width=width, + height=height, + xaxis_title=xlabel, + xaxis_range=[0, max_p + 1], + yaxis_title=ylabel, + yaxis_range=[0, max_p + 1], + title=title, + showlegend=False, + annotations=[ + go.layout.Annotation( + font=dict(color=label_color, size=22), + text=lgc_label, + xanchor='left', + yanchor='bottom', + showarrow=False, + x=max_p * 0.8, + y=1 + ) + ], + hovermode="x unified" + ) + + return fig From c89030535c4e098fe9bf2ffaffeb2ec56064299b Mon Sep 17 00:00:00 2001 From: Tim Poterba Date: Thu, 2 Mar 2023 14:47:31 -0500 Subject: [PATCH 02/11] fix aes --- hail/python/hail/ggplot/aes.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/hail/python/hail/ggplot/aes.py b/hail/python/hail/ggplot/aes.py index 5497f28d4d2..01346ca2702 100644 --- a/hail/python/hail/ggplot/aes.py +++ b/hail/python/hail/ggplot/aes.py @@ -1,6 +1,7 @@ from collections.abc import Mapping from hail.expr import Expression from hail import literal +import hail class Aesthetic(Mapping): @@ -44,6 +45,8 @@ def aes(**kwargs): hail_field_properties = {} for k, v in kwargs.items(): + if isinstance(v, (tuple, dict)): + v = hail.str(v) if not isinstance(v, Expression): v = literal(v) hail_field_properties[k] = v From 98ce2fd45e73a00e380201bf6f538b71d2100260 Mon Sep 17 00:00:00 2001 From: Tim Poterba Date: Fri, 3 Mar 2023 10:01:16 -0500 Subject: [PATCH 03/11] add arg to ibd --- .../python/hail/methods/relatedness/identity_by_descent.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/hail/python/hail/methods/relatedness/identity_by_descent.py b/hail/python/hail/methods/relatedness/identity_by_descent.py index ecce456eaca..a5b60e68cd2 100644 --- a/hail/python/hail/methods/relatedness/identity_by_descent.py +++ b/hail/python/hail/methods/relatedness/identity_by_descent.py @@ -15,8 +15,9 @@ maf=nullable(expr_float64), bounded=bool, min=nullable(numeric), - max=nullable(numeric)) -def identity_by_descent(dataset, maf=None, bounded=True, min=None, max=None) -> Table: + max=nullable(numeric), + _use_python_impl=bool) +def identity_by_descent(dataset, maf=None, bounded=True, min=None, max=None, _use_python_impl=False) -> Table: """Compute matrix of identity-by-descent estimates. .. include:: ../_templates/req_tstring.rst @@ -104,7 +105,7 @@ def identity_by_descent(dataset, maf=None, bounded=True, min=None, max=None) -> dataset = dataset.select_cols().select_globals().select_entries('GT') dataset = require_biallelic(dataset, 'ibd') - if isinstance(Env.backend(), SparkBackend): + if isinstance(Env.backend(), SparkBackend) and not _use_python_impl: return Table(ir.MatrixToTableApply(dataset._mir, { 'name': 'IBD', 'mafFieldName': '__maf' if maf is not None else None, From 42b1e63ccbdccccfd63928afc2610c43087943ab Mon Sep 17 00:00:00 2001 From: Tim Poterba Date: Tue, 7 Mar 2023 15:38:30 -0500 Subject: [PATCH 04/11] Additional changes used for workshop --- hail/python/hail/expr/functions.py | 2 +- .../relatedness/identity_by_descent.py | 21 ++- .../methods/relatedness/mating_simulation.py | 152 ++++++++++++------ .../test/hail/methods/test_simulation.py | 16 +- 4 files changed, 130 insertions(+), 61 deletions(-) diff --git a/hail/python/hail/expr/functions.py b/hail/python/hail/expr/functions.py index 950bb394387..97fad620728 100644 --- a/hail/python/hail/expr/functions.py +++ b/hail/python/hail/expr/functions.py @@ -3905,7 +3905,7 @@ def group_by(f: Callable, collection) -> DictExpression: @typecheck(f=func_spec(2, expr_any), zero=expr_any, collection=expr_oneof(expr_set(), expr_array())) -def fold(f: Callable, zero, collection) -> Expression: +def fold(f: Callable, zero, collection): """Reduces a collection with the given function `f`, provided the initial value `zero`. Examples diff --git a/hail/python/hail/methods/relatedness/identity_by_descent.py b/hail/python/hail/methods/relatedness/identity_by_descent.py index a5b60e68cd2..b97c3146763 100644 --- a/hail/python/hail/methods/relatedness/identity_by_descent.py +++ b/hail/python/hail/methods/relatedness/identity_by_descent.py @@ -90,8 +90,6 @@ def identity_by_descent(dataset, maf=None, bounded=True, min=None, max=None, _us :class:`.Table` """ - require_col_key_str(dataset, 'identity_by_descent') - if not isinstance(dataset.GT, hl.CallExpression): raise Exception('GT field must be of type Call') @@ -119,9 +117,14 @@ def identity_by_descent(dataset, maf=None, bounded=True, min=None, max=None, _us if not 0 <= min <= max <= 1: raise Exception(f"invalid pi hat filters {min} {max}") - sample_ids = dataset.s.collect() - if len(sample_ids) != len(set(sample_ids)): - raise Exception('duplicate sample ids found') + col_key_field = list(dataset.col_key)[0] + ds_unkey = dataset.key_cols_by() + _ids, dups = ds_unkey.aggregate_cols((hl.agg.collect(ds_unkey[col_key_field]), + hl.array(hl.agg.counter(ds_unkey[col_key_field])) + .filter(lambda x: x[1] > 1) + .map(lambda x: x[0]))) + if len(dups) > 0: + raise Exception(f'identity_by_descent: {len(dups)} duplicate sample ids found: {list(dups.keys())[:10]}') dataset = dataset.annotate_entries( n_alt_alleles=hl.or_else(dataset.GT.n_alt_alleles(), 0), @@ -158,6 +161,8 @@ def identity_by_descent(dataset, maf=None, bounded=True, min=None, max=None, _us + p * (q ** 2) * ((Y - 1) / Y) * (T / (T - 1)) * (T / (T - 2))), _e22=(T / 2) ) + dataset = dataset.filter_rows( + ~hl.any(*(hl.is_nan(dataset[x]) for x in ('_e00', '_e10', '_e20', '_e11', '_e21', '_e22')))) dataset = dataset.checkpoint(hl.utils.new_temp_file()) @@ -222,10 +227,10 @@ def bound_result(_ibd): result = result.annotate(ibd=result.ibd.annotate(PI_HAT=result.ibd.Z1 / 2 + result.ibd.Z2)) result = result.filter((result.i < result.j) & (min <= result.ibd.PI_HAT) & (result.ibd.PI_HAT <= max)) - samples = hl.literal(dataset.s.collect()) + ids = hl.literal(_ids) result = result.key_by( - i=samples[hl.int32(result.i)], - j=samples[hl.int32(result.j)] + i=ids[hl.int32(result.i)], + j=ids[hl.int32(result.j)] ) return result.persist() diff --git a/hail/python/hail/methods/relatedness/mating_simulation.py b/hail/python/hail/methods/relatedness/mating_simulation.py index 4ade8f7d69c..8f107499efc 100644 --- a/hail/python/hail/methods/relatedness/mating_simulation.py +++ b/hail/python/hail/methods/relatedness/mating_simulation.py @@ -1,70 +1,130 @@ import hail as hl -from hail.typecheck import typecheck, numeric +from hail.typecheck import typecheck, numeric, nullable +import random +from hail.utils.java import info -@typecheck(mt=hl.MatrixTable, n_rounds=int, generation_size_multiplier=numeric, keep_founders=bool) -def simulate_random_mating(mt, n_rounds=1, generation_size_multiplier=1.0, keep_founders=True): +@typecheck(mt=hl.MatrixTable, + n_rounds=int, + pairs_per_generation_multiplier=numeric, + children_per_pair=int, + seed=nullable(int)) +def simulate_random_mating(mt, + n_rounds=1, + pairs_per_generation_multiplier=0.5, + children_per_pair=2, + seed=None): """Simulate random diploid mating to produce new individuals. + .. include:: _templates/experimental.rst + + Exmaples + -------- + + >>> dataset_sim = hl.simulate_random_mating(dataset, n_rounds=2, pairs_per_generation_multiplier=0.5) + Parameters ---------- mt n_rounds : :obj:`int` Number of rounds of mating. - generation_size_multiplier : :obj:`float` - Ratio of number of offspring to current population for each round of mating. - keep_founders :obj:`bool` - If true, keep all founders and intermediate generations in the final sample list. If - false, keep only offspring in the last generation. - + pairs_per_generation_multiplier : :obj:`float` + Ratio of number of mating pairs to current population size for each round of mating. + children_per_pair : :obj:`int` + Number of children per mating pair. Returns ------- :class:`.MatrixTable` """ - if generation_size_multiplier <= 0: + if pairs_per_generation_multiplier <= 0: raise ValueError( - f"simulate_random_mating: 'generation_size_multiplier' must be greater than zero: got {generation_size_multiplier}") + f"simulate_random_mating: 'generation_size_multiplier' must be greater than zero: got {pairs_per_generation_multiplier}") if n_rounds < 1: raise ValueError(f"simulate_random_mating: 'n_rounds' must be positive: got {n_rounds}") - ck = list(mt.col_key)[0] - mt = mt.select_entries('GT') - ht = mt.localize_entries('__entries', '__cols') - ht = ht.annotate_globals( - generation_0=hl.range(hl.len(ht.__cols)).map(lambda i: hl.struct(s=hl.str('generation_0_idx_') + hl.str(i), - original=hl.str(ht.__cols[i][ck]), - mother=hl.missing('int32'), - father=hl.missing('int32')))) - - def make_new_generation(prev_generation_tup, idx): - prev_size = prev_generation_tup[1] - n_new = hl.int32(hl.floor(prev_size * generation_size_multiplier)) - new_generation = hl.range(n_new).map( - lambda i: hl.struct(s=hl.str('generation_') + hl.str(idx + 1) + hl.str('_idx_') + hl.str(i), - original=hl.missing('str'), - mother=hl.rand_int32(0, prev_size), - father=hl.rand_int32(0, prev_size))) - return (new_generation, (prev_size + n_new) if keep_founders else n_new) - - ht = ht.annotate_globals(generations=hl.range(n_rounds).scan(lambda prev, idx: make_new_generation(prev, idx), - (ht.generation_0, hl.len(ht.generation_0)))) - - def simulate_mating_calls(prev_generation_calls, new_generation): - new_samples = new_generation.map(lambda samp: hl.call(prev_generation_calls[samp.mother][hl.rand_int32(0, 2)], - prev_generation_calls[samp.father][hl.rand_int32(0, 2)])) - if keep_founders: - return prev_generation_calls.extend(new_samples) - else: - return new_samples + ns = mt.count_cols() + + # dict of true nonzero relatedness. indeed by tuples of (id1, id2) where a pair is stored with the larger (later) id first. + from collections import defaultdict + relatedness = defaultdict(dict) + + def get_rel(s1, s2): + if s1 > s2: + if s1 in relatedness: + return relatedness[s1].get(s2, 0.0) + elif s2 in relatedness: + return relatedness[s2].get(s1, 0.0) + return 0.0 + + samples = [(i, f'founder_{i}', None, None) for i in range(ns)] + info(f'simulate_random_mating: {len(samples)} founders, {n_rounds} rounds of mating to do') + last_generation_start_idx = 0 + indices = [] + + if seed is not None: + random.seed(seed) + for generation in range(n_rounds): + last_generation_end = len(samples) + mating_generation_size = last_generation_end - last_generation_start_idx + + new_pairs = int(mating_generation_size * pairs_per_generation_multiplier) + + curr_sample_idx = len(samples) + for pair in range(new_pairs): + mother = int(random.uniform(last_generation_start_idx, last_generation_end)) + father = int(last_generation_start_idx + ( + mother + random.uniform(1, mating_generation_size)) % mating_generation_size) + + mother_rel = relatedness[mother] + father_rel = relatedness[father] + + merged_parent_rel = {} + for k, v in mother_rel.items(): + merged_parent_rel[k] = .5 * (v + father_rel.get(k, 0.0)) + for k, v in father_rel.items(): + if k not in mother_rel: + merged_parent_rel[k] = .5 * v + + child_rel_value = 0.25 + get_rel(mother, father) / 2 + first_child = curr_sample_idx + for child in range(children_per_pair): + samples.append( + (curr_sample_idx, f'generation_{generation + 1}_pair_{pair}_child_{child}', mother, father)) + relatedness[curr_sample_idx] = merged_parent_rel.copy() + relatedness[curr_sample_idx][mother] = child_rel_value + relatedness[curr_sample_idx][father] = child_rel_value + + if child > 0: + relatedness[curr_sample_idx][first_child] = child_rel_value + + curr_sample_idx += 1 + info( + f'simulate_random_mating: generation {generation + 1}: ' + f'{curr_sample_idx - last_generation_end} new samples, ' + f'for a total of {len(samples)}') + + indices.append((last_generation_end, curr_sample_idx)) + last_generation_start_idx = last_generation_end + + ht = ht.annotate_globals(__samples=hl.literal(samples, dtype='tarray') + .map(lambda t: hl.struct(sample_idx=t[0], s=t[1], mother=t[2], father=t[3])), + __indices=indices, + relatedness=relatedness) + + def simulate_mating_calls(prev_generation_calls, samples, indices): + new_samples = hl.range(indices[0], indices[1]) \ + .map(lambda i: samples[i]) \ + .map(lambda samp: hl.call(prev_generation_calls[samp.mother][hl.rand_int32(0, 2)], + prev_generation_calls[samp.father][hl.rand_int32(0, 2)])) + return prev_generation_calls.extend(new_samples) + samples = ht.__samples ht = ht.annotate(__new_entries=hl.fold( - lambda prev_calls, generation_metadata: simulate_mating_calls(prev_calls, generation_metadata[0]), + lambda prev_calls, indices: simulate_mating_calls(prev_calls, samples, indices), ht.__entries.GT, - ht.generations[1:]).map(lambda gt: hl.struct(GT=gt))) - ht = ht.annotate_globals( - __new_cols=ht.generations.flatmap(lambda x: x[0]) if keep_founders else ht.generations[-1][0]) - ht = ht.drop('__entries', '__cols', 'generation_0', 'generations') - return ht._unlocalize_entries('__new_entries', '__new_cols', list('s')) + ht.__indices).map(lambda gt: hl.struct(GT=gt))) + ht = ht.drop('__entries', '__cols', '__indices') + return ht._unlocalize_entries('__new_entries', '__samples', list('s')) diff --git a/hail/python/test/hail/methods/test_simulation.py b/hail/python/test/hail/methods/test_simulation.py index c058eeecf9e..ade336e7a3e 100644 --- a/hail/python/test/hail/methods/test_simulation.py +++ b/hail/python/test/hail/methods/test_simulation.py @@ -8,10 +8,14 @@ def test_mating_simulation(): n_samples = mt.count_cols() - assert hl.simulate_random_mating(mt, n_rounds=1, generation_size_multiplier=2, keep_founders=False).count_cols() == n_samples * 2 - assert hl.simulate_random_mating(mt, n_rounds=4, generation_size_multiplier=2, keep_founders=False).count_cols() == n_samples * 16 - assert hl.simulate_random_mating(mt, n_rounds=2, generation_size_multiplier=1, keep_founders=False).count_cols() == n_samples - assert hl.simulate_random_mating(mt, n_rounds=2, generation_size_multiplier=2, keep_founders=True).count_cols() == n_samples * 9 + assert hl.simulate_random_mating(mt, n_rounds=1, pairs_per_generation_multiplier=0.5, + children_per_pair=2).count_cols() == n_samples * 2 + assert hl.simulate_random_mating(mt, n_rounds=4, pairs_per_generation_multiplier=0.5, + children_per_pair=2).count_cols() == n_samples * 5 + assert hl.simulate_random_mating(mt, n_rounds=3, pairs_per_generation_multiplier=1, + children_per_pair=2).count_cols() == n_samples + n_samples * 2 + n_samples * 4 + n_samples * 8 + assert hl.simulate_random_mating(mt, n_rounds=2, pairs_per_generation_multiplier=0.5, + children_per_pair=1).count_cols() == n_samples * 1.75 - - hl.simulate_random_mating(mt, n_rounds=2, generation_size_multiplier=0.5, keep_founders=True)._force_count_rows() \ No newline at end of file + hl.simulate_random_mating(mt, n_rounds=2, pairs_per_generation_multiplier=0.5, + children_per_pair=2)._force_count_rows() From 5fb4408ad6fa2a8e53db37971b09919c067d830e Mon Sep 17 00:00:00 2001 From: Tim Poterba Date: Tue, 7 Mar 2023 15:50:46 -0500 Subject: [PATCH 05/11] Additional changes used for workshop --- hail/python/hail/methods/relatedness/mating_simulation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hail/python/hail/methods/relatedness/mating_simulation.py b/hail/python/hail/methods/relatedness/mating_simulation.py index 8f107499efc..9554a2204eb 100644 --- a/hail/python/hail/methods/relatedness/mating_simulation.py +++ b/hail/python/hail/methods/relatedness/mating_simulation.py @@ -60,7 +60,7 @@ def get_rel(s1, s2): return 0.0 samples = [(i, f'founder_{i}', None, None) for i in range(ns)] - info(f'simulate_random_mating: {len(samples)} founders, {n_rounds} rounds of mating to do') + print(f'simulate_random_mating: {len(samples)} founders, {n_rounds} rounds of mating to do') last_generation_start_idx = 0 indices = [] @@ -101,7 +101,7 @@ def get_rel(s1, s2): relatedness[curr_sample_idx][first_child] = child_rel_value curr_sample_idx += 1 - info( + print( f'simulate_random_mating: generation {generation + 1}: ' f'{curr_sample_idx - last_generation_end} new samples, ' f'for a total of {len(samples)}') From 1ab4690dd7e0bd37f54eaca1a9eadfc50c196c6b Mon Sep 17 00:00:00 2001 From: Tim Poterba Date: Wed, 8 Mar 2023 08:00:07 -0500 Subject: [PATCH 06/11] fix relatedness estimation --- .../methods/relatedness/mating_simulation.py | 70 ++++++++++++++----- 1 file changed, 51 insertions(+), 19 deletions(-) diff --git a/hail/python/hail/methods/relatedness/mating_simulation.py b/hail/python/hail/methods/relatedness/mating_simulation.py index 9554a2204eb..f9196da15b2 100644 --- a/hail/python/hail/methods/relatedness/mating_simulation.py +++ b/hail/python/hail/methods/relatedness/mating_simulation.py @@ -1,7 +1,6 @@ import hail as hl from hail.typecheck import typecheck, numeric, nullable import random -from hail.utils.java import info @typecheck(mt=hl.MatrixTable, @@ -52,18 +51,21 @@ def simulate_random_mating(mt, relatedness = defaultdict(dict) def get_rel(s1, s2): - if s1 > s2: - if s1 in relatedness: - return relatedness[s1].get(s2, 0.0) - elif s2 in relatedness: - return relatedness[s2].get(s1, 0.0) - return 0.0 + smaller, larger = (s1, s2) if s1 < s2 else s2, s1 + return relatedness[larger].get(smaller, 0.0) + + def set_rel(from_s, to_s, value, fwd, generation_start): + relatedness[from_s][to_s] = value + if to_s >= generation_start: + fwd[to_s][from_s] = value samples = [(i, f'founder_{i}', None, None) for i in range(ns)] print(f'simulate_random_mating: {len(samples)} founders, {n_rounds} rounds of mating to do') last_generation_start_idx = 0 indices = [] + last_gen_fwd = defaultdict(dict) + if seed is not None: random.seed(seed) for generation in range(n_rounds): @@ -72,33 +74,62 @@ def get_rel(s1, s2): new_pairs = int(mating_generation_size * pairs_per_generation_multiplier) + curr_gen_fwd = defaultdict(dict) + curr_sample_idx = len(samples) + parent_to_child = defaultdict(set) for pair in range(new_pairs): mother = int(random.uniform(last_generation_start_idx, last_generation_end)) father = int(last_generation_start_idx + ( mother + random.uniform(1, mating_generation_size)) % mating_generation_size) - mother_rel = relatedness[mother] - father_rel = relatedness[father] + mother_rel = {**relatedness[mother], **last_gen_fwd[mother]} + father_rel = {**relatedness[father], **last_gen_fwd[father]} merged_parent_rel = {} + for k, v in mother_rel.items(): - merged_parent_rel[k] = .5 * (v + father_rel.get(k, 0.0)) + merged_parent_rel[k] = .5 * v for k, v in father_rel.items(): - if k not in mother_rel: + if k in mother_rel: + merged_parent_rel[k] += .5 * v + else: merged_parent_rel[k] = .5 * v - child_rel_value = 0.25 + get_rel(mother, father) / 2 - first_child = curr_sample_idx + if mother in merged_parent_rel: + assert father in merged_parent_rel + merged_parent_rel[mother] += 0.25 + merged_parent_rel[father] += 0.25 + else: + merged_parent_rel[mother] = 0.25 + merged_parent_rel[father] = 0.25 + sibling_rel = merged_parent_rel[mother] # here mother/father should be the same, since + for child in range(children_per_pair): samples.append( (curr_sample_idx, f'generation_{generation + 1}_pair_{pair}_child_{child}', mother, father)) - relatedness[curr_sample_idx] = merged_parent_rel.copy() - relatedness[curr_sample_idx][mother] = child_rel_value - relatedness[curr_sample_idx][father] = child_rel_value - - if child > 0: - relatedness[curr_sample_idx][first_child] = child_rel_value + for k, v in merged_parent_rel.items(): + set_rel(curr_sample_idx, k, v, curr_gen_fwd, last_generation_start_idx) + + mother_sibs = parent_to_child[mother] + father_sibs = parent_to_child[father] + for sib in mother_sibs: + if sib in father_sibs: + rel = sibling_rel + else: + _, _, sib_mom, sib_dad = samples[sib] + other_parent = sib_mom if sib_mom != mother else sib_dad + rel = .125 + get_rel(mother, other_parent) / 2 + set_rel(curr_sample_idx, sib, rel, curr_gen_fwd, last_generation_start_idx) + for sib in father_sibs: + if sib not in mother_sibs: + _, _, sib_mom, sib_dad = samples[sib] + other_parent = sib_mom if sib_mom != father else sib_dad + set_rel(curr_sample_idx, sib, 0.125 + get_rel(father, other_parent) / 2, curr_gen_fwd, + last_generation_start_idx) + + mother_sibs.add(curr_sample_idx) + father_sibs.add(curr_sample_idx) curr_sample_idx += 1 print( @@ -108,6 +139,7 @@ def get_rel(s1, s2): indices.append((last_generation_end, curr_sample_idx)) last_generation_start_idx = last_generation_end + last_gen_fwd = curr_gen_fwd ht = ht.annotate_globals(__samples=hl.literal(samples, dtype='tarray') .map(lambda t: hl.struct(sample_idx=t[0], s=t[1], mother=t[2], father=t[3])), From b123465fc0bf21651bcfdfa8e634c6ee5543b7f0 Mon Sep 17 00:00:00 2001 From: Tim Poterba Date: Thu, 9 Mar 2023 07:22:33 -0700 Subject: [PATCH 07/11] add test --- .../methods/relatedness/mating_simulation.py | 21 ++++++----- .../test/hail/methods/test_simulation.py | 36 +++++++++++++++++++ 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/hail/python/hail/methods/relatedness/mating_simulation.py b/hail/python/hail/methods/relatedness/mating_simulation.py index f9196da15b2..6ffd4e7b940 100644 --- a/hail/python/hail/methods/relatedness/mating_simulation.py +++ b/hail/python/hail/methods/relatedness/mating_simulation.py @@ -103,7 +103,10 @@ def set_rel(from_s, to_s, value, fwd, generation_start): else: merged_parent_rel[mother] = 0.25 merged_parent_rel[father] = 0.25 - sibling_rel = merged_parent_rel[mother] # here mother/father should be the same, since + + # here mother/father be the same, since the edge from father to mother + # and the edge from mother to father must be identical. + sibling_rel = merged_parent_rel[mother] for child in range(children_per_pair): samples.append( @@ -111,25 +114,25 @@ def set_rel(from_s, to_s, value, fwd, generation_start): for k, v in merged_parent_rel.items(): set_rel(curr_sample_idx, k, v, curr_gen_fwd, last_generation_start_idx) - mother_sibs = parent_to_child[mother] - father_sibs = parent_to_child[father] - for sib in mother_sibs: - if sib in father_sibs: + mother_other_kids = parent_to_child[mother] + father_other_kids = parent_to_child[father] + for sib in mother_other_kids: + if sib in father_other_kids: rel = sibling_rel else: _, _, sib_mom, sib_dad = samples[sib] other_parent = sib_mom if sib_mom != mother else sib_dad rel = .125 + get_rel(mother, other_parent) / 2 set_rel(curr_sample_idx, sib, rel, curr_gen_fwd, last_generation_start_idx) - for sib in father_sibs: - if sib not in mother_sibs: + for sib in father_other_kids: + if sib not in mother_other_kids: _, _, sib_mom, sib_dad = samples[sib] other_parent = sib_mom if sib_mom != father else sib_dad set_rel(curr_sample_idx, sib, 0.125 + get_rel(father, other_parent) / 2, curr_gen_fwd, last_generation_start_idx) - mother_sibs.add(curr_sample_idx) - father_sibs.add(curr_sample_idx) + mother_other_kids.add(curr_sample_idx) + father_other_kids.add(curr_sample_idx) curr_sample_idx += 1 print( diff --git a/hail/python/test/hail/methods/test_simulation.py b/hail/python/test/hail/methods/test_simulation.py index ade336e7a3e..0dc995b0246 100644 --- a/hail/python/test/hail/methods/test_simulation.py +++ b/hail/python/test/hail/methods/test_simulation.py @@ -19,3 +19,39 @@ def test_mating_simulation(): hl.simulate_random_mating(mt, n_rounds=2, pairs_per_generation_multiplier=0.5, children_per_pair=2)._force_count_rows() + + +def test_relatedness(): + mt = hl.balding_nichols_model(n_populations=2, n_samples=2, n_variants=2) + # this configuration forces 2 unrelated mating to produce 2 samples, who mate to produce 2 more + rel = hl.eval(hl.simulate_random_mating(mt, + n_rounds=2, + pairs_per_generation_multiplier=0.5, + children_per_pair=2).index_globals().relatedness) + + assert rel == { + 0: {}, + 1: {}, + 2: { + 0: 0.25, + 1: 0.25 + }, + 3: { + 0: 0.25, + 1: 0.25, + 2: 0.25, + }, + 4: { + 0: 0.25, + 1: 0.25, + 2: 0.375, + 3: 0.375, + }, + 5: { + 0: 0.25, + 1: 0.25, + 2: 0.375, + 3: 0.375, + 4: 0.375, + }, + } From b3dfdc944d95bee44b898e130b9063efc5cdb0a9 Mon Sep 17 00:00:00 2001 From: Tim Poterba Date: Thu, 9 Mar 2023 15:11:43 -0700 Subject: [PATCH 08/11] logging improvements --- hail/python/hail/backend/spark_backend.py | 4 ++-- hail/python/hail/matrixtable.py | 6 ------ hail/python/hail/methods/relatedness/pc_relate.py | 11 ++++++++++- hail/python/hail/plot/plots.py | 2 +- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/hail/python/hail/backend/spark_backend.py b/hail/python/hail/backend/spark_backend.py index a7e4798b1a6..8275008bf77 100644 --- a/hail/python/hail/backend/spark_backend.py +++ b/hail/python/hail/backend/spark_backend.py @@ -18,7 +18,7 @@ from hail.matrixtable import MatrixTable from .py4j_backend import Py4JBackend, handle_java_exception -from ..hail_logging import Logger +from ..hail_logging import Logger, PythonOnlyLogger if pyspark.__version__ < '3' and sys.version_info > (3, 8): raise EnvironmentError('Hail with spark {} requires Python 3.7, found {}.{}'.format( @@ -241,7 +241,7 @@ def stop(self): @property def logger(self): if self._logger is None: - self._logger = Log4jLogger(self._utils_package_object) + self._logger = PythonOnlyLogger() return self._logger @property diff --git a/hail/python/hail/matrixtable.py b/hail/python/hail/matrixtable.py index f9369a91d42..391134f7b80 100644 --- a/hail/python/hail/matrixtable.py +++ b/hail/python/hail/matrixtable.py @@ -2955,12 +2955,6 @@ def entries(self) -> Table: :class:`.Table` Table with all non-global fields from the matrix, with **one row per entry of the matrix**. """ - if Env.hc()._warn_entries_order and len(self.col_key) > 0: - warning("entries(): Resulting entries table is sorted by '(row_key, col_key)'." - "\n To preserve row-major matrix table order, " - "first unkey columns with 'key_cols_by()'") - Env.hc()._warn_entries_order = False - return Table(ir.MatrixEntriesTable(self._mir)) def index_globals(self) -> Expression: diff --git a/hail/python/hail/methods/relatedness/pc_relate.py b/hail/python/hail/methods/relatedness/pc_relate.py index 37be5b04f6e..963f9e869e5 100644 --- a/hail/python/hail/methods/relatedness/pc_relate.py +++ b/hail/python/hail/methods/relatedness/pc_relate.py @@ -12,7 +12,7 @@ from hail.table import Table from hail.typecheck import enumeration, nullable, numeric, typecheck from hail.utils import new_temp_file -from hail.utils.java import Env +from hail.utils.java import Env, info from ..pca import _hwe_normalized_blanczos, hwe_normalized_pca @@ -313,6 +313,7 @@ def pc_relate(call_expr: CallExpression, mt = matrix_table_source('pc_relate/call_expr', call_expr) if k and scores_expr is None: + info(f"pc_relate: computing {k} principal components...") _, scores, _ = hwe_normalized_pca(call_expr, k, compute_loadings=False) scores_expr = scores[mt.col_key].scores elif not k and scores_expr is not None: @@ -325,6 +326,7 @@ def pc_relate(call_expr: CallExpression, scores_table = mt.select_cols(__scores=scores_expr) \ .key_cols_by().select_cols('__scores').cols() + info(f"pc_relate: computing score missingness...") n_missing = scores_table.aggregate(agg.count_where(hl.is_missing(scores_table.__scores))) if n_missing > 0: raise ValueError(f'Found {n_missing} columns with missing scores array.') @@ -336,11 +338,15 @@ def pc_relate(call_expr: CallExpression, if not block_size: block_size = BlockMatrix.default_block_size() + info(f"pc_relate: imputing missing data and writing as block matrix...") + g = BlockMatrix.from_entry_expr(mean_imputed_gt, block_size=block_size) + info(f"pc_relate: collecting scores table...") pcs = scores_table.collect(_localize=False).map(lambda x: x.__scores) + info(f"pc_relate: Running PC-Relate model...") ht = Table(ir.BlockMatrixToTableApply(g._bmir, pcs._ir, { 'name': 'PCRelate', 'maf': min_individual_maf, @@ -358,7 +364,10 @@ def pc_relate(call_expr: CallExpression, if not include_self_kinship: ht = ht.filter(ht.i == ht.j, keep=False) + info(f"pc_relate: collecting column keys...") + col_keys = hl.literal(mt.select_cols().key_cols_by().cols().collect(), dtype=tarray(mt.col_key.dtype)) + info(f"pc_relate: keying result table and caching") return ht.key_by(i=col_keys[ht.i], j=col_keys[ht.j]).persist() diff --git a/hail/python/hail/plot/plots.py b/hail/python/hail/plot/plots.py index f1395aa3220..aecae7a89ae 100644 --- a/hail/python/hail/plot/plots.py +++ b/hail/python/hail/plot/plots.py @@ -771,7 +771,7 @@ def _collect_scatter_plot_data( expressions = {k: hail.str(v) if not isinstance(v, StringExpression) else v for k, v in expressions.items()} agg_f = x[1]._aggregation_method() - res = agg_f(hail.agg.downsample(x[1], y[1], label=list(expressions.values()) if expressions else None, n_divisions=n_divisions)) + res = agg_f(hail.agg.downsample(x[1], y[1], label=hail.array(list(expressions.values())) if expressions else None, n_divisions=n_divisions)) source_pd = pd.DataFrame([ dict( **{x[0]: point[0], y[0]: point[1]}, From 8ff98bacc0b392d8a17fe0e1b1c74c148a1ea740 Mon Sep 17 00:00:00 2001 From: Tim Poterba Date: Thu, 9 Mar 2023 22:21:02 -0700 Subject: [PATCH 09/11] add max to show --- hail/python/hail/matrixtable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hail/python/hail/matrixtable.py b/hail/python/hail/matrixtable.py index 391134f7b80..0040e0d554d 100644 --- a/hail/python/hail/matrixtable.py +++ b/hail/python/hail/matrixtable.py @@ -2812,7 +2812,7 @@ def estimate_size(struct_expression): if include_row_fields: characters -= estimate_size(self.row_value) characters = max(characters, 0) - n_cols = characters // (estimate_size(self.entry) + 4) # 4 for the column index + n_cols = max(characters // (estimate_size(self.entry) + 4), 3) # 4 for the column index actual_n_cols = self.count_cols() displayed_n_cols = min(actual_n_cols, n_cols) From 37f305e255362a1c972299f36f1d32c8acf0d250 Mon Sep 17 00:00:00 2001 From: Tim Poterba Date: Fri, 10 Mar 2023 06:43:28 -0700 Subject: [PATCH 10/11] col-warning --- hail/python/hail/matrixtable.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/hail/python/hail/matrixtable.py b/hail/python/hail/matrixtable.py index 0040e0d554d..e748c4fbd96 100644 --- a/hail/python/hail/matrixtable.py +++ b/hail/python/hail/matrixtable.py @@ -2894,13 +2894,6 @@ def cols(self) -> Table: :class:`.Table` Table with all column fields from the matrix, with one row per column of the matrix. """ - - if len(self.col_key) != 0 and Env.hc()._warn_cols_order: - warning("cols(): Resulting column table is sorted by 'col_key'." - "\n To preserve matrix table column order, " - "first unkey columns with 'key_cols_by()'") - Env.hc()._warn_cols_order = False - return Table(ir.MatrixColsTable(self._mir)) def entries(self) -> Table: From e406e9c68d0810ffabae4d41ffb91da2d70669df Mon Sep 17 00:00:00 2001 From: Tim Poterba Date: Fri, 10 Mar 2023 07:00:25 -0700 Subject: [PATCH 11/11] fix log --- hail/python/hail/methods/relatedness/pc_relate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hail/python/hail/methods/relatedness/pc_relate.py b/hail/python/hail/methods/relatedness/pc_relate.py index 963f9e869e5..a9b27eb6464 100644 --- a/hail/python/hail/methods/relatedness/pc_relate.py +++ b/hail/python/hail/methods/relatedness/pc_relate.py @@ -367,7 +367,7 @@ def pc_relate(call_expr: CallExpression, info(f"pc_relate: collecting column keys...") col_keys = hl.literal(mt.select_cols().key_cols_by().cols().collect(), dtype=tarray(mt.col_key.dtype)) - info(f"pc_relate: keying result table and caching") + info(f"pc_relate: re-keying result table...") return ht.key_by(i=col_keys[ht.i], j=col_keys[ht.j]).persist()