Skip to content

Commit

Permalink
[qob] allow writing to requester pays buckets (#12855)
Browse files Browse the repository at this point in the history
CHANGELOG: In Query-on-Batch, allow writing to requester pays buckets,
which was broken before this release.
  • Loading branch information
danking authored Apr 10, 2023
1 parent c65098a commit 9386fdd
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 34 deletions.
4 changes: 4 additions & 0 deletions hail/python/hail/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ def register_ir_function(self,
body: Expression):
pass

@abc.abstractmethod
def _is_registered_ir_function_name(self, name: str) -> bool:
pass

@abc.abstractmethod
def persist_expression(self, expr: Expression) -> Expression:
pass
Expand Down
8 changes: 7 additions & 1 deletion hail/python/hail/backend/local_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, Tuple, List
from typing import Optional, Union, Tuple, List, Set
import os
import socket
import socketserver
Expand Down Expand Up @@ -155,6 +155,7 @@ def __init__(self, tmpdir, log, quiet, append, branching_factor,
)
self._jhc = hail_package.HailContext.apply(
self._jbackend, branching_factor, optimizer_iterations)
self._registered_ir_function_names: Set[str] = set()

# This has to go after creating the SparkSession. Unclear why.
# Maybe it does its own patch?
Expand Down Expand Up @@ -193,6 +194,7 @@ def register_ir_function(self,
r = CSERenderer(stop_at_jir=True)
code = r(finalize_randomness(body._ir))
jbody = (self._parse_value_ir(code, ref_map=dict(zip(value_parameter_names, value_parameter_types)), ir_map=r.jirs))
self._registered_ir_function_names.add(name)

self.hail_package().expr.ir.functions.IRFunctionRegistry.pyRegisterIR(
name,
Expand All @@ -202,13 +204,17 @@ def register_ir_function(self,
return_type._parsable_string(),
jbody)

def _is_registered_ir_function_name(self, name: str) -> bool:
return name in self._registered_ir_function_names

def validate_file_scheme(self, url):
pass

def stop(self):
self._jhc.stop()
self._jhc = None
self._gateway.shutdown()
self._registered_ir_function_names = set()
uninstall_exception_handler()

@property
Expand Down
8 changes: 7 additions & 1 deletion hail/python/hail/backend/service_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Callable, Awaitable, Mapping, Any, List, Union, Tuple, TypeVar
from typing import Dict, Optional, Callable, Awaitable, Mapping, Any, List, Union, Tuple, TypeVar, Set
import abc
import collections
import struct
Expand Down Expand Up @@ -287,6 +287,7 @@ def __init__(self,
self.flags = flags
self.jar_spec = jar_spec
self.functions: List[IRFunction] = []
self._registered_ir_function_names: Set[str] = set()
self.driver_cores = driver_cores
self.driver_memory = driver_memory
self.worker_cores = worker_cores
Expand Down Expand Up @@ -328,6 +329,7 @@ def stop(self):
async_to_blocking(self._async_fs.close())
async_to_blocking(self.async_bc.close())
self.functions = []
self._registered_ir_function_names = set()

def render(self, ir):
r = CSERenderer()
Expand Down Expand Up @@ -612,6 +614,7 @@ def register_ir_function(self,
value_parameter_types: Union[Tuple[HailType, ...], List[HailType]],
return_type: HailType,
body: Expression):
self._registered_ir_function_names.add(name)
self.functions.append(IRFunction(
name,
type_parameters,
Expand All @@ -621,6 +624,9 @@ def register_ir_function(self,
body
))

def _is_registered_ir_function_name(self, name: str) -> bool:
return name in self._registered_ir_function_names

def persist_expression(self, expr):
# FIXME: should use context manager to clean up persisted resources
fname = TemporaryFilename(prefix='persist_expression').name
Expand Down
7 changes: 7 additions & 0 deletions hail/python/hail/backend/spark_backend.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Set
import pkg_resources
import sys
import os
Expand Down Expand Up @@ -194,6 +195,7 @@ def __init__(self, idempotent, sc, spark_conf, app_name, master,
self.sc = pyspark.SparkContext(gateway=self._gateway, jsc=self._jvm.JavaSparkContext(self._jsc))
self._jspark_session = self._jbackend.sparkSession()
self._spark_session = pyspark.sql.SparkSession(self.sc, self._jspark_session)
self._registered_ir_function_names: Set[str] = set()

# This has to go after creating the SparkSession. Unclear why.
# Maybe it does its own patch?
Expand Down Expand Up @@ -238,6 +240,7 @@ def stop(self):
self._jhc = None
self.sc.stop()
self.sc = None
self._registered_ir_function_names = set()
uninstall_exception_handler()

@property
Expand Down Expand Up @@ -267,6 +270,7 @@ def register_ir_function(self, name, type_parameters, argument_names, argument_t
assert not body._ir.uses_randomness
code = r(body._ir)
jbody = (self._parse_value_ir(code, ref_map=dict(zip(argument_names, argument_types)), ir_map=r.jirs))
self._registered_ir_function_names.add(name)

self.hail_package().expr.ir.functions.IRFunctionRegistry.pyRegisterIR(
name,
Expand All @@ -275,6 +279,9 @@ def register_ir_function(self, name, type_parameters, argument_names, argument_t
return_type._parsable_string(),
jbody)

def _is_registered_ir_function_name(self, name: str) -> bool:
return name in self._registered_ir_function_names

def read_multiple_matrix_tables(self, paths: 'List[str]', intervals: 'List[hl.Interval]', intervals_type):
json_repr = {
'paths': paths,
Expand Down
3 changes: 2 additions & 1 deletion hail/python/hail/experimental/function.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
from hail.expr.expressions import construct_expr, expr_any, unify_all
from hail.expr.types import hail_type
from hail.ir import Apply, Ref
Expand All @@ -18,7 +19,7 @@ def __call__(self, *args):


@typecheck(f=anytype, param_types=hail_type, _name=nullable(str), type_args=tupleof(hail_type))
def define_function(f, *param_types, _name=None, type_args=()):
def define_function(f, *param_types, _name: Optional[str] = None, type_args=()) -> Function:
mname = _name if _name is not None else Env.get_uid()
param_names = [Env.get_uid(mname) for _ in param_types]
body = f(*(construct_expr(Ref(pn), pt) for pn, pt in zip(param_names, param_types)))
Expand Down
21 changes: 11 additions & 10 deletions hail/python/hail/experimental/vcf_combiner/vcf_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@

import hail as hl
from hail import MatrixTable, Table
from hail.experimental.function import Function
from hail.expr import StructExpression
from hail.expr.expressions import expr_bool, expr_str
from hail.genetics.reference_genome import reference_genome_type
from hail.ir import Apply, TableMapRows, MatrixKeyRowsBy
from hail.typecheck import oneof, sequenceof, typecheck
from hail.utils.java import info, warning, Env

_transform_rows_function_map = {}
_merge_function_map = {}
_transform_rows_function_map: Dict[Tuple[hl.HailType], Function] = {}
_merge_function_map: Dict[Tuple[hl.HailType, hl.HailType], Function] = {}


@typecheck(string=expr_str, has_non_ref=expr_bool)
Expand Down Expand Up @@ -121,7 +122,8 @@ def transform_gvcf(mt, info_to_keep=[]) -> Table:
info_to_keep = [name for name in mt.info if name not in ['END', 'DP']]
mt = localize(mt)

if mt.row.dtype not in _transform_rows_function_map:
transform_row = _transform_rows_function_map.get(mt.row.dtype)
if transform_row is None or not hl.current_backend()._is_registered_ir_function_name(transform_row._name):
def get_lgt(e, n_alleles, has_non_ref, row):
index = e.GT.unphased_diploid_gt_index()
n_no_nonref = n_alleles - hl.int(has_non_ref)
Expand Down Expand Up @@ -181,7 +183,7 @@ def make_entry_struct(e, alleles_len, has_non_ref, row):
pass_through_fields = {k: v for k, v in e.items() if k not in handled_names}
return hl.struct(**handled_fields, **pass_through_fields)

f = hl.experimental.define_function(
transform_row = hl.experimental.define_function(
lambda row: hl.rbind(
hl.len(row.alleles), '<NON_REF>' == row.alleles[-1],
lambda alleles_len, has_non_ref: hl.struct(
Expand All @@ -191,8 +193,7 @@ def make_entry_struct(e, alleles_len, has_non_ref, row):
__entries=row.__entries.map(
lambda e: make_entry_struct(e, alleles_len, has_non_ref, row)))),
mt.row.dtype)
_transform_rows_function_map[mt.row.dtype] = f
transform_row = _transform_rows_function_map[mt.row.dtype]
_transform_rows_function_map[mt.row.dtype] = transform_row
return Table(TableMapRows(mt._tir, Apply(transform_row._name, transform_row._ret_type, mt.row._ir)))


Expand Down Expand Up @@ -235,8 +236,9 @@ def renumber_entry(entry, old_to_new) -> StructExpression:
# global index of alternate (non-ref) alleles
return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak]))

if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
f = hl.experimental.define_function(
merge_function = _merge_function_map.get((ts.row.dtype, ts.globals.dtype))
if merge_function is None or not hl.current_backend()._is_registered_ir_function_name(merge_function._name):
merge_function = hl.experimental.define_function(
lambda row, gbl:
hl.rbind(
merge_alleles(row.data.map(lambda d: d.alleles)),
Expand All @@ -260,8 +262,7 @@ def renumber_entry(entry, old_to_new) -> StructExpression:
hl.dict(hl.range(0, hl.len(alleles.globl)).map(
lambda j: hl.tuple([alleles.globl[j], j])))))),
ts.row.dtype, ts.globals.dtype)
_merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
_merge_function_map[(ts.row.dtype, ts.globals.dtype)] = merge_function
ts = Table(TableMapRows(ts._tir, Apply(merge_function._name,
merge_function._ret_type,
ts.row._ir,
Expand Down
35 changes: 17 additions & 18 deletions hail/python/hail/vds/combiner/combine.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Collection, List, Optional, Set
from typing import Collection, List, Optional, Set, Tuple, Dict

import hail as hl
from hail import MatrixTable, Table
from hail.ir import Apply, TableMapRows
from hail.experimental.function import Function
from hail.experimental.vcf_combiner.vcf_combiner import combine_gvcfs, localize, parse_as_fields, unlocalize
from ..variant_dataset import VariantDataset

_transform_variant_function_map = {}
_transform_reference_fuction_map = {}
_transform_variant_function_map: Dict[Tuple[hl.HailType, Tuple[str, ...]], Function] = {}
_transform_reference_fuction_map: Dict[Tuple[hl.HailType, Tuple[str, ...]], Function] = {}
_merge_function_map: Dict[Tuple[hl.HailType, hl.HailType], Function] = {}


def make_variants_matrix_table(mt: MatrixTable,
Expand All @@ -21,7 +23,8 @@ def make_variants_matrix_table(mt: MatrixTable,
mt = localize(mt)
mt = mt.filter(hl.is_missing(mt.info.END))

if (mt.row.dtype, info_key) not in _transform_variant_function_map:
transform_row = _transform_variant_function_map.get((mt.row.dtype, info_key))
if transform_row is None or not hl.current_backend()._is_registered_ir_function_name(transform_row._name):
def get_lgt(e, n_alleles, has_non_ref, row):
index = e.GT.unphased_diploid_gt_index()
n_no_nonref = n_alleles - hl.int(has_non_ref)
Expand Down Expand Up @@ -77,7 +80,7 @@ def make_entry_struct(e, alleles_len, has_non_ref, row):
pass_through_fields = {k: v for k, v in e.items() if k not in handled_names}
return hl.struct(**handled_fields, **pass_through_fields)

f = hl.experimental.define_function(
transform_row = hl.experimental.define_function(
lambda row: hl.rbind(
hl.len(row.alleles), '<NON_REF>' == row.alleles[-1],
lambda alleles_len, has_non_ref: hl.struct(
Expand All @@ -87,8 +90,7 @@ def make_entry_struct(e, alleles_len, has_non_ref, row):
__entries=row.__entries.map(
lambda e: make_entry_struct(e, alleles_len, has_non_ref, row)))),
mt.row.dtype)
_transform_variant_function_map[mt.row.dtype, info_key] = f
transform_row = _transform_variant_function_map[mt.row.dtype, info_key]
_transform_variant_function_map[mt.row.dtype, info_key] = transform_row
return unlocalize(Table(TableMapRows(mt._tir, Apply(transform_row._name, transform_row._ret_type, mt.row._ir))))


Expand Down Expand Up @@ -126,16 +128,16 @@ def make_entry_struct(e, row):
.or_error('found END with non reference-genotype at' + hl.str(row.locus)))

mt = localize(mt)
if (mt.row.dtype, entry_key) not in _transform_reference_fuction_map:
f = hl.experimental.define_function(
transform_row = _transform_reference_fuction_map.get((mt.row.dtype, entry_key))
if transform_row is None or not hl.current_backend()._is_registered_ir_function_name(transform_row._name):
transform_row = hl.experimental.define_function(
lambda row: hl.struct(
locus=row.locus,
__entries=row.__entries.map(
lambda e: make_entry_struct(e, row))),
mt.row.dtype)
_transform_reference_fuction_map[mt.row.dtype, entry_key] = f
_transform_reference_fuction_map[mt.row.dtype, entry_key] = transform_row

transform_row = _transform_reference_fuction_map[mt.row.dtype, entry_key]
return unlocalize(Table(TableMapRows(mt._tir, Apply(transform_row._name, transform_row._ret_type, mt.row._ir))))


Expand Down Expand Up @@ -186,12 +188,10 @@ def transform_gvcf(mt: MatrixTable,
return VariantDataset(ref_mt, var_mt._key_rows_by_assert_sorted('locus', 'alleles'))


_merge_function_map = {}


def combine_r(ts, ref_block_max_len_field):
if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
f = hl.experimental.define_function(
merge_function = _merge_function_map.get((ts.row.dtype, ts.globals.dtype))
if merge_function is None or not hl.current_backend()._is_registered_ir_function_name(merge_function._name):
merge_function = hl.experimental.define_function(
lambda row, gbl:
hl.struct(
locus=row.locus,
Expand All @@ -202,8 +202,7 @@ def combine_r(ts, ref_block_max_len_field):
.map(lambda _: hl.missing(row.data[i].__entries.dtype.element_type)),
row.data[i].__entries))),
ts.row.dtype, ts.globals.dtype)
_merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
_merge_function_map[(ts.row.dtype, ts.globals.dtype)] = merge_function
ts = Table(TableMapRows(ts._tir, Apply(merge_function._name,
merge_function._ret_type,
ts.row._ir,
Expand Down
15 changes: 15 additions & 0 deletions hail/python/test/hail/fs/test_worker_driver_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from hailtop.utils import secret_alnum_string
from hailtop.test_utils import skip_in_azure

from ..helpers import fails_local_backend, fails_service_backend


@skip_in_azure
def test_requester_pays_no_settings():
Expand All @@ -25,6 +27,19 @@ def test_requester_pays_write_no_settings():
assert False


@skip_in_azure
@fails_local_backend()
@fails_service_backend()
def test_requester_pays_write_with_project():
hl.stop()
hl.init(gcs_requester_pays_configuration='hail-vdc')
random_filename = 'gs://hail-services-requester-pays/test_requester_pays_on_worker_driver_' + secret_alnum_string(10)
try:
hl.utils.range_table(4, n_partitions=4).write(random_filename, overwrite=True)
finally:
hl.current_backend().fs.rmtree(random_filename)


@skip_in_azure
def test_requester_pays_with_project():
hl.stop()
Expand Down
25 changes: 22 additions & 3 deletions hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,30 @@ class GoogleStorageFS(
.build()

val os: PositionedOutputStream = new FSPositionedOutputStream(8 * 1024 * 1024) {
private[this] val write: WriteChannel = storage.writer(blobInfo)
private[this] var writer: WriteChannel = null

private[this] def doHandlingRequesterPays(f: => Unit): Unit = {
if (writer != null) {
f
} else {
handleRequesterPays(
{ (options: Seq[BlobWriteOption]) =>
writer = storage.writer(blobInfo, options:_*)
f
},
BlobWriteOption.userProject _,
bucket
)
}
}

override def flush(): Unit = {
bb.flip()

while (bb.remaining() > 0)
write.write(bb)
doHandlingRequesterPays {
writer.write(bb)
}

bb.clear()
}
Expand All @@ -284,7 +301,9 @@ class GoogleStorageFS(
if (!closed) {
flush()
retryTransientErrors {
write.close()
doHandlingRequesterPays {
writer.close()
}
}
closed = true
}
Expand Down

0 comments on commit 9386fdd

Please sign in to comment.