Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[workshop] Add necessary infrastructure for IBG 2023 workshop #12769

Closed
wants to merge 11 commits into from
4 changes: 2 additions & 2 deletions hail/python/hail/backend/spark_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from hail.matrixtable import MatrixTable

from .py4j_backend import Py4JBackend, handle_java_exception
from ..hail_logging import Logger
from ..hail_logging import Logger, PythonOnlyLogger

if pyspark.__version__ < '3' and sys.version_info > (3, 8):
raise EnvironmentError('Hail with spark {} requires Python 3.7, found {}.{}'.format(
Expand Down Expand Up @@ -241,7 +241,7 @@ def stop(self):
@property
def logger(self):
if self._logger is None:
self._logger = Log4jLogger(self._utils_package_object)
self._logger = PythonOnlyLogger()
return self._logger

@property
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3905,7 +3905,7 @@ def group_by(f: Callable, collection) -> DictExpression:
@typecheck(f=func_spec(2, expr_any),
zero=expr_any,
collection=expr_oneof(expr_set(), expr_array()))
def fold(f: Callable, zero, collection) -> Expression:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drive-by fix. Hinting the superclass of the returned value causes lint issues (in my IDE at least), leaving it out and using duck typing doesn't. Can remove if you want, but I've been stripping these out when I see them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. Pyright complains about our use of Callable, but I don't understand the error message [1]. With or without this change, pyright says the type of

fold(lambda _: 3, literal(0), literal([1]))

is Expression. What type do you get?

Hmm, doesn't seem like there's integration for Pyright with IntelliJ unfortunately. The Python community seems to unifying on it since it's remarkably good.

[1]: Illegal type annotation: variable not allowed unless it is a type alias (lsp)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohhhh, I bet this is because your IDE thinks the return type is the return type of the function after modification by @typecheck

def fold(f: Callable, zero, collection):
"""Reduces a collection with the given function `f`, provided the initial value `zero`.

Examples
Expand Down
5 changes: 4 additions & 1 deletion hail/python/hail/ggplot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
scale_color_manual, scale_color_continuous, scale_fill_discrete, scale_fill_hue, scale_fill_identity, scale_fill_continuous,\
scale_fill_manual, scale_shape_manual, scale_shape_auto
from .facets import vars, facet_wrap
from .premade import manhattan_plot, qq_plot

__all__ = [
"aes",
Expand Down Expand Up @@ -54,5 +55,7 @@
"scale_shape_manual",
"scale_shape_auto",
"facet_wrap",
"vars"
"vars",
"manhattan_plot",
"qq_plot",
]
3 changes: 3 additions & 0 deletions hail/python/hail/ggplot/aes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Mapping
from hail.expr import Expression
from hail import literal
import hail


class Aesthetic(Mapping):
Expand Down Expand Up @@ -44,6 +45,8 @@ def aes(**kwargs):
hail_field_properties = {}

for k, v in kwargs.items():
if isinstance(v, (tuple, dict)):
v = hail.str(v)
if not isinstance(v, Expression):
v = literal(v)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's an example of using this new behavior? I feel a bit concerned that aes(foo=(ht.a, ht.b)) is different from:

ht = ht.annotate(c=(ht.a, ht.b))
aes(foo=ht.c)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can remove this, I think I added this in the process of trying to build the manhattan plot in ggplot. The issue I was trying to solve is that passing hl.tuple([something1, something2]) produces a tuple label, but if you pass (something1, something2), you get an exception after it tries to take the literal path. This is a consequence of trying to support both expressions and non-expr sequences. This is sort of a hack, but catches more cases of things that are probably meant to go the expr route.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove it from this PR and revisit separately. I seems to me that if you get an error from hl.literal((ht.a, ht.b)) you should also get an error from my example above. If that's true, we should fix it for both use cases.

hail_field_properties[k] = v
Expand Down
261 changes: 261 additions & 0 deletions hail/python/hail/ggplot/premade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
import math
from typing import Tuple, Dict, Union

import plotly.express as px
import plotly.graph_objects as go

import hail
from hail import Table
from hail.expr.expressions import expr_float64, expr_locus, expr_any, expr_numeric, Expression, NumericExpression
from hail.plot.plots import _collect_scatter_plot_data
from hail.typecheck import nullable, typecheck, numeric, dictof, oneof, sized_tupleof


@typecheck(pvals=expr_float64, locus=nullable(expr_locus()), title=nullable(str),
size=int, hover_fields=nullable(dictof(str, expr_any)), collect_all=bool, n_divisions=int,
significance_line=nullable(numeric))
def manhattan_plot(pvals, locus=None, title=None, size=5, hover_fields=None, collect_all=False, n_divisions=500,
significance_line=5e-8):
"""Create a Manhattan plot. (https://en.wikipedia.org/wiki/Manhattan_plot)

Parameters
----------
pvals : :class:`.Float64Expression`
P-values to be plotted.
locus : :class:`.LocusExpression`
Locus values to be plotted.
title : str
Title of the plot.
size : int
Size of markers in screen space units.
hover_fields : Dict[str, :class:`.Expression`]
Dictionary of field names and values to be shown in the HoverTool of the plot.
collect_all : bool
Whether to collect all values or downsample before plotting.
n_divisions : int
Factor by which to downsample (default value = 500). A lower input results in fewer output datapoints.
significance_line : float, optional
p-value at which to add a horizontal, dotted red line indicating
genome-wide significance. If ``None``, no line is added.

Returns
-------
:class:`bokeh.plotting.figure.Figure`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be a plotly figure

"""
if locus is None:
locus = pvals._indices.source.locus

ref = locus.dtype.reference_genome

if hover_fields is None:
hover_fields = {}

hover_fields['locus'] = hail.str(locus)
hover_fields['contig_even'] = locus.contig_idx % 2

pvals = -hail.log10(pvals)

source_pd = _collect_scatter_plot_data(
('_global_locus', locus.global_position()),
('_pval', pvals),
fields=hover_fields,
n_divisions=None if collect_all else n_divisions
)
source_pd['p_value'] = [10 ** (-p) for p in source_pd['_pval']]
source_pd['_contig'] = [locus.split(":")[0] for locus in source_pd['locus']]

observed_contigs = set(source_pd['_contig'])
observed_contigs = [contig for contig in ref.contigs.copy() if contig in observed_contigs]

contig_ticks = [ref._contig_global_position(contig) + ref.contig_length(contig) // 2 for contig in observed_contigs]

extra_hover = {k: True for k in hover_fields if k not in ('locus', 'contig_even')}
fig = px.scatter(source_pd, x='_global_locus', y='_pval', color='contig_even',
labels={
'_global_locus': 'global position',
'_pval': '-log10 p-value',
'p_value': 'p-value',
'locus': 'locus'
},
color_continuous_scale=["#00539C", "#EEA47F"],
template='ggplot2',
hover_name="locus",
hover_data={'contig_even': False,
'_global_locus': False,
'_pval': False,
'p_value': ':.2e',
**extra_hover})

fig.update_layout(
xaxis_title="Genomic Coordinate",
yaxis_title="-log10 p-value",
title=title,
showlegend=False,
coloraxis_showscale=False,
xaxis=dict(
tickmode='array',
tickvals=contig_ticks,
ticktext=observed_contigs
)
)

fig.update_traces(marker=dict(size=size),
selector=dict(mode='markers'))

if significance_line is not None:
fig.add_hline(y=-math.log10(significance_line), line_dash='dash', line_color='red', opacity=1, line_width=2)

return fig


@typecheck(pvals=oneof(expr_numeric, sized_tupleof(str, expr_numeric)),
label=nullable(oneof(dictof(str, expr_any), expr_any)), title=nullable(str),
xlabel=nullable(str), ylabel=nullable(str), size=int,
hover_fields=nullable(dictof(str, expr_any)),
width=int, height=int, collect_all=bool, n_divisions=nullable(int), missing_label=str)
def qq_plot(
pvals: Union[NumericExpression, Tuple[str, NumericExpression]],
label: Union[Expression, Dict[str, Expression]] = None,
title: str = 'Q-Q plot',
xlabel: str = 'Expected -log10(p)',
ylabel: str = 'Observed -log10(p)',
size: int = 6,
hover_fields: Dict[str, Expression] = None,
width: int = 800,
height: int = 800,
collect_all: bool = False,
n_divisions: int = 500,
missing_label: str = 'NA'
):
"""Create a Quantile-Quantile plot. (https://en.wikipedia.org/wiki/Q-Q_plot)

``pvals`` must be either:
- a :class:`.NumericExpression`
- a tuple (str, :class:`.NumericExpression`). If passed as a tuple the first element is used as the hover label.

If no label or a single label is provided, then returns :class:`bokeh.plotting.figure.Figure`
Otherwise returns a :class:`bokeh.models.layouts.Column` containing:
- a :class:`bokeh.models.widgets.inputs.Select` dropdown selection widget for labels
- a :class:`bokeh.plotting.figure.Figure` containing the interactive qq plot

Points will be colored by one of the labels defined in the ``label`` using the color scheme defined in
the corresponding entry of ``colors`` if provided (otherwise a default scheme is used). To specify your color
mapper, check `the bokeh documentation <https://bokeh.pydata.org/en/latest/docs/reference/colors.html>`__
for CategoricalMapper for categorical labels, and for LinearColorMapper and LogColorMapper
for continuous labels.
For categorical labels, clicking on one of the items in the legend will hide/show all points with the corresponding label.
Note that using many different labelling schemes in the same plots, particularly if those labels contain many
different classes could slow down the plot interactions.

Hovering on points will display their coordinates, labels and any additional fields specified in ``hover_fields``.

Parameters
----------
pvals : :class:`.NumericExpression` or (str, :class:`.NumericExpression`)
List of x-values to be plotted.
label : :class:`.Expression` or Dict[str, :class:`.Expression`]]
Either a single expression (if a single label is desired), or a
dictionary of label name -> label value for x and y values.
Used to color each point w.r.t its label.
When multiple labels are given, a dropdown will be displayed with the different options.
Can be used with categorical or continuous expressions.
title : str
Title of the scatterplot.
xlabel : str
X-axis label.
ylabel : str
Y-axis label.
size : int
Size of markers in screen space units.
hover_fields : Dict[str, :class:`.Expression`]
Extra fields to be displayed when hovering over a point on the plot.
width: int
Plot width
height: int
Plot height
collect_all : bool
Whether to collect all values or downsample before plotting.
n_divisions : int
Factor by which to downsample (default value = 500). A lower input results in fewer output datapoints.
missing_label: str
Label to use when a point is missing data for a categorical label

Returns
-------
:class:`bokeh.plotting.figure.Figure` if no label or a single label was given, otherwise :class:`bokeh.models.layouts.Column`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return type is wrong

"""
hover_fields = {} if hover_fields is None else hover_fields
label = {} if label is None else {'label': label} if isinstance(label, Expression) else label
source = pvals._indices.source
if 'locus' in source.row:
hover_fields['__locus'] = source['locus']

if isinstance(source, Table):
ht = source.select(p_value=pvals, **hover_fields, **label)
else:
ht = source.select_rows(p_value=pvals, **hover_fields, **label).rows()
ht = ht.key_by().select('p_value', *hover_fields, *label).key_by('p_value').persist()
n = ht.count()
ht = ht.annotate(
observed_p=-hail.log10(ht['p_value']),
expected_p=-hail.log10((hail.scan.count() + 1) / n)
).persist()

if 'p_value' not in hover_fields:
hover_fields['p_value'] = ht.p_value

df = _collect_scatter_plot_data(
('expected p-value', ht.expected_p),
('observed p-value', ht.observed_p),
fields={k: ht[k] for k in hover_fields},
n_divisions=None if collect_all else n_divisions,
missing_label=missing_label
)

fig = px.scatter(df, x='expected p-value', y='observed p-value',
template='ggplot2',
hover_name="__locus",
hover_data={**{k: True for k in hover_fields},
'__locus': False,
'p_value': ':.2e'})

fig.update_traces(marker=dict(size=size, color='black'),
selector=dict(mode='markers'))

from hail.methods.statgen import _lambda_gc_agg
lambda_gc, max_p = ht.aggregate(
(_lambda_gc_agg(ht['p_value']), hail.agg.max(hail.max(ht.observed_p, ht.expected_p))))
fig.add_trace(go.Scatter(x=[0, max_p + 1], y=[0, max_p + 1],
mode='lines',
name='expected',
line=dict(color='red', width=3, dash='dot')))

label_color = 'red' if lambda_gc > 1.25 else 'orange' if lambda_gc > 1.1 else 'black'

lgc_label = f'<b>λ GC: {lambda_gc:.2f}</b>'

fig.update_layout(
autosize=False,
width=width,
height=height,
xaxis_title=xlabel,
xaxis_range=[0, max_p + 1],
yaxis_title=ylabel,
yaxis_range=[0, max_p + 1],
title=title,
showlegend=False,
annotations=[
go.layout.Annotation(
font=dict(color=label_color, size=22),
text=lgc_label,
xanchor='left',
yanchor='bottom',
showarrow=False,
x=max_p * 0.8,
y=1
)
],
hovermode="x unified"
)

return fig
15 changes: 1 addition & 14 deletions hail/python/hail/matrixtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2812,7 +2812,7 @@ def estimate_size(struct_expression):
if include_row_fields:
characters -= estimate_size(self.row_value)
characters = max(characters, 0)
n_cols = characters // (estimate_size(self.entry) + 4) # 4 for the column index
n_cols = max(characters // (estimate_size(self.entry) + 4), 3) # 4 for the column index
actual_n_cols = self.count_cols()
displayed_n_cols = min(actual_n_cols, n_cols)

Expand Down Expand Up @@ -2894,13 +2894,6 @@ def cols(self) -> Table:
:class:`.Table`
Table with all column fields from the matrix, with one row per column of the matrix.
"""

if len(self.col_key) != 0 and Env.hc()._warn_cols_order:
warning("cols(): Resulting column table is sorted by 'col_key'."
"\n To preserve matrix table column order, "
"first unkey columns with 'key_cols_by()'")
Env.hc()._warn_cols_order = False

return Table(ir.MatrixColsTable(self._mir))

def entries(self) -> Table:
Expand Down Expand Up @@ -2955,12 +2948,6 @@ def entries(self) -> Table:
:class:`.Table`
Table with all non-global fields from the matrix, with **one row per entry of the matrix**.
"""
if Env.hc()._warn_entries_order and len(self.col_key) > 0:
warning("entries(): Resulting entries table is sorted by '(row_key, col_key)'."
"\n To preserve row-major matrix table order, "
"first unkey columns with 'key_cols_by()'")
Env.hc()._warn_entries_order = False

return Table(ir.MatrixEntriesTable(self._mir))

def index_globals(self) -> Expression:
Expand Down
Loading