diff --git a/.gitignore b/.gitignore index 755cae1ddcb..7382426347d 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,5 @@ python/hail/docs/_build/* src/main/c/libsimdpp-2.0-rc2.tar.gz src/main/c/libsimdpp-2.0-rc2 target -src/main/c/lib src/main/resources/build-info.properties src/test/resources/example.v11.bgen.idx diff --git a/python/hail/__init__.py b/python/hail/__init__.py index cd53b726278..14cd5bee69f 100644 --- a/python/hail/__init__.py +++ b/python/hail/__init__.py @@ -55,3 +55,5 @@ import builtins __all__.extend([x for x in expr.__all__ if not hasattr(builtins, x)]) del builtins + +__version__ = None # set in hail.init() diff --git a/python/hail/context.py b/python/hail/context.py index d93c97dae66..4cedc320495 100644 --- a/python/hail/context.py +++ b/python/hail/context.py @@ -1,6 +1,7 @@ from pyspark import SparkContext from pyspark.sql import SQLContext +import hail from hail.genetics.reference_genome import ReferenceGenome from hail.typecheck import nullable, typecheck, typecheck_method, enumeration from hail.utils import wrap_to_list, get_env_or_default @@ -65,6 +66,9 @@ def __init__(self, sc=None, app_name="Hail", master=None, local='local[*]', self._default_ref = None Env.hail().variant.ReferenceGenome.setDefaultReference(self._jhc, default_reference) + version = self._jhc.version() + hail.__version__ = version + if not quiet: sys.stderr.write('Running on Apache Spark version {}\n'.format(self.sc.version)) if self._jsc.uiWebUrl().isDefined(): @@ -79,20 +83,15 @@ def __init__(self, sc=None, app_name="Hail", master=None, local='local[*]', ' __ __ <>__\n' ' / /_/ /__ __/ /\n' ' / __ / _ `/ / /\n' - ' /_/ /_/\_,_/_/_/ version {}\n'.format(self.version)) + ' /_/ /_/\_,_/_/_/ version {}\n'.format(version)) - if self.version.startswith('devel'): + if version.startswith('devel'): sys.stderr.write('NOTE: This is a beta version. Interfaces may change\n' ' during the beta period. We also recommend pulling\n' - ' the latest changes weekly.') + ' the latest changes weekly.\n') install_exception_handler() - - @property - def version(self): - return self._jhc.version() - @property def default_reference(self): if not self._default_ref: diff --git a/python/hail/docs/methods/impex.rst b/python/hail/docs/methods/impex.rst index 7ca1df6f95d..625ce8f4dea 100644 --- a/python/hail/docs/methods/impex.rst +++ b/python/hail/docs/methods/impex.rst @@ -16,6 +16,7 @@ Import / Export get_vcf_metadata import_bed import_bgen + index_bgen import_fam import_gen import_locus_intervals @@ -34,6 +35,7 @@ Import / Export .. autofunction:: get_vcf_metadata .. autofunction:: import_bed .. autofunction:: import_bgen +.. autofunction:: index_bgen .. autofunction:: import_fam .. autofunction:: import_gen .. autofunction:: import_locus_intervals diff --git a/python/hail/docs/methods/index.rst b/python/hail/docs/methods/index.rst index 57a152b6f98..950f0da1c79 100644 --- a/python/hail/docs/methods/index.rst +++ b/python/hail/docs/methods/index.rst @@ -23,6 +23,7 @@ Methods get_vcf_metadata import_bed import_bgen + index_bgen import_fam import_gen import_locus_intervals diff --git a/python/hail/expr/aggregators.py b/python/hail/expr/aggregators.py index 621d1774df6..d7e5655c45e 100644 --- a/python/hail/expr/aggregators.py +++ b/python/hail/expr/aggregators.py @@ -1,18 +1,43 @@ import hail as hl -from hail.typecheck import * +from hail.typecheck import TypeChecker, TypecheckFailure from hail.expr.expressions import * from hail.expr.expr_ast import * from hail.expr.types import * +class AggregableChecker(TypeChecker): + def __init__(self, coercer): + self.coercer = coercer + super(AggregableChecker, self).__init__() + + def expects(self): + return self.coercer.expects() + + def format(self, arg): + if isinstance(arg, Aggregable): + return f'' + else: + return self.coercer.format(arg) + + def check(self, x, caller, param): + coercer = self.coercer + if isinstance(x, Aggregable): + if coercer.can_coerce(x.dtype): + if coercer.requires_conversion(x.dtype): + return x._map(lambda x_: coercer.coerce(x_)) + else: + return x + else: + raise TypecheckFailure + else: + x = coercer.check(x, caller, param) + return _to_agg(x) def _to_agg(x): - if isinstance(x, Aggregable): - return x - else: - x = to_expr(x) - uid = Env.get_uid() - ast = LambdaClassMethod('map', uid, AggregableReference(), x._ast) - return Aggregable(ast, x._type, x._indices, x._aggregations, x._joins) + uid = Env.get_uid() + ast = LambdaClassMethod('map', uid, AggregableReference(), x._ast) + return Aggregable(ast, x._type, x._indices, x._aggregations, x._joins) + +agg_expr = AggregableChecker def _agg_func(name, aggregable, ret_type, *args): @@ -39,7 +64,7 @@ def _check_agg_bindings(expr): if free_variables: raise ExpressionException("dynamic variables created by 'hl.bind' or lambda methods like 'hl.map' may not be aggregated") -@typecheck(expr=oneof(Aggregable, expr_any)) +@typecheck(expr=agg_expr(expr_any)) def collect(expr): """Collect records into an array. @@ -73,10 +98,9 @@ def collect(expr): :class:`.ArrayExpression` Array of all `expr` records. """ - agg = _to_agg(expr) - return _agg_func('collect', agg, tarray(agg._type)) + return _agg_func('collect', expr, tarray(expr.dtype)) -@typecheck(expr=oneof(Aggregable, expr_any)) +@typecheck(expr=agg_expr(expr_any)) def collect_as_set(expr): """Collect records into a set. @@ -103,11 +127,9 @@ def collect_as_set(expr): :class:`.SetExpression` Set of unique `expr` records. """ + return _agg_func('collectAsSet', expr, tarray(expr.dtype)) - agg = _to_agg(expr) - return _agg_func('collectAsSet', agg, tarray(agg._type)) - -@typecheck(expr=nullable(oneof(Aggregable, expr_any))) +@typecheck(expr=nullable(agg_expr(expr_any))) def count(expr=None): """Count the number of records. @@ -147,11 +169,11 @@ def count(expr=None): Total number of records. """ if expr is not None: - return _agg_func('count', _to_agg(expr), tint64) + return _agg_func('count', expr, tint64) else: - return _agg_func('count', _to_agg(0), tint64) + return _agg_func('count', _to_agg(hl.int32(0)), tint64) -@typecheck(condition=oneof(Aggregable, expr_bool)) +@typecheck(condition=expr_bool) def count_where(condition): """Count the number of records where a predicate is ``True``. @@ -178,7 +200,7 @@ def count_where(condition): return _agg_func('count', filter(condition, 0), tint64) -@typecheck(condition=oneof(Aggregable, expr_bool)) +@typecheck(condition=agg_expr(expr_bool)) def any(condition): """Returns ``True`` if `condition` is ``True`` for any record. @@ -216,7 +238,7 @@ def any(condition): """ return count(filter(lambda x: x, condition)) > 0 -@typecheck(condition=oneof(Aggregable, expr_bool)) +@typecheck(condition=agg_expr(expr_bool)) def all(condition): """Returns ``True`` if `condition` is ``True`` for every record. @@ -256,7 +278,7 @@ def all(condition): n_true = count(filter(lambda x: hl.is_defined(x) & x, condition)) return n_defined == n_true -@typecheck(expr=oneof(Aggregable, expr_any)) +@typecheck(expr=agg_expr(expr_any)) def counter(expr): """Count the occurrences of each unique record and return a dictionary. @@ -293,10 +315,11 @@ def counter(expr): :class:`.DictExpression` Dictionary with the number of occurrences of each unique record. """ - agg = _to_agg(expr) - return _agg_func('counter', agg, tdict(agg._type, tint64)) + return _agg_func('counter', expr, tdict(expr.dtype, tint64)) -@typecheck(expr=oneof(Aggregable, expr_any), n=int, ordering=nullable(oneof(expr_any, func_spec(1, expr_any)))) +@typecheck(expr=agg_expr(expr_any), + n=int, + ordering=nullable(oneof(expr_any, func_spec(1, expr_any)))) def take(expr, n, ordering=None): """Take `n` records of `expr`, optionally ordered by `ordering`. @@ -353,37 +376,36 @@ def take(expr, n, ordering=None): Array of up to `n` records of `expr`. """ - agg = _to_agg(expr) n = to_expr(n) if ordering is None: - return _agg_func('take', agg, tarray(agg.dtype), n) + return _agg_func('take', expr, tarray(expr.dtype), n) else: uid = Env.get_uid() if callable(ordering): lambda_result = to_expr( - ordering(construct_expr(VariableReference(uid), agg._type, agg._indices, - agg._aggregations, agg._joins))) + ordering(construct_expr(VariableReference(uid), expr.dtype, expr._indices, + expr._aggregations, expr._joins))) else: lambda_result = ordering - indices, aggregations, joins = unify_all(agg, lambda_result) + indices, aggregations, joins = unify_all(expr, lambda_result) if not (is_numeric(ordering.dtype) or ordering.dtype == tstr): raise TypeError("'take' expects 'ordering' to be or return an ordered expression\n" " Ordered expressions are 'int32', 'int64', 'float32', 'float64', 'str'\n" " Found '{}'".format(ordering._type)) - ast = LambdaClassMethod('takeBy', uid, agg._ast, lambda_result._ast, n._ast) + ast = LambdaClassMethod('takeBy', uid, expr._ast, lambda_result._ast, n._ast) if aggregations: raise ExpressionException('Cannot aggregate an already-aggregated expression') - _check_agg_bindings(agg) + _check_agg_bindings(expr) _check_agg_bindings(lambda_result) - return construct_expr(ast, tarray(agg._type), Indices(source=indices.source), - aggregations.push(Aggregation(agg, lambda_result)), joins) + return construct_expr(ast, tarray(expr._type), Indices(source=indices.source), + aggregations.push(Aggregation(expr, lambda_result)), joins) -@typecheck(expr=oneof(Aggregable, expr_numeric)) +@typecheck(expr=agg_expr(expr_numeric)) def min(expr): """Compute the minimum `expr`. @@ -411,12 +433,9 @@ def min(expr): :class:`.NumericExpression` Minimum value of all `expr` records, same type as `expr`. """ - agg = _to_agg(expr) - if not is_numeric(agg._type): - raise TypeError("'min' expects a numeric argument, found '{}'".format(agg._type)) - return _agg_func('min', agg, agg._type) + return _agg_func('min', expr, expr.dtype) -@typecheck(expr=oneof(Aggregable, expr_numeric)) +@typecheck(expr=agg_expr(expr_numeric)) def max(expr): """Compute the maximum `expr`. @@ -444,12 +463,9 @@ def max(expr): :class:`.NumericExpression` Maximum value of all `expr` records, same type as `expr`. """ - agg = _to_agg(expr) - if not is_numeric(agg._type): - raise TypeError("'max' expects a numeric argument, found '{}'".format(agg._type)) - return _agg_func('max', agg, agg._type) + return _agg_func('max', expr, expr.dtype) -@typecheck(expr=oneof(Aggregable, expr_numeric)) +@typecheck(expr=agg_expr(expr_numeric)) def sum(expr): """Compute the sum of all records of `expr`. @@ -481,12 +497,9 @@ def sum(expr): :class:`.Expression` of type :py:data:`.tint64` or :py:data:`.tfloat64` Sum of records of `expr`. """ - agg = _to_agg(expr) - if not is_numeric(agg._type): - raise TypeError("'sum' expects a numeric argument, found '{}'".format(agg._type)) - return _agg_func('sum', agg, agg._type) + return _agg_func('sum', expr, expr.dtype) -@typecheck(expr=oneof(Aggregable, expr_any)) +@typecheck(expr=agg_expr(expr_array(expr_numeric))) def array_sum(expr): """Compute the coordinate-wise sum of all records of `expr`. @@ -512,12 +525,9 @@ def array_sum(expr): ------- :class:`.ArrayNumericExpression` """ - agg = _to_agg(expr) - if not (isinstance(agg._type, tarray) and is_numeric(agg._type.element_type)): - raise TypeError("'array_sum' expects a numeric array argument, found '{}'".format(agg._type)) - return _agg_func('sum', agg, agg._type) + return _agg_func('sum', expr, expr.dtype.element_type) -@typecheck(expr=oneof(Aggregable, expr_numeric)) +@typecheck(expr=agg_expr(expr_float64)) def mean(expr): """Compute the mean value of records of `expr`. @@ -544,12 +554,9 @@ def mean(expr): :class:`.Expression` of type :py:data:`.tfloat64` Mean value of records of `expr`. """ - agg = _to_agg(expr) - if not is_numeric(agg._type): - raise TypeError("'mean' expects a numeric argument, found '{}'".format(agg._type)) - return stats(agg).mean + return stats(expr).mean -@typecheck(expr=oneof(Aggregable, expr_numeric)) +@typecheck(expr=agg_expr(expr_float64)) def stats(expr): """Compute a number of useful statistics about `expr`. @@ -584,18 +591,14 @@ def stats(expr): Struct expression with fields `mean`, `stdev`, `min`, `max`, `n`, and `sum`. """ - agg = _to_agg(expr) - if not is_numeric(agg._type): - raise TypeError("'stats' expects a numeric argument, found '{}'".format(agg._type)) - agg = Expression._promote_numeric(agg, tfloat64) - return _agg_func('stats', agg, tstruct(mean=tfloat64, + return _agg_func('stats', expr, tstruct(mean=tfloat64, stdev=tfloat64, min=tfloat64, max=tfloat64, n=tint64, sum=tfloat64)) -@typecheck(expr=oneof(Aggregable, expr_numeric)) +@typecheck(expr=agg_expr(expr_oneof(expr_int64, expr_float64))) def product(expr): """Compute the product of all records of `expr`. @@ -628,13 +631,9 @@ def product(expr): Product of records of `expr`. """ - agg = _to_agg(expr) - if not is_numeric(agg._type): - raise TypeError("'product' expects a numeric argument, found '{}'".format(agg._type)) - agg = Expression._promote_numeric(agg, tfloat64) - return _agg_func('product', agg, agg._type) + return _agg_func('product', expr, expr.dtype) -@typecheck(predicate=oneof(Aggregable, expr_bool)) +@typecheck(predicate=agg_expr(expr_bool)) def fraction(predicate): """Compute the fraction of records where `predicate` is ``True``. @@ -661,20 +660,15 @@ def fraction(predicate): :class:`.Expression` of type :py:data:`.tfloat64` Fraction of records where `predicate` is ``True``. """ - agg = _to_agg(predicate) - if not agg.dtype == tbool: - raise TypeError( - "'fraction' aggregator expects an expression of type 'bool', found '{}'".format(agg.dtype)) - - if agg._aggregations: + if predicate._aggregations: raise ExpressionException('Cannot aggregate an already-aggregated expression') uid = Env.get_uid() - ast = LambdaClassMethod('fraction', uid, agg._ast, VariableReference(uid)) - return construct_expr(ast, tfloat64, Indices(source=agg._indices.source), - agg._aggregations.push(Aggregation(agg)), agg._joins) + ast = LambdaClassMethod('fraction', uid, predicate._ast, VariableReference(uid)) + return construct_expr(ast, tfloat64, Indices(source=predicate._indices.source), + predicate._aggregations.push(Aggregation(predicate)), predicate._joins) -@typecheck(expr=oneof(Aggregable, expr_any)) +@typecheck(expr=agg_expr(expr_call)) def hardy_weinberg(expr): """Compute Hardy-Weinberg Equilbrium (HWE) p-value and heterozygosity ratio. @@ -722,16 +716,11 @@ def hardy_weinberg(expr): :class:`.StructExpression` Struct expression with fields `r_expected_het_freq` and `p_hwe`. """ - t = tstruct(r_expected_het_freq=tfloat64, - p_hwe=tfloat64) - agg = _to_agg(expr) - if not agg.dtype == tcall: - raise TypeError("aggregator 'hardy_weinberg' requires an expression of type 'Call', found '{}'".format( - agg._type.__class__)) - return _agg_func('hardyWeinberg', agg, t) + t = tstruct(r_expected_het_freq=tfloat64, p_hwe=tfloat64) + return _agg_func('hardyWeinberg', expr, t) -@typecheck(expr=oneof(Aggregable, expr_array(), expr_set())) +@typecheck(expr=agg_expr(expr_oneof(expr_array(), expr_set()))) def explode(expr): """Explode an array or set expression to aggregate the elements of all records. @@ -778,14 +767,11 @@ def explode(expr): :class:`.Aggregable` Aggregable expression. """ - agg = _to_agg(expr) - if not (isinstance(agg._type, tset) or isinstance(agg._type, tarray)): - raise TypeError("'explode' expects a 'Set' or 'Array' argument, found '{}'".format(agg._type)) uid = Env.get_uid() - return Aggregable(LambdaClassMethod('flatMap', uid, agg._ast, VariableReference(uid)), - agg._type.element_type, agg._indices, agg._aggregations, agg._joins) + return Aggregable(LambdaClassMethod('flatMap', uid, expr._ast, VariableReference(uid)), + expr._type.element_type, expr._indices, expr._aggregations, expr._joins) -@typecheck(condition=oneof(func_spec(1, expr_bool), expr_bool), expr=oneof(Aggregable, expr_any)) +@typecheck(condition=oneof(func_spec(1, expr_bool), expr_bool), expr=agg_expr(expr_any)) def filter(condition, expr): """Filter records according to a predicate. @@ -820,27 +806,24 @@ def filter(condition, expr): Aggregable expression. """ - agg = _to_agg(expr) uid = Env.get_uid() if callable(condition): lambda_result = to_expr( condition( - construct_expr(VariableReference(uid), agg._type, agg._indices, - agg._aggregations, agg._joins))) + construct_expr(VariableReference(uid), expr._type, expr._indices, + expr._aggregations, expr._joins))) else: lambda_result = to_expr(condition) - if lambda_result.dtype != tbool: - raise TypeError( - "'filter' expects the 'condition' argument to be or return an expression of type 'bool', found '{}'".format( - lambda_result.dtype)) - indices, aggregations, joins = unify_all(agg, lambda_result) - ast = LambdaClassMethod('filter', uid, agg._ast, lambda_result._ast) - return Aggregable(ast, agg.dtype, indices, aggregations, joins) + assert lambda_result.dtype == tbool + indices, aggregations, joins = unify_all(expr, lambda_result) + ast = LambdaClassMethod('filter', uid, expr._ast, lambda_result._ast) + return Aggregable(ast, expr.dtype, indices, aggregations, joins) -@typecheck(expr=oneof(Aggregable, expr_call), prior=expr_numeric) + +@typecheck(expr=agg_expr(expr_call), prior=expr_float64) def inbreeding(expr, prior): """Compute inbreeding statistics on calls. @@ -900,35 +883,24 @@ def inbreeding(expr, prior): :class:`.StructExpression` Struct expression with fields `f_stat`, `n_called`, `expected_homs`, `observed_homs`. """ - agg = _to_agg(expr) - - if not agg.dtype == tcall: - raise TypeError("aggregator 'inbreeding' requires an expression of type 'call', found '{}'".format( - agg.dtype)) - - if prior.dtype == tfloat32: - prior = hl.float64(prior) - if not prior.dtype == tfloat64: - raise TypeError("'inbreeding' expects 'prior' to be type 'float32' or 'float64', found '{}'".format(prior._type)) - uid = Env.get_uid() - ast = LambdaClassMethod('inbreeding', uid, agg._ast, prior._ast) + ast = LambdaClassMethod('inbreeding', uid, expr._ast, prior._ast) - indices, aggregations, joins = unify_all(agg, prior) + indices, aggregations, joins = unify_all(expr, prior) if aggregations: raise ExpressionException('Cannot aggregate an already-aggregated expression') - _check_agg_bindings(agg) + _check_agg_bindings(expr) _check_agg_bindings(prior) t = tstruct(f_stat=tfloat64, n_called=tint64, expected_homs=tfloat64, observed_homs=tint64) return construct_expr(ast, t, Indices(source=indices.source), - aggregations.push(Aggregation(agg, prior)), joins) + aggregations.push(Aggregation(expr, prior)), joins) -@typecheck(expr=oneof(Aggregable, expr_call), alleles=expr_array(expr_str)) +@typecheck(expr=agg_expr(expr_call), alleles=expr_array(expr_str)) def call_stats(expr, alleles): """Compute useful call statistics. @@ -983,31 +955,25 @@ def call_stats(expr, alleles): :class:`.StructExpression` Struct expression with fields `AC`, `AF`, and `AN` """ - agg = _to_agg(expr) alleles = to_expr(alleles) - uid = Env.get_uid() - if agg.dtype != tcall: - raise TypeError("aggregator 'call_stats' requires an expression of type 'TCall', found '{}'".format( - agg.dtype)) - - ast = LambdaClassMethod('callStats', uid, agg._ast, alleles._ast) - indices, aggregations, joins = unify_all(agg, alleles) + ast = LambdaClassMethod('callStats', uid, expr._ast, alleles._ast) + indices, aggregations, joins = unify_all(expr, alleles) if aggregations: raise ExpressionException('Cannot aggregate an already-aggregated expression') - _check_agg_bindings(agg) + _check_agg_bindings(expr) _check_agg_bindings(alleles) t = tstruct(AC=tarray(tint32), AF=tarray(tfloat64), AN=tint32) return construct_expr(ast, t, Indices(source=indices.source), - aggregations.push(Aggregation(agg, alleles)), joins) + aggregations.push(Aggregation(expr, alleles)), joins) -@typecheck(expr=oneof(Aggregable, expr_numeric), start=expr_float64, end=expr_float64, bins=expr_int32) +@typecheck(expr=agg_expr(expr_float64), start=expr_float64, end=expr_float64, bins=expr_int32) def hist(expr, start, end, bins): """Compute binned counts of a numeric expression. @@ -1053,12 +1019,8 @@ def hist(expr, start, end, bins): :class:`.StructExpression` Struct expression with fields `bin_edges`, `bin_freq`, `n_smaller`, and `n_larger`. """ - agg = _to_agg(expr) - if not is_numeric(agg._type): - raise TypeError("'hist' expects argument 'expr' to be a numeric type, found '{}'".format(agg._type)) - agg = Expression._promote_numeric(agg, tfloat64) t = tstruct(bin_edges=tarray(tfloat64), bin_freq=tarray(tint64), n_less=tint64, n_larger=tint64) - return _agg_func('hist', agg, t, start, end, bins) + return _agg_func('hist', expr, t, start, end, bins) diff --git a/python/hail/expr/expressions/base_expression.py b/python/hail/expr/expressions/base_expression.py index 23096bbbdaf..cf2e2d35e8f 100644 --- a/python/hail/expr/expressions/base_expression.py +++ b/python/hail/expr/expressions/base_expression.py @@ -233,7 +233,7 @@ def unify_all(*exprs) -> Tuple[Indices, LinkedList, LinkedList]: "\n Found fields from {n} objects:{fields}".format( n=len(sources), fields=''.join("\n {}: {}".format(src, fds) for src, fds in sources.items()) - )) + )) from None first, rest = exprs[0], exprs[1:] aggregations = first._aggregations joins = first._joins @@ -365,18 +365,26 @@ def _promote_scalar(self, typ): return hail.float64(self) def _promote_numeric(self, typ): - if isinstance(typ, tarray): - if isinstance(self.dtype, tarray): - return hail.map(lambda x: x._promote_scalar(typ.element_type), self) - else: - return self._promote_scalar(typ.element_type) - elif isinstance(self, expressions.Aggregable): - return self._map(lambda x: x._promote_scalar(typ)) + coercer = expressions.coercer_from_dtype(typ) + if isinstance(typ, tarray) and not isinstance(self.dtype, tarray): + return coercer.ec.coerce(self) else: - return self._promote_scalar(typ) + return coercer.coerce(self) def _bin_op_numeric_unify_types(self, name, other): - t = unify_types(self.dtype._scalar_type(), other.dtype._scalar_type()) + def numeric_proxy(t): + if t == tbool: + return tint32 + else: + return t + + def scalar_type(t): + if isinstance(t, tarray): + return numeric_proxy(t.element_type) + else: + return numeric_proxy(t) + + t = unify_types(scalar_type(self.dtype), scalar_type(other.dtype)) if t is None: raise NotImplementedError("'{}' {} '{}'".format( self.dtype, name, other.dtype)) @@ -399,8 +407,7 @@ def _bin_op_numeric(self, name, other, ret_type_f=None): return me._bin_op(name, other, ret_type) def _bin_op_numeric_reverse(self, name, other, ret_type_f=None): - other = to_expr(other) - return other._bin_op_numeric(name, self, ret_type_f) + return to_expr(other)._bin_op_numeric(name, self, ret_type_f) def _unary_op(self, name): return expressions.construct_expr(UnaryOperation(self._ast, name), @@ -465,13 +472,13 @@ def _bin_lambda_method(self, name, f, input_type, ret_type_f, *args): return expressions.construct_expr(ast, ret_type_f(lambda_result._type), indices, aggregations, joins) @property - def dtype(self): + def dtype(self) -> HailType: """The data type of the expression. Returns ------- :class:`.HailType` - Data type. + """ return self._type diff --git a/python/hail/expr/expressions/expression_typecheck.py b/python/hail/expr/expressions/expression_typecheck.py index 5e8dabbca64..102e6147203 100644 --- a/python/hail/expr/expressions/expression_typecheck.py +++ b/python/hail/expr/expressions/expression_typecheck.py @@ -52,7 +52,8 @@ def _requires_conversion(self, t: HailType) -> bool: def can_coerce(self, t: HailType) -> bool: ... - def coerce(self, x: Expression) -> Expression: + def coerce(self, x) -> Expression: + x = to_expr(x) if not self.can_coerce(x.dtype): raise ExpressionException(f"cannot coerce type '{x.dtype}' to type '{self.str_t}'") if self._requires_conversion(x.dtype): diff --git a/python/hail/expr/expressions/typed_expressions.py b/python/hail/expr/expressions/typed_expressions.py index 607250fb9c6..982d4bdc308 100644 --- a/python/hail/expr/expressions/typed_expressions.py +++ b/python/hail/expr/expressions/typed_expressions.py @@ -1403,141 +1403,6 @@ def __iter__(self): yield self[i] -class BooleanExpression(Expression): - """Expression of type :py:data:`.tbool`. - - >>> t = hl.literal(True) - >>> f = hl.literal(False) - >>> na = hl.null(hl.tbool) - - .. doctest:: - - >>> hl.eval_expr(t) - True - - >>> hl.eval_expr(f) - False - - >>> hl.eval_expr(na) - None - - """ - - def _bin_op_logical(self, name, other): - other = to_expr(other) - return self._bin_op(name, other, tbool) - - @typecheck_method(other=expr_bool) - def __rand__(self, other): - return self.__and__(other) - - @typecheck_method(other=expr_bool) - def __ror__(self, other): - return self.__or__(other) - - @typecheck_method(other=expr_bool) - def __and__(self, other): - """Return ``True`` if the left and right arguments are ``True``. - - Examples - -------- - .. doctest:: - - >>> hl.eval_expr(t & f) - False - - >>> hl.eval_expr(t & na) - None - - >>> hl.eval_expr(f & na) - False - - The ``&`` and ``|`` operators have higher priority than comparison - operators like ``==``, ``<``, or ``>``. Parentheses are often - necessary: - - .. doctest:: - - >>> x = hl.literal(5) - - >>> hl.eval_expr((x < 10) & (x > 2)) - True - - Parameters - ---------- - other : :class:`.BooleanExpression` - Right-side operand. - - Returns - ------- - :class:`.BooleanExpression` - ``True`` if both left and right are ``True``. - """ - return self._bin_op_logical("&&", other) - - @typecheck_method(other=expr_bool) - def __or__(self, other): - """Return ``True`` if at least one of the left and right arguments is ``True``. - - Examples - -------- - .. doctest:: - - >>> hl.eval_expr(t | f) - True - - >>> hl.eval_expr(t | na) - True - - >>> hl.eval_expr(f | na) - None - - The ``&`` and ``|`` operators have higher priority than comparison - operators like ``==``, ``<``, or ``>``. Parentheses are often - necessary: - - .. doctest:: - - >>> x = hl.literal(5) - - >>> hl.eval_expr((x < 10) | (x > 20)) - True - - Parameters - ---------- - other : :class:`.BooleanExpression` - Right-side operand. - - Returns - ------- - :class:`.BooleanExpression` - ``True`` if either left or right is ``True``. - """ - return self._bin_op_logical("||", other) - - def __invert__(self): - """Return the boolean negation. - - Examples - -------- - .. doctest:: - - >>> hl.eval_expr(~t) - False - - >>> hl.eval_expr(~f) - True - - >>> hl.eval_expr(~na) - None - - Returns - ------- - :class:`.BooleanExpression` - Boolean negation. - """ - return self._unary_op("!") - class NumericExpression(Expression): """Expression of numeric type. @@ -1657,7 +1522,8 @@ def __neg__(self): :class:`.NumericExpression` Negated number. """ - return self._unary_op("-") + + return expr_numeric.coerce(self)._unary_op("-") def __add__(self, other): """Add two numbers. @@ -1877,6 +1743,142 @@ def __rpow__(self, other): return self._bin_op_numeric_reverse('**', other, lambda _: tfloat64) +class BooleanExpression(NumericExpression): + """Expression of type :py:data:`.tbool`. + + >>> t = hl.literal(True) + >>> f = hl.literal(False) + >>> na = hl.null(hl.tbool) + + .. doctest:: + + >>> hl.eval_expr(t) + True + + >>> hl.eval_expr(f) + False + + >>> hl.eval_expr(na) + None + + """ + + def _bin_op_logical(self, name, other): + other = to_expr(other) + return self._bin_op(name, other, tbool) + + @typecheck_method(other=expr_bool) + def __rand__(self, other): + return self.__and__(other) + + @typecheck_method(other=expr_bool) + def __ror__(self, other): + return self.__or__(other) + + @typecheck_method(other=expr_bool) + def __and__(self, other): + """Return ``True`` if the left and right arguments are ``True``. + + Examples + -------- + .. doctest:: + + >>> hl.eval_expr(t & f) + False + + >>> hl.eval_expr(t & na) + None + + >>> hl.eval_expr(f & na) + False + + The ``&`` and ``|`` operators have higher priority than comparison + operators like ``==``, ``<``, or ``>``. Parentheses are often + necessary: + + .. doctest:: + + >>> x = hl.literal(5) + + >>> hl.eval_expr((x < 10) & (x > 2)) + True + + Parameters + ---------- + other : :class:`.BooleanExpression` + Right-side operand. + + Returns + ------- + :class:`.BooleanExpression` + ``True`` if both left and right are ``True``. + """ + return self._bin_op_logical("&&", other) + + @typecheck_method(other=expr_bool) + def __or__(self, other): + """Return ``True`` if at least one of the left and right arguments is ``True``. + + Examples + -------- + .. doctest:: + + >>> hl.eval_expr(t | f) + True + + >>> hl.eval_expr(t | na) + True + + >>> hl.eval_expr(f | na) + None + + The ``&`` and ``|`` operators have higher priority than comparison + operators like ``==``, ``<``, or ``>``. Parentheses are often + necessary: + + .. doctest:: + + >>> x = hl.literal(5) + + >>> hl.eval_expr((x < 10) | (x > 20)) + True + + Parameters + ---------- + other : :class:`.BooleanExpression` + Right-side operand. + + Returns + ------- + :class:`.BooleanExpression` + ``True`` if either left or right is ``True``. + """ + return self._bin_op_logical("||", other) + + def __invert__(self): + """Return the boolean negation. + + Examples + -------- + .. doctest:: + + >>> hl.eval_expr(~t) + False + + >>> hl.eval_expr(~f) + True + + >>> hl.eval_expr(~na) + None + + Returns + ------- + :class:`.BooleanExpression` + Boolean negation. + """ + return self._unary_op("!") + + class Float64Expression(NumericExpression): """Expression of type :py:data:`.tfloat64`.""" pass @@ -2048,6 +2050,78 @@ def split(self, delim, n=None): else: return self._method("split", tarray(tstr), delim, n) + def lower(self): + """Returns a copy of the string, but with upper case letters converted + to lower case. + + Examples + -------- + >>> s.lower().value + 'the quick brown fox' + + Returns + ------- + :class:`.StringExpression` + """ + return self._method("lower", tstr) + + def upper(self): + """Returns a copy of the string, but with lower case letters converted + to upper case. + + Examples + -------- + >>> s.upper().value + 'THE QUICK BROWN FOX' + + Returns + ------- + :class:`.StringExpression` + """ + return self._method("upper", tstr) + + def strip(self): + r"""Returns a copy of the string with whitespace removed from the start + and end. + + Examples + -------- + >>> s2 = hl.str(' once upon a time\n') + >>> s2.strip().value + 'once upon a time' + + Returns + ------- + :class:`.StringExpression` + """ + return self._method("strip", tstr) + + @typecheck_method(substr=expr_str) + def contains(self, substr): + """Returns whether `substr` is contained in the string. + + Examples + -------- + >>> s.contains('fox').value + True + + >>> s.contains('dog').value + False + + Note + ---- + This method is case-sensitive. + + Parameters + ---------- + substr : :class:`.StringExpression` + + Returns + ------- + :class:`.BooleanExpression` + """ + return self._method("contains", tstr, substr) + @typecheck_method(regex=str) def matches(self, regex): """Returns ``True`` if the string contains any match for the given regex. @@ -2091,30 +2165,6 @@ def matches(self, regex): return construct_expr(RegexMatch(self._ast, regex), tbool, self._indices, self._aggregations, self._joins) - def to_boolean(self): - """Parse the string to a Boolean. - - Examples - -------- - .. doctest:: - - >>> s = hl.literal('TRUE') - >>> hl.eval_expr(s.to_boolean()) - True - - Notes - ----- - Acceptable values are: ``True``, ``true``, ``TRUE``, ``False``, - ``false``, and ``FALSE``. - - Returns - ------- - :class:`.BooleanExpression` - Parsed Boolean expression. - """ - - return self._method("toBoolean", tbool) - class CallExpression(Expression): """Expression of type :py:data:`.tcall`. diff --git a/python/hail/expr/functions.py b/python/hail/expr/functions.py index 598ec2002d1..283df9f4ac0 100644 --- a/python/hail/expr/functions.py +++ b/python/hail/expr/functions.py @@ -2116,7 +2116,10 @@ def str(x: Expression) -> StringExpression: ------- :class:`.StringExpression` """ - return _func("str", tstr, x) + if x.dtype == tstr: + return x + else: + return _func("str", tstr, x) @typecheck(c=expr_call, i=expr_int32) diff --git a/python/hail/expr/types.py b/python/hail/expr/types.py index 501a1072219..749567483e4 100644 --- a/python/hail/expr/types.py +++ b/python/hail/expr/types.py @@ -198,13 +198,6 @@ def _convert_from_json_na(self, x): def _convert_from_json(self, x): return x - def _scalar_type(self): - if isinstance(self, tarray): - return self.element_type - else: - assert is_numeric(self) - return self - hail_type = oneof(HailType, transformed((str, dtype))) @@ -1138,8 +1131,8 @@ def _convert_to_json(self, x): hts_entry_schema = tstruct(GT=tcall, AD=tarray(tint32), DP=tint32, GQ=tint32, PL=tarray(tint32)) -_numeric_types = {tint32, tint64, tfloat32, tfloat64} -_primitive_types = _numeric_types.union({tbool, tstr}) +_numeric_types = {tbool, tint32, tint64, tfloat32, tfloat64} +_primitive_types = _numeric_types.union({tstr}) @typecheck(t=HailType) diff --git a/python/hail/matrixtable.py b/python/hail/matrixtable.py index 5b69ea356ff..6027440d35a 100644 --- a/python/hail/matrixtable.py +++ b/python/hail/matrixtable.py @@ -1,5 +1,6 @@ import itertools from typing import * +from collections import OrderedDict import hail import hail as hl @@ -940,16 +941,12 @@ def annotate_entries(self, **named_exprs: NamedExprs) -> 'MatrixTable': :class:`.MatrixTable` Matrix table with new row-and-column-indexed field(s). """ - exprs = [] named_exprs = {k: to_expr(v) for k, v in named_exprs.items()} - base, cleanup = self._process_joins(*named_exprs.values()) for k, v in named_exprs.items(): - analyze('MatrixTable.annotate_entries', v, self._entry_indices) - exprs.append('g.{k} = {v}'.format(k=escape_id(k), v=v._ast.to_hql())) check_collisions(self._fields, k, self._entry_indices) - m = MatrixTable(base._jvds.annotateEntriesExpr(",\n".join(exprs))) - return cleanup(m) + + return self._select_entries("MatrixTable.annotate_entries", self.entry.annotate(**named_exprs)) def select_globals(self, *exprs: FieldRefArgs, **named_exprs: NamedExprs) -> 'MatrixTable': """Select existing global fields or create new fields by name, dropping the rest. @@ -1188,31 +1185,21 @@ def select_entries(self, *exprs: FieldRefArgs, **named_exprs: NamedExprs) -> 'Ma """ exprs = [to_expr(e) if not isinstance(e, str) else self[e] for e in exprs] named_exprs = {k: to_expr(v) for k, v in named_exprs.items()} - strs = [] - all_exprs = [] - base, cleanup = self._process_joins(*itertools.chain(exprs, named_exprs.values())) + assignments = OrderedDict() - ids = [] for e in exprs: - all_exprs.append(e) - analyze('MatrixTable.select_entries', e, self._entry_indices) if not e._indices == self._entry_indices: # detect row or col fields here raise ExpressionException("method 'select_entries' parameter 'exprs' expects entry-indexed fields," " found indices {}".format(list(e._indices.axes))) if e._ast.search(lambda ast: not isinstance(ast, TopLevelReference) and not isinstance(ast, Select)): raise ExpressionException("method 'select_entries' expects keyword arguments for complex expressions") - strs.append(e._ast.to_hql()) - ids.append(e._ast.name) + assignments[e._ast.name] = e for k, e in named_exprs.items(): - all_exprs.append(e) - analyze('MatrixTable.select_entries', e, self._entry_indices) check_collisions(self._fields, k, self._entry_indices) - strs.append('{} = {}'.format(escape_id(k), e._ast.to_hql())) - ids.append(k) - check_field_uniqueness(ids) - m = MatrixTable(base._jvds.selectEntries(strs)) - return cleanup(m) + assignments[k] = e + check_field_uniqueness(assignments.keys()) + return self._select_entries("MatrixTable.select_entries", hl.struct(**assignments)) @typecheck_method(exprs=oneof(str, Expression)) def drop(self, *exprs: FieldRefArgs) -> 'MatrixTable': @@ -1290,10 +1277,9 @@ def drop(self, *exprs: FieldRefArgs) -> 'MatrixTable': new_col_fields = [f for f in m.col if f not in fields_to_drop] m = m.select_cols(*new_col_fields) - entry_fields = [x for x in fields_to_drop if self._fields[x]._indices == self._entry_indices] - if any(self._fields[field]._indices == self._entry_indices for field in fields_to_drop): - # need to drop entry fields - m = MatrixTable(m._jvds.dropEntries(entry_fields)) + entry_fields = [field for field in fields_to_drop if self._fields[field]._indices == self._entry_indices] + if entry_fields: + m = m._select_entries("MatrixTable.drop_entries", m.entry.drop(*entry_fields)) return m @@ -2208,7 +2194,6 @@ def joiner(left: MatrixTable): src_cols_indexed = src_cols_indexed.annotate(**{col_uid: hl.int32(src_cols_indexed[col_uid])}) left = left._annotate_all(row_exprs = {row_uid: localized.index(*row_exprs)[row_uid]}, col_exprs = {col_uid: src_cols_indexed.index(*col_exprs)[col_uid]}) - return left.annotate_entries(**{uid: left[row_uid][left[col_uid]]}) return construct_expr(Select(TopLevelReference('g', self._entry_indices), uid), @@ -2247,12 +2232,9 @@ def _annotate_all(self, check_collisions(self._fields, k, self._col_indices) jmt = jmt.annotateColsExpr(",\n".join(col_strs)) if entry_exprs: - entry_strs = [] - for k, v in entry_exprs.items(): - analyze('MatrixTable.annotate_entries', v, self._entry_indices) - entry_strs.append('g.{k} = {v}'.format(k=escape_id(k), v=v._ast.to_hql())) - check_collisions(self._fields, k, self._entry_indices) - jmt = jmt.annotateEntriesExpr(",\n".join(entry_strs)) + entry_struct = self.entry.annotate(**entry_exprs) + analyze("MatrixTable.annotate_entries", entry_struct, self._entry_indices) + jmt = jmt.selectEntries(entry_struct._ast.to_hql()) if global_exprs: global_strs = [] for k, v in global_exprs.items(): @@ -2622,6 +2604,12 @@ def add_col_index(self, name: str = 'col_idx') -> 'MatrixTable': def _same(self, other, tolerance=1e-6): return self._jvds.same(other._jvds, tolerance) + @typecheck_method(caller=str, s=expr_struct()) + def _select_entries(self, caller, s): + base, cleanup = self._process_joins(s) + analyze(caller, s, self._entry_indices) + return cleanup(MatrixTable(base._jvds.selectEntries(s._ast.to_hql()))) + @typecheck(datasets=matrix_table_type) def union_rows(*datasets: Tuple['MatrixTable']) -> 'MatrixTable': """Take the union of dataset rows. diff --git a/python/hail/methods/statgen.py b/python/hail/methods/statgen.py index 44a9da995a1..03f8d259fa1 100644 --- a/python/hail/methods/statgen.py +++ b/python/hail/methods/statgen.py @@ -1300,7 +1300,7 @@ def hwe_normalized_pca(dataset, k=10, compute_loadings=False, as_array=False): :math:`i` and :math:`j` of :math:`M`; in terms of :math:`C` it is .. math:: - + \frac{1}{m}\sum_{l\in\mathcal{C}_i\cap\mathcal{C}_j}\frac{(C_{il}-2p_l)(C_{jl} - 2p_l)}{2p_l(1-p_l)} where :math:`\mathcal{C}_i = \{l \mid C_{il} \text{ is non-missing}\}`. In @@ -1322,7 +1322,7 @@ def hwe_normalized_pca(dataset, k=10, compute_loadings=False, as_array=False): Parameters ---------- dataset : :class:`.MatrixTable` - Dataset. + Matrix table with entry-indexed ``GT`` field of type :py:data:`.tcall`. k : :obj:`int` Number of principal components. compute_loadings : :obj:`bool` @@ -1378,7 +1378,7 @@ def pca(entry_expr, k=10, compute_loadings=False, as_array=False): 1s encoding missingness of genotype calls. >>> eigenvalues, scores, _ = hl.pca(hl.int(hl.is_defined(dataset.GT)), - ... k=2) + ... k=2) Warning ------- @@ -1437,8 +1437,6 @@ def pca(entry_expr, k=10, compute_loadings=False, as_array=False): Parameters ---------- - dataset : :class:`.MatrixTable` - Dataset. entry_expr : :class:`.Expression` Numeric expression for matrix entries. k : :obj:`int` @@ -1473,10 +1471,9 @@ def pca(entry_expr, k=10, compute_loadings=False, as_array=False): k=int, maf=numeric, block_size=int, - path=nullable(str), min_kinship=numeric, statistics=enumeration("phi", "phik2", "phik2k0", "all")) -def pc_relate(ds, k, maf, path=None, block_size=512, min_kinship=-float("inf"), statistics="all"): +def pc_relate(ds, k, maf, block_size=512, min_kinship=-float("inf"), statistics="all"): """Compute relatedness estimates between individuals using a variant of the PC-Relate method. @@ -1699,9 +1696,6 @@ def pc_relate(ds, k, maf, path=None, block_size=512, min_kinship=-float("inf"), maf : :obj:`float` The minimum individual-specific allele frequency for an allele used to measure relatedness. - path : :obj:`str` or :obj:`None` - A temporary directory to store intermediate matrices. Storing the matrices - to a file system is necessary for reliable execution of this method. block_size : :obj:`int` the side length of the blocks of the block-distributed matrices; this should be set such that at least three of these matrices fit in memory @@ -1735,16 +1729,15 @@ def pc_relate(ds, k, maf, path=None, block_size=512, min_kinship=-float("inf"), ds = ds.annotate_cols(scores=scores[ds.s].scores) - return pc_relate_with_scores(ds, ds.scores, maf, path, block_size, min_kinship, statistics) + return pc_relate_with_scores(ds, ds.scores, maf, block_size, min_kinship, statistics) @typecheck(ds=MatrixTable, scores=expr_array(expr_float64), maf=numeric, - path=nullable(str), block_size=int, min_kinship=numeric, statistics=enumeration("phi", "phik2", "phik2k0", "all")) -def pc_relate_with_scores(ds, scores, maf, path=None, block_size=512, min_kinship=-float("inf"), statistics="all"): +def pc_relate_with_scores(ds, scores, maf, block_size=512, min_kinship=-float("inf"), statistics="all"): """The PC-Relate method parameterized by sample PC scores See the detailed documentation at :meth:`.pc_relate`. @@ -1760,9 +1753,6 @@ def pc_relate_with_scores(ds, scores, maf, path=None, block_size=512, min_kinshi maf : :obj:`float` The minimum individual-specific allele frequency for an allele used to measure relatedness. - path : :obj:`str` or :obj:`None` - A temporary directory to store intermediate matrices. Storing the matrices - to a file system is necessary for reliable execution of this method. block_size : :obj:`int` the side length of the blocks of the block-distributed matrices; this should be set such that at least three of these matrices fit in memory @@ -1802,9 +1792,8 @@ def pc_relate_with_scores(ds, scores, maf, path=None, block_size=512, min_kinshi mean_gt=(agg.sum(ds.naa) / agg.count_where(hl.is_defined(ds.GT)))) mean_imputed_gt = hl.or_else(ds.naa, ds.mean_gt) - g = BlockMatrix.write_from_entry_expr(mean_imputed_gt, - path=path, - block_size=block_size) + g = BlockMatrix.from_entry_expr(mean_imputed_gt, + block_size=block_size) pc_scores = (ds.add_col_index('column_index').cols().collect()) pc_scores.sort(key=lambda x: x.column_index) @@ -2345,10 +2334,11 @@ def realized_relationship_matrix(call_expr): fst=nullable(listof(numeric)), af_dist=oneof(UniformDist, BetaDist, TruncatedBetaDist), seed=int, - reference_genome=reference_genome_type) + reference_genome=reference_genome_type, + mixture=bool) def balding_nichols_model(n_populations, n_samples, n_variants, n_partitions=None, pop_dist=None, fst=None, af_dist=UniformDist(0.1, 0.9), - seed=0, reference_genome='default'): + seed=0, reference_genome='default', mixture=False): r"""Generate a matrix table of variants, samples, and genotypes using the Balding-Nichols model. @@ -2411,7 +2401,7 @@ def balding_nichols_model(n_populations, n_samples, n_variants, n_partitions=Non population allele frequencies by :math:`p_{k, m}`, and diploid, unphased genotype calls by :math:`g_{n, m}` (0, 1, and 2 correspond to homozygous reference, heterozygous, and homozygous variant, respectively). - + The generative model is then given by: .. math:: @@ -2441,6 +2431,7 @@ def balding_nichols_model(n_populations, n_samples, n_variants, n_partitions=Non - `ancestral_af_dist` (:class:`.tstruct`) -- Description of the ancestral allele frequency distribution. - `seed` (:py:data:`.tint32`) -- Random seed. + - `mixture` (:py:data:`.tbool`) -- Value of `mixture` parameter. Row fields: @@ -2484,6 +2475,12 @@ def balding_nichols_model(n_populations, n_samples, n_variants, n_partitions=Non Random seed. reference_genome : :obj:`str` or :class:`.ReferenceGenome` Reference genome to use. + mixture : :obj:`bool` + Treat `pop_dist` as the parameters of a Dirichlet distribution, + as in the Prichard-Stevens-Donnelly model. This feature is + EXPERIMENTAL and currently undocumented and untested. + If ``True``, the type of `pop` is :class:`.tarray` of + :py:data:`.tfloat64` and the value is the mixture proportions. Returns ------- @@ -2507,7 +2504,8 @@ def balding_nichols_model(n_populations, n_samples, n_variants, n_partitions=Non jvm_fst_opt, af_dist._jrep(), seed, - reference_genome._jrep) + reference_genome._jrep, + mixture) return MatrixTable(jmt) diff --git a/python/hail/table.py b/python/hail/table.py index 2a6e98c3011..e8bb32a134e 100644 --- a/python/hail/table.py +++ b/python/hail/table.py @@ -376,6 +376,12 @@ def _select(self, caller, s): analyze(caller, s, self._row_indices) return cleanup(Table(base._jt.select(s._ast.to_hql()))) + @typecheck_method(caller=str, s=expr_struct()) + def _select_globals(self, caller, s): + base, cleanup = self._process_joins(s) + analyze(caller, s, self._global_indices) + return cleanup(Table(base._jt.selectGlobal(s._ast.to_hql()))) + @classmethod @typecheck_method(rows=anytype, schema=tstruct, @@ -463,17 +469,10 @@ def annotate_globals(self, **named_exprs): :class:`.Table` Table with new global field(s). """ - - exprs = [] named_exprs = {k: to_expr(v) for k, v in named_exprs.items()} - base, cleanup = self._process_joins(*named_exprs.values()) for k, v in named_exprs.items(): - analyze('Table.annotate_globals', v, self._global_indices) check_collisions(self._fields, k, self._global_indices) - exprs.append('{k} = {v}'.format(k=escape_id(k), v=v._ast.to_hql())) - - m = Table(base._jt.annotateGlobalExpr(",\n".join(exprs))) - return cleanup(m) + return self._select_globals('Table.annotate_globals', self.globals.annotate(**named_exprs)) def select_globals(self, *exprs, **named_exprs): """Select existing global fields or create new fields by name, dropping the rest. @@ -511,29 +510,22 @@ def select_globals(self, *exprs, **named_exprs): :class:`.Table` Table with specified global fields. """ - exprs = [self[e] if not isinstance(e, Expression) else e for e in exprs] named_exprs = {k: to_expr(v) for k, v in named_exprs.items()} - strs = [] - all_exprs = [] - base, cleanup = self._process_joins(*itertools.chain(exprs, named_exprs.values())) + assignments = OrderedDict() - ids = [] for e in exprs: - all_exprs.append(e) - analyze('Table.select_globals', e, self._global_indices) if e._ast.search(lambda ast: not isinstance(ast, TopLevelReference) and not isinstance(ast, Select)): raise ExpressionException("method 'select_globals' expects keyword arguments for complex expressions") - strs.append(e._ast.to_hql()) - ids.append(e._ast.expand()[0].name) + assert isinstance(e._ast, Select) + assignments[e._ast.name] = e + for k, e in named_exprs.items(): - all_exprs.append(e) - analyze('Table.select_globals', e, self._global_indices) check_collisions(self._fields, k, self._global_indices) - strs.append('{} = {}'.format(escape_id(k), to_expr(e)._ast.to_hql())) - ids.append(k) - check_field_uniqueness(ids) - return cleanup(Table(base._jt.selectGlobal(strs))) + assignments[k] = e + + check_field_uniqueness(assignments.keys()) + return self._select_globals('Table.select_globals', hl.struct(**assignments)) def transmute_globals(self, **named_exprs): raise NotImplementedError() diff --git a/python/hail/tests/test_api.py b/python/hail/tests/test_api.py index 8c6e5c4b883..2fc56ee0ea1 100644 --- a/python/hail/tests/test_api.py +++ b/python/hail/tests/test_api.py @@ -616,6 +616,15 @@ def test_query(self): qgs = vds.aggregate_entries(hl.Struct(x=agg.collect(agg.filter(False, vds.y1)), y=agg.collect(agg.filter(hl.rand_bool(0.1), vds.GT)))) + def test_select_entries(self): + mt = hl.utils.range_matrix_table(10, 10, n_partitions=4) + mt = mt.annotate_entries(a=hl.struct(b=mt.row_idx, c=mt.col_idx), foo=mt.row_idx * 10 + mt.col_idx) + mt = mt.select_entries(mt.a.b, mt.a.c, mt.foo) + mt = mt.annotate_entries(bc=mt.b * 10 + mt.c) + mt_entries = mt.entries() + + assert(mt_entries.all(mt_entries.bc == mt_entries.foo)) + def test_drop(self): vds = self.get_vds() vds = vds.annotate_globals(foo=5) @@ -800,12 +809,13 @@ def test_computed_key_join_3(self): def test_entry_join_self(self): mt1 = hl.utils.range_matrix_table(10, 10, n_partitions=4) - mt1 = mt1.annotate_entries(x = mt1.row_idx + mt1.col_idx) + mt1 = mt1.annotate_entries(x = 10*mt1.row_idx + mt1.col_idx) self.assertEqual(mt1[mt1.row_idx, mt1.col_idx].dtype, mt1.entry.dtype) mt_join = mt1.annotate_entries(x2 = mt1[mt1.row_idx, mt1.col_idx].x) mt_join_entries = mt_join.entries() + self.assertTrue(mt_join_entries.all(mt_join_entries.x == mt_join_entries.x2)) def test_entry_join_const(self): diff --git a/python/hail/tests/test_expr.py b/python/hail/tests/test_expr.py index 2aa48c7e3d2..fe3cef09ee3 100644 --- a/python/hail/tests/test_expr.py +++ b/python/hail/tests/test_expr.py @@ -307,6 +307,21 @@ def test_str_ops(self): self.assertFalse(hl.eval_expr(hl.bool(s5))) self.assertFalse(hl.eval_expr(hl.bool(s6))) + # lower + s = hl.literal('abcABC123') + self.assertEqual(s.lower().value, 'abcabc123') + self.assertEqual(s.upper().value, 'ABCABC123') + + s_whitespace = hl.literal(' \t 1 2 3 \t\n') + self.assertEqual(s_whitespace.strip().value, '1 2 3') + + self.assertEqual(s.contains('ABC').value, True) + self.assertEqual(s.contains('a').value, True) + self.assertEqual(s.contains('C123').value, True) + self.assertEqual(s.contains('').value, True) + self.assertEqual(s.contains('C1234').value, False) + self.assertEqual(s.contains(' ').value, False) + def check_expr(self, expr, expected, expected_type): self.assertEqual(expected_type, expr.dtype) self.assertEqual((expected, expected_type), hl.eval_expr_typed(expr)) @@ -849,6 +864,23 @@ def test_modulus(self): self.check_expr(a_float32 % float64_3s, expected, tarray(tfloat64)) self.check_expr(a_float64 % float64_3s, expected, tarray(tfloat64)) + def test_bools_can_math(self): + b1 = hl.literal(True) + b2 = hl.literal(False) + + b_array = hl.literal([True, False]) + f1 = hl.float64(5.5) + f_array = hl.array([1.5, 2.5]) + + self.assertEqual((b1 * b2).value, 0) + self.assertEqual((b1 + b2).value, 1) + self.assertEqual((b1 - b2).value, 1) + self.assertEqual((b1 / b1).value, 1.0) + self.assertEqual((f1 * b2).value, 0.0) + self.assertEqual((b_array + f1).value, [6.5, 5.5]) + self.assertEqual((b_array + f_array).value, [2.5, 2.5]) + + def test_allele_methods(self): self.assertTrue(hl.eval_expr(hl.is_transition("A", "G"))) self.assertFalse(hl.eval_expr(hl.is_transversion("A", "G"))) diff --git a/python/hail/typecheck/__init__.py b/python/hail/typecheck/__init__.py index e8b3f2a80c9..bae65f3d661 100644 --- a/python/hail/typecheck/__init__.py +++ b/python/hail/typecheck/__init__.py @@ -1,6 +1,7 @@ from .check import * -__all__ = ['typecheck', +__all__ = ['TypeChecker', + 'typecheck', 'typecheck_method', 'anytype', 'nullable', diff --git a/python/hail/typecheck/check.py b/python/hail/typecheck/check.py index b640f2952b1..e0151f12ae8 100644 --- a/python/hail/typecheck/check.py +++ b/python/hail/typecheck/check.py @@ -427,7 +427,7 @@ def check_all(f, args, kwargs, checks, is_method): argname=argname, expected=tc.expects(), found=tc.format(arg) - )) + )) from None else: raise TypeError("{fname}: parameter '*{argname}' (arg {idx} of {tot}): " "expected {expected}, found {found}".format( @@ -437,7 +437,7 @@ def check_all(f, args, kwargs, checks, is_method): tot=len(pos_args) - len(named_args), expected=tc.expects(), found=tc.format(arg) - )) + )) from None kwargs_ = {} @@ -453,7 +453,7 @@ def check_all(f, args, kwargs, checks, is_method): argname=kw, expected=tc.expects(), found=tc.format(kwargs[kw]) - )) + )) from None if spec.varkw: tc = checks[spec.varkw] for argname, arg in kwargs.items(): @@ -467,7 +467,7 @@ def check_all(f, args, kwargs, checks, is_method): argname=argname, expected=tc.expects(), found=tc.format(arg) - )) + )) from None return args_, kwargs_ diff --git a/python/hail/utils/java.py b/python/hail/utils/java.py index 63e0eb0e902..e239c042434 100644 --- a/python/hail/utils/java.py +++ b/python/hail/utils/java.py @@ -5,8 +5,7 @@ from threading import Thread import py4j -import numpy as np - +import hail class FatalError(Exception): """:class:`.FatalError` is an error thrown by Hail method failures""" @@ -194,11 +193,11 @@ def deco(*args, **kwargs): deepest, full = tpl._1(), tpl._2() raise FatalError('%s\n\nJava stack trace:\n%s\n' 'Hail version: %s\n' - 'Error summary: %s' % (deepest, full, Env.hc().version, deepest)) from None + 'Error summary: %s' % (deepest, full, hail.__version__, deepest)) from None except pyspark.sql.utils.CapturedException as e: raise FatalError('%s\n\nJava stack trace:\n%s\n' 'Hail version: %s\n' - 'Error summary: %s' % (e.desc, e.stackTrace, Env.hc().version, e.desc)) from None + 'Error summary: %s' % (e.desc, e.stackTrace, hail.__version__, e.desc)) from None return deco diff --git a/src/main/c/lib/darwin/libhail.dylib b/src/main/c/lib/darwin/libhail.dylib new file mode 100755 index 00000000000..367350f4853 Binary files /dev/null and b/src/main/c/lib/darwin/libhail.dylib differ diff --git a/src/main/scala/is/hail/HailContext.scala b/src/main/scala/is/hail/HailContext.scala index b25b4a99b90..c11e4b57a72 100644 --- a/src/main/scala/is/hail/HailContext.scala +++ b/src/main/scala/is/hail/HailContext.scala @@ -603,8 +603,9 @@ class HailContext private(val sc: SparkContext, fst: Option[Array[Double]] = None, afDist: Distribution = UniformDist(0.1, 0.9), seed: Int = 0, - rg: ReferenceGenome = ReferenceGenome.defaultReference): MatrixTable = - BaldingNicholsModel(this, populations, samples, variants, popDist, fst, seed, nPartitions, afDist, rg) + rg: ReferenceGenome = ReferenceGenome.defaultReference, + mixture: Boolean = false): MatrixTable = + BaldingNicholsModel(this, populations, samples, variants, popDist, fst, seed, nPartitions, afDist, rg, mixture) def genDataset(): MatrixTable = VSMSubgen.realistic.gen(this).sample() diff --git a/src/main/scala/is/hail/annotations/BroadcastValue.scala b/src/main/scala/is/hail/annotations/BroadcastValue.scala index 68ef333627f..b9f57605236 100644 --- a/src/main/scala/is/hail/annotations/BroadcastValue.scala +++ b/src/main/scala/is/hail/annotations/BroadcastValue.scala @@ -6,4 +6,14 @@ import org.apache.spark.broadcast.Broadcast case class BroadcastValue(value: Annotation, t: Type, sc: SparkContext) { lazy val broadcast: Broadcast[Annotation] = sc.broadcast(value) + + lazy val regionValue: RegionValue = { + val rv = RegionValue(Region()) + val rvb = new RegionValueBuilder() + rvb.set(rv.region) + rvb.start(t) + rvb.addAnnotation(t, value) + rv.set(rv.region, rvb.end()) + rv + } } diff --git a/src/main/scala/is/hail/annotations/ExtendedOrdering.scala b/src/main/scala/is/hail/annotations/ExtendedOrdering.scala index 30bc429772a..70151bb2790 100644 --- a/src/main/scala/is/hail/annotations/ExtendedOrdering.scala +++ b/src/main/scala/is/hail/annotations/ExtendedOrdering.scala @@ -103,9 +103,10 @@ object ExtendedOrdering { def compareNonnull(x: T, y: T, missingGreatest: Boolean): Int = { val rx = x.asInstanceOf[Row] val ry = y.asInstanceOf[Row] - + + val commonPrefix = math.min(fieldOrd.length, math.min(rx.length, ry.length)) var i = 0 - while (i < fieldOrd.length) { + while (i < commonPrefix) { val c = fieldOrd(i).compare(rx.get(i), ry.get(i), missingGreatest) if (c != 0) return c diff --git a/src/main/scala/is/hail/annotations/OrderedRVIterator.scala b/src/main/scala/is/hail/annotations/OrderedRVIterator.scala index 52cde9b31fc..3608ac157c9 100644 --- a/src/main/scala/is/hail/annotations/OrderedRVIterator.scala +++ b/src/main/scala/is/hail/annotations/OrderedRVIterator.scala @@ -5,6 +5,15 @@ import is.hail.utils._ case class OrderedRVIterator(t: OrderedRVDType, iterator: Iterator[RegionValue]) { + def restrictToPKInterval(interval: Interval): Iterator[RegionValue] = { + val ur = new UnsafeRow(t.rowType) + val pk = new KeyedRow(ur, t.kRowFieldIdx) + iterator.filter { rv => { + ur.set(rv) + interval.contains(t.kType.ordering, pk) + } } + } + def staircase: StagingIterator[FlipbookIterator[RegionValue]] = iterator.toFlipbookIterator.staircased(t.kRowOrdView) diff --git a/src/main/scala/is/hail/annotations/UnsafeRow.scala b/src/main/scala/is/hail/annotations/UnsafeRow.scala index f5c4006d501..5b3f43dce8c 100644 --- a/src/main/scala/is/hail/annotations/UnsafeRow.scala +++ b/src/main/scala/is/hail/annotations/UnsafeRow.scala @@ -301,3 +301,13 @@ class UnsafeRow(var t: TBaseStruct, } } } + +class KeyedRow(var row: Row, keyFields: Array[Int]) extends Row { + def length: Int = row.size + def get(i: Int): Any = row.get(keyFields(i)) + def copy(): Row = new KeyedRow(row, keyFields) + def set(newRow: Row): KeyedRow = { + row = newRow + this + } +} diff --git a/src/main/scala/is/hail/expr/AST.scala b/src/main/scala/is/hail/expr/AST.scala index 93719037e70..5e48632f61a 100644 --- a/src/main/scala/is/hail/expr/AST.scala +++ b/src/main/scala/is/hail/expr/AST.scala @@ -660,17 +660,24 @@ case class Apply(posn: Position, fn: String, args: Array[AST]) extends AST(posn, } yield ir.ApplyBinaryPrimOp(op, x, y, t) } + private[this] def tryIRConversion(agg: Option[String]): Option[IR] = + for { + irArgs <- anyFailAllFail(args.map(_.toIR(agg))) + ir <- tryPrimOpConversion(args.map(_.`type`).zip(irArgs)).orElse( + IRFunctionRegistry.lookupConversion(fn, args.map(_.`type`)) + .map { irf => irf(irArgs) }) + } yield ir + def toIR(agg: Option[String] = None): Option[IR] = { fn match { - case "merge" | "select" | "drop" | "annotate" | "index" => + case "merge" | "select" | "drop" | "index" => None + case "annotate" => + if (!args(1).isInstanceOf[StructConstructor]) + return None + tryIRConversion(agg) case _ => - for { - irArgs <- anyFailAllFail(args.map(_.toIR(agg))) - ir <- tryPrimOpConversion(args.map(_.`type`).zip(irArgs)).orElse( - IRFunctionRegistry.lookupFunction(fn, args.map(_.`type`)) - .map { irf => irf(irArgs) }) - } yield ir + tryIRConversion(agg) } } } @@ -781,7 +788,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST case _ => for { irs <- anyFailAllFail((lhs +: args).map(_.toIR(agg))) - ir <- IRFunctionRegistry.lookupFunction(method, (lhs +: args).map(_.`type`)) + ir <- IRFunctionRegistry.lookupConversion(method, (lhs +: args).map(_.`type`)) .map { irf => irf(irs) } } yield ir } diff --git a/src/main/scala/is/hail/expr/FunctionRegistry.scala b/src/main/scala/is/hail/expr/FunctionRegistry.scala index 97e8aba3c5c..1fa80fd82fd 100644 --- a/src/main/scala/is/hail/expr/FunctionRegistry.scala +++ b/src/main/scala/is/hail/expr/FunctionRegistry.scala @@ -1013,6 +1013,11 @@ object FunctionRegistry { registerMethod("split", (s: String, p: String, n: Int) => s.split(p, n): IndexedSeq[String]) + registerMethod("lower", (s: String) => s.toLowerCase) + registerMethod("upper", (s: String) => s.toUpperCase) + registerMethod("strip", (s: String) => s.trim()) + registerMethod("contains", (s: String, t: String) => s.contains(t)) + registerMethod("replace", (str: String, pattern1: String, pattern2: String) => str.replaceAll(pattern1, pattern2)) diff --git a/src/main/scala/is/hail/expr/Relational.scala b/src/main/scala/is/hail/expr/Relational.scala index 9bc86a4fdd2..a26b82b5c14 100644 --- a/src/main/scala/is/hail/expr/Relational.scala +++ b/src/main/scala/is/hail/expr/Relational.scala @@ -379,6 +379,76 @@ case class FilterRows( } } +case class MapEntries(child: MatrixIR, newEntries: IR) extends MatrixIR { + + def children: IndexedSeq[BaseIR] = Array(child, newEntries) + + def copy(newChildren: IndexedSeq[BaseIR]): MapEntries = { + assert(newChildren.length == 2) + MapEntries(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[IR]) + } + + val newRow = { + val arrayLength = ArrayLen(GetField(Ref("va"), MatrixType.entriesIdentifier)) + val idxEnv = new Env[IR]() + .bind("g", ArrayRef(GetField(Ref("va"), MatrixType.entriesIdentifier), Ref("i"))) + .bind("sa", ArrayRef(Ref("sa"), Ref("i"))) + val entries = ArrayMap(ArrayRange(I32(0), arrayLength, I32(1)), "i", Subst(newEntries, idxEnv)) + InsertFields(Ref("va"), Seq((MatrixType.entriesIdentifier, entries))) + } + + val typ: MatrixType = { + Infer(newRow, None, new Env[Type]() + .bind("global", child.typ.globalType) + .bind("va", child.typ.rvRowType) + .bind("sa", TArray(child.typ.colType)) + ) + child.typ.copy(rvRowType = newRow.typ) + } + + def execute(hc: HailContext): MatrixValue = { + val prev = child.execute(hc) + + val localGlobalsType = typ.globalType + val localColsType = TArray(typ.colType) + val colValuesBc = prev.colValuesBc + val globalsBc = prev.globals.broadcast + + val (rTyp, f) = ir.Compile[Long, Long, Long, Long]( + "global", localGlobalsType, + "va", prev.typ.rvRowType, + "sa", localColsType, + newRow) + assert(rTyp == typ.rvRowType) + + val newRVD = prev.rvd.mapPartitionsPreservesPartitioning(typ.orvdType) { it => + val rvb = new RegionValueBuilder() + val newRV = RegionValue() + val rowF = f() + + it.map { rv => + val region = rv.region + val oldRow = rv.offset + + rvb.set(region) + rvb.start(localGlobalsType) + rvb.addAnnotation(localGlobalsType, globalsBc.value) + val globals = rvb.end() + + rvb.start(localColsType) + rvb.addAnnotation(localColsType, colValuesBc.value) + val cols = rvb.end() + + val off = rowF(region, globals, false, oldRow, false, cols, false) + + newRV.set(region, off) + newRV + } + } + prev.copy(typ = typ, rvd = newRVD) + } +} + case class TableValue(typ: TableType, globals: BroadcastValue, rvd: RVD) { def rdd: RDD[Row] = { val localRowType = typ.rowType @@ -495,6 +565,109 @@ case class TableFilter(child: TableIR, pred: IR) extends TableIR { } } +case class TableJoin(left: TableIR, right: TableIR, joinType: String) extends TableIR { + require(left.typ.keyType isIsomorphicTo right.typ.keyType) + + val children: IndexedSeq[BaseIR] = Array(left, right) + + private val joinedFields = left.typ.keyType.fields ++ + left.typ.valueType.fields ++ + right.typ.valueType.fields + private val preNames = joinedFields.map(_.name).toArray + private val (finalColumnNames, remapped) = mangle(preNames) + + val newRowType = TStruct(joinedFields.zipWithIndex.map { + case (fd, i) => (finalColumnNames(i), fd.typ) + }: _*) + + val typ: TableType = left.typ.copy(rowType = newRowType) + + def copy(newChildren: IndexedSeq[BaseIR]): TableJoin = { + assert(newChildren.length == 2) + TableJoin( + newChildren(0).asInstanceOf[TableIR], + newChildren(1).asInstanceOf[TableIR], + joinType ) + } + + def execute(hc: HailContext): TableValue = { + val leftTV = left.execute(hc) + val rightTV = right.execute(hc) + val leftRowType = left.typ.rowType + val rightRowType = right.typ.rowType + val leftKeyFieldIdx = left.typ.keyFieldIdx + val rightKeyFieldIdx = right.typ.keyFieldIdx + val leftValueFieldIdx = left.typ.valueFieldIdx + val rightValueFieldIdx = right.typ.valueFieldIdx + val localNewRowType = newRowType + val rvMerger: Iterator[JoinedRegionValue] => Iterator[RegionValue] = { it => + val rvb = new RegionValueBuilder() + val rv = RegionValue() + it.map { joined => + val lrv = joined._1 + val rrv = joined._2 + + if (lrv != null) + rvb.set(lrv.region) + else { + assert(rrv != null) + rvb.set(rrv.region) + } + + rvb.start(localNewRowType) + rvb.startStruct() + + if (lrv != null) + rvb.addFields(leftRowType, lrv, leftKeyFieldIdx) + else { + assert(rrv != null) + rvb.addFields(rightRowType, rrv, rightKeyFieldIdx) + } + + if (lrv != null) + rvb.addFields(leftRowType, lrv, leftValueFieldIdx) + else + rvb.skipFields(leftValueFieldIdx.length) + + if (rrv != null) + rvb.addFields(rightRowType, rrv, rightValueFieldIdx) + else + rvb.skipFields(rightValueFieldIdx.length) + + rvb.endStruct() + rv.set(rvb.region, rvb.end()) + rv + } + } + val leftORVD = leftTV.rvd match { + case ordered: OrderedRVD => ordered + case unordered => + OrderedRVD( + new OrderedRVDType(left.typ.key.toArray, left.typ.key.toArray, leftRowType), + unordered.rdd, + None, + None) + } + val rightORVD = rightTV.rvd match { + case ordered: OrderedRVD => ordered + case unordered => + val ordType = + new OrderedRVDType(right.typ.key.toArray, right.typ.key.toArray, rightRowType) + if (joinType == "left" || joinType == "inner") + unordered.constrainToOrderedPartitioner(ordType, leftORVD.partitioner) + else + OrderedRVD(ordType, unordered.rdd, None, Some(leftORVD.partitioner)) + } + val joinedRVD = leftORVD.orderedJoin( + rightORVD, + joinType, + rvMerger, + new OrderedRVDType(leftORVD.typ.partitionKey, leftORVD.typ.key, newRowType)) + + TableValue(typ, leftTV.globals, joinedRVD) + } +} + case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { val children: IndexedSeq[BaseIR] = Array(child, newRow) @@ -537,3 +710,36 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { }) } } + +case class TableMapGlobals(child: TableIR, newRow: IR) extends TableIR { + val children: IndexedSeq[BaseIR] = Array(child, newRow) + + val typ: TableType = { + Infer(newRow, None, child.typ.env) + child.typ.copy(globalType = newRow.typ.asInstanceOf[TStruct]) + } + + def copy(newChildren: IndexedSeq[BaseIR]): TableMapGlobals = { + assert(newChildren.length == 2) + TableMapGlobals(newChildren(0).asInstanceOf[TableIR], newChildren(1).asInstanceOf[IR]) + } + + def execute(hc: HailContext): TableValue = { + val tv = child.execute(hc) + val gType = typ.globalType + + val (rTyp, f) = ir.Compile[Long, Long]( + "global", child.typ.globalType, + newRow) + assert(rTyp == gType) + + val rv = tv.globals.regionValue + val offset = f()(rv.region, rv.offset, false) + + val newGlobals = tv.globals.copy( + value = UnsafeRow.read(rTyp, rv.region, offset), + t = rTyp) + + TableValue(typ, newGlobals, tv.rvd) + } +} diff --git a/src/main/scala/is/hail/expr/ir/Children.scala b/src/main/scala/is/hail/expr/ir/Children.scala index 77aa8c256a4..9cc17183d98 100644 --- a/src/main/scala/is/hail/expr/ir/Children.scala +++ b/src/main/scala/is/hail/expr/ir/Children.scala @@ -76,7 +76,7 @@ object Children { none case Die(message) => none - case ApplyFunction(impl, args) => + case Apply(_, args, _) => args.toIndexedSeq } } diff --git a/src/main/scala/is/hail/expr/ir/Compile.scala b/src/main/scala/is/hail/expr/ir/Compile.scala index 673a7429758..57294222814 100644 --- a/src/main/scala/is/hail/expr/ir/Compile.scala +++ b/src/main/scala/is/hail/expr/ir/Compile.scala @@ -35,6 +35,14 @@ object Compile { apply[AsmFunction1[Region, R], R](Seq(), body) } + def apply[T0: TypeInfo : ClassTag, R: TypeInfo : ClassTag]( + name0: String, + typ0: Type, + body: IR): (Type, () => AsmFunction3[Region, T0, Boolean, R]) = { + + apply[AsmFunction3[Region, T0, Boolean, R], R](Seq((name0, typ0, classTag[T0])), body) + } + def apply[T0: TypeInfo : ClassTag, T1: TypeInfo : ClassTag, R: TypeInfo : ClassTag]( name0: String, typ0: Type, @@ -45,6 +53,30 @@ object Compile { apply[AsmFunction5[Region, T0, Boolean, T1, Boolean, R], R](Seq((name0, typ0, classTag[T0]), (name1, typ1, classTag[T1])), body) } + def apply[T0: TypeInfo : ClassTag, T1: TypeInfo : ClassTag, T2: TypeInfo : ClassTag, R: TypeInfo : ClassTag]( + name0: String, + typ0: Type, + name1: String, + typ1: Type, + name2: String, + typ2: Type, + body: IR): (Type, () => AsmFunction7[Region, T0, Boolean, T1, Boolean, T2, Boolean, R]) = { + assert(TypeToIRIntermediateClassTag(typ0) == classTag[T0]) + assert(TypeToIRIntermediateClassTag(typ1) == classTag[T1]) + assert(TypeToIRIntermediateClassTag(typ2) == classTag[T2]) + val fb = FunctionBuilder.functionBuilder[Region, T0, Boolean, T1, Boolean, T2, Boolean, R] + var e = body + val env = new Env[IR]() + .bind(name0, In(0, typ0)) + .bind(name1, In(1, typ1)) + .bind(name2, In(2, typ2)) + e = Subst(e, env) + Infer(e) + assert(TypeToIRIntermediateClassTag(e.typ) == classTag[R]) + Emit(e, fb) + (e.typ, fb.result()) + } + def apply[T0: TypeInfo : ClassTag, T1: TypeInfo : ClassTag, T2: TypeInfo : ClassTag, T3: TypeInfo : ClassTag, T4: TypeInfo : ClassTag, T5: TypeInfo : ClassTag, R: TypeInfo : ClassTag]( diff --git a/src/main/scala/is/hail/expr/ir/Copy.scala b/src/main/scala/is/hail/expr/ir/Copy.scala index 413d12a06eb..ee7eac163e3 100644 --- a/src/main/scala/is/hail/expr/ir/Copy.scala +++ b/src/main/scala/is/hail/expr/ir/Copy.scala @@ -104,8 +104,8 @@ object Copy { same case Die(message) => same - case ApplyFunction(impl, args) => - ApplyFunction(impl, children) + case Apply(fn, args, impl) => + Apply(fn, children, impl) } } } diff --git a/src/main/scala/is/hail/expr/ir/Emit.scala b/src/main/scala/is/hail/expr/ir/Emit.scala index 438b8ef06ba..c61c0d47d44 100644 --- a/src/main/scala/is/hail/expr/ir/Emit.scala +++ b/src/main/scala/is/hail/expr/ir/Emit.scala @@ -501,7 +501,7 @@ private class Emit( present(fb.getArg[Boolean](i * 2 + 3)) case Die(m) => present(Code._throw(Code.newInstance[RuntimeException, String](m))) - case ApplyFunction(impl, args) => + case Apply(fn, args, impl) => val meth = methods.getOrElseUpdate(impl, { impl.argTypes.foreach(_.clear()) (impl.argTypes, args.map(a => a.typ)).zipped.foreach(_.unify(_)) diff --git a/src/main/scala/is/hail/expr/ir/IR.scala b/src/main/scala/is/hail/expr/ir/IR.scala index 8cc39ea6102..961ff67ba62 100644 --- a/src/main/scala/is/hail/expr/ir/IR.scala +++ b/src/main/scala/is/hail/expr/ir/IR.scala @@ -84,4 +84,4 @@ final case class InMissingness(i: Int) extends IR { val typ: Type = TBoolean() } // FIXME: should be type any final case class Die(message: String) extends IR { val typ = TVoid } -final case class ApplyFunction(implementation: IRFunction, args: Seq[IR]) extends IR { val typ = implementation.returnType } +final case class Apply(function: String, args: Seq[IR], var implementation: IRFunction = null) extends IR { def typ = implementation.returnType } diff --git a/src/main/scala/is/hail/expr/ir/Infer.scala b/src/main/scala/is/hail/expr/ir/Infer.scala index ed248a1f418..811219d414b 100644 --- a/src/main/scala/is/hail/expr/ir/Infer.scala +++ b/src/main/scala/is/hail/expr/ir/Infer.scala @@ -1,5 +1,6 @@ package is.hail.expr.ir +import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.types._ object Infer { @@ -156,7 +157,7 @@ object Infer { case x@GetField(o, name, _) => infer(o) val t = coerce[TStruct](o.typ) - assert(t.index(name).nonEmpty) + assert(t.index(name).nonEmpty, s"$name not in $t") x.typ = -t.field(name).typ case GetFieldMissingness(o, name) => infer(o) @@ -174,9 +175,11 @@ object Infer { assert(typ != null) case InMissingness(i) => case Die(msg) => - case ApplyFunction(impl, args) => + case x@Apply(fn, args, impl) => args.foreach(infer(_)) - assert(args.map(_.typ).zip(impl.argTypes).forall {case (i, j) => j.unify(i)}) + if (impl == null) + x.implementation = IRFunctionRegistry.lookupFunction(fn, args.map(_.typ)).get + assert(args.map(_.typ).zip(x.implementation.argTypes).forall {case (i, j) => j.unify(i)}) } } diff --git a/src/main/scala/is/hail/expr/ir/Recur.scala b/src/main/scala/is/hail/expr/ir/Recur.scala index d93a9a9bf1b..41af8a885f1 100644 --- a/src/main/scala/is/hail/expr/ir/Recur.scala +++ b/src/main/scala/is/hail/expr/ir/Recur.scala @@ -41,6 +41,6 @@ object Recur { case In(i, typ) => ir case InMissingness(i) => ir case Die(message) => ir - case ApplyFunction(impl, args) => ApplyFunction(impl, args.map(f)) + case Apply(fn, args, impl) => Apply(fn, args.map(f), impl) } } diff --git a/src/main/scala/is/hail/expr/ir/functions/Functions.scala b/src/main/scala/is/hail/expr/ir/functions/Functions.scala index 0ab43c664a7..586e34a10ce 100644 --- a/src/main/scala/is/hail/expr/ir/functions/Functions.scala +++ b/src/main/scala/is/hail/expr/ir/functions/Functions.scala @@ -7,38 +7,60 @@ import is.hail.utils._ import is.hail.asm4s.coerce import scala.collection.mutable -import scala.reflect.ClassTag object IRFunctionRegistry { - val registry: mutable.Map[String, Seq[(Seq[Type], Seq[IR] => IR)]] = mutable.Map().withDefaultValue(Seq.empty) + val irRegistry: mutable.Map[String, Seq[(Seq[Type], Seq[IR] => IR)]] = mutable.Map().withDefaultValue(Seq.empty) + + val codeRegistry: mutable.Map[String, Seq[(Seq[Type], IRFunction)]] = mutable.Map().withDefaultValue(Seq.empty) def addIRFunction(f: IRFunction) { - val l = registry(f.name) - registry.put(f.name, - l :+ (f.argTypes, { args: Seq[IR] => - ApplyFunction(f, args) - })) + val l = codeRegistry(f.name) + codeRegistry.put(f.name, + l :+ (f.argTypes, f)) } def addIR(name: String, types: Seq[Type], f: Seq[IR] => IR) { - val l = registry(name) - registry.put(name, l :+ ((types, f))) + val l = irRegistry(name) + irRegistry.put(name, l :+ ((types, f))) + } + + def lookupFunction(name: String, args: Seq[Type]): Option[IRFunction] = { + val validF = codeRegistry(name).flatMap { case (ts, f) => + if (ts.length == args.length) { + ts.foreach(_.clear()) + if ((ts, args).zipped.forall(_.unify(_))) + Some(f) + else + None + } else + None + } + + validF match { + case Seq() => None + case Seq(x) => Some(x) + case _ => fatal(s"Multiple IRFunctions found that satisfy $name$args.") + } } - def lookupFunction(name: String, args: Seq[Type]): Option[Seq[IR] => IR] = { + def lookupConversion(name: String, args: Seq[Type]): Option[Seq[IR] => IR] = { assert(args.forall(_ != null)) - val validMethods = registry(name).flatMap { case (ts, f) => + val validIR = irRegistry(name).flatMap { case (ts, f) => if (ts.length == args.length) { ts.foreach(_.clear()) if ((ts, args).zipped.forall(_.unify(_))) Some(f) else None - } else { + } else None - } } + + val validMethods = validIR ++ lookupFunction(name, args).map { f => + { args: Seq[IR] => Apply(name, args, f) } + } + validMethods match { case Seq() => None case Seq(x) => Some(x) @@ -50,6 +72,7 @@ object IRFunctionRegistry { GenotypeFunctions.registerAll() MathFunctions.registerAll() UtilFunctions.registerAll() + StringFunctions.registerAll() } abstract class RegistryFunctions { diff --git a/src/main/scala/is/hail/expr/ir/functions/StringFunctions.scala b/src/main/scala/is/hail/expr/ir/functions/StringFunctions.scala new file mode 100644 index 00000000000..e26ca14f35e --- /dev/null +++ b/src/main/scala/is/hail/expr/ir/functions/StringFunctions.scala @@ -0,0 +1,22 @@ +package is.hail.expr.ir.functions + +import is.hail.expr.types._ + +object StringFunctions extends RegistryFunctions { + + def upper(s: String): String = s.toUpperCase + + def lower(s: String): String = s.toLowerCase + + def strip(s: String): String = s.trim() + + def contains(s: String, t: String): Boolean = s.contains(t) + + def registerAll(): Unit = { + val thisClass = getClass + registerScalaFunction("upper", TString(), TString())(thisClass, "upper") + registerScalaFunction("lower", TString(), TString())(thisClass, "lower") + registerScalaFunction("strip", TString(), TString())(thisClass, "strip") + registerScalaFunction("contains", TString(), TString(), TBoolean())(thisClass, "contains") + } +} diff --git a/src/main/scala/is/hail/expr/ir/functions/UtilFunctions.scala b/src/main/scala/is/hail/expr/ir/functions/UtilFunctions.scala index 4fa386b9d16..612cdba3f84 100644 --- a/src/main/scala/is/hail/expr/ir/functions/UtilFunctions.scala +++ b/src/main/scala/is/hail/expr/ir/functions/UtilFunctions.scala @@ -41,5 +41,9 @@ object UtilFunctions extends RegistryFunctions { registerIR("range", TInt32(), TInt32())(ArrayRange(_, _, I32(1))) registerIR("range", TInt32())(ArrayRange(I32(0), _, I32(1))) + + registerIR("annotate", tv("T", _.isInstanceOf[TStruct]), tv("U", _.isInstanceOf[TStruct])) { (s, annotations) => + InsertFields(s, annotations.asInstanceOf[MakeStruct].fields) + } } } diff --git a/src/main/scala/is/hail/expr/types/TBaseStruct.scala b/src/main/scala/is/hail/expr/types/TBaseStruct.scala index 0c62f6bae2c..0b9e754d4b8 100644 --- a/src/main/scala/is/hail/expr/types/TBaseStruct.scala +++ b/src/main/scala/is/hail/expr/types/TBaseStruct.scala @@ -10,6 +10,11 @@ import org.json4s.jackson.JsonMethods import scala.reflect.{ClassTag, classTag} object TBaseStruct { + /** + * Define an ordering on Row objects. Works with any row r such that the list + * of types of r is a prefix of types, or types is a prefix of the list of + * types of r. + */ def getOrdering(types: Array[Type]): ExtendedOrdering = ExtendedOrdering.rowOrdering(types.map(_.ordering)) diff --git a/src/main/scala/is/hail/expr/types/TStruct.scala b/src/main/scala/is/hail/expr/types/TStruct.scala index 686a0d91091..4307ef12af0 100644 --- a/src/main/scala/is/hail/expr/types/TStruct.scala +++ b/src/main/scala/is/hail/expr/types/TStruct.scala @@ -90,6 +90,13 @@ final case class TStruct(fields: IndexedSeq[Field], override val required: Boole case _ => false } + def isIsomorphicTo(other: TStruct): Boolean = + size == other.size && isPrefixOf(other) + + def isPrefixOf(other: TStruct): Boolean = + size <= other.size && + fields.zip(other.fields).forall{ case (l, r) => l.typ isOfType r.typ } + override def subst() = TStruct(fields.map(f => f.copy(typ = f.typ.subst().asInstanceOf[Type]))) def index(str: String): Option[Int] = fieldIdx.get(str) diff --git a/src/main/scala/is/hail/expr/types/TableType.scala b/src/main/scala/is/hail/expr/types/TableType.scala index e986ff08344..722fce9f930 100644 --- a/src/main/scala/is/hail/expr/types/TableType.scala +++ b/src/main/scala/is/hail/expr/types/TableType.scala @@ -17,6 +17,12 @@ case class TableType(rowType: TStruct, key: IndexedSeq[String], globalType: TStr .bind(("row", rowType)) } + def keyType: TStruct = rowType.select(key.toArray)._1 + val keyFieldIdx: Array[Int] = key.toArray.map(rowType.fieldIdx) + def valueType: TStruct = rowType.filter(key.toSet, include = false)._1 + val valueFieldIdx: Array[Int] = + rowType.fields.filter(f => !key.contains(f.name)).map(_.index).toArray + def pretty(sb: StringBuilder, indent0: Int = 0, compact: Boolean = false) { var indent = indent0 diff --git a/src/main/scala/is/hail/linalg/LocalMatrix.scala b/src/main/scala/is/hail/linalg/LocalMatrix.scala deleted file mode 100644 index 616ad328122..00000000000 --- a/src/main/scala/is/hail/linalg/LocalMatrix.scala +++ /dev/null @@ -1,325 +0,0 @@ -package is.hail.linalg - -import is.hail.HailContext -import is.hail.stats.eigSymD -import is.hail.utils.richUtils.RichDenseMatrixDouble -import is.hail.utils._ - -import scala.collection.immutable.Range -import breeze.linalg.{* => B_*, DenseMatrix => BDM, DenseVector => BDV, cholesky => breezeCholesky, qr => breezeQR, svd => breezeSVD, _} -import breeze.numerics.{pow => breezePow, sqrt => breezeSqrt} -import breeze.stats.distributions.{RandBasis, ThreadLocalRandomGenerator} -import is.hail.io._ -import org.apache.commons.math3.random.MersenneTwister - - -object LocalMatrix { - type M = LocalMatrix - - private val bufferSpec: BufferSpec = - new BlockingBufferSpec(32 * 1024, - new LZ4BlockBufferSpec(32 * 1024, - new StreamBlockBufferSpec)) - - private val sclrType = 0 - private val colType = 1 - private val rowType = 2 - private val matType = 3 - - def apply(m: BDM[Double]): M = new LocalMatrix(m) - - // vector => matrix with single column - def apply(v: BDV[Double]): M = { - val data = if (v.length == v.data.length) v.data else v.toArray - LocalMatrix(data) - } - - // array => matrix with single column - def apply(data: Array[Double]): M = { - require(data.length > 0) - LocalMatrix(data.length, 1, data) - } - - def zeros(nRows: Int, nCols: Int): M = - LocalMatrix(new BDM[Double](nRows, nCols)) - - def apply(nRows: Int, nCols: Int, data: Array[Double], isTransposed: Boolean = false): M = { - require(nRows * nCols == data.length) - val m = new BDM[Double](nRows, nCols, data, 0, if (isTransposed) nCols else nRows, isTransposed) - LocalMatrix(m) - } - - def apply(nRows: Int, nCols: Int, data: Array[Double], offset: Int, majorStride: Int, isTransposed: Boolean): M = { - val m = new BDM[Double](nRows, nCols, data, offset, majorStride, isTransposed) - LocalMatrix(m) - } - - def read(hc: HailContext, path: String): M = - new LocalMatrix(RichDenseMatrixDouble.read(hc, path, LocalMatrix.bufferSpec)) - - def random(nRows: Int, nCols: Int, seed: Int = 0, uniform: Boolean = true): M = { - val randBasis: RandBasis = new RandBasis(new ThreadLocalRandomGenerator(new MersenneTwister(seed))) - val rand = if (uniform) randBasis.uniform else randBasis.gaussian - LocalMatrix(BDM.rand(nRows, nCols, rand)) - } - - def outerSum(row: Array[Double], col: Array[Double]): LocalMatrix = { - val nRows = col.length - val nCols = row.length - assert(nRows > 0 && nCols > 0 && (nRows * nCols.toLong < Int.MaxValue)) - - val a = new Array[Double](nRows * nCols) - var j = 0 - while (j < nCols) { - var i = 0 - while (i < nRows) { - a(j * nRows + i) = col(i) + row(j) - i += 1 - } - j += 1 - } - LocalMatrix(nRows, nCols, a) - } - - object ops { - implicit class Shim(l: M) { - def +(r: M): M = l.add(r) - - def -(r: M): M = l.subtract(r) - - def *(r: M): M = l.multiply(r) - - def /(r: M): M = l.divide(r) - - def +(r: Double): M = l.add(r) - - def -(r: Double): M = l.subtract(r) - - def *(r: Double): M = l.multiply(r) - - def /(r: Double): M = l.divide(r) - } - - implicit class ScalarShim(l: Double) { - def +(r: M): M = r.add(l) - - def -(r: M): M = r.rsubtract(l) - - def *(r: M): M = r.multiply(l) - - def /(r: M): M = r.rdivide(l) - } - } - - def checkShapes(m1: LocalMatrix, m2: LocalMatrix, op: String): (Int, Int) = { - val shapeTypes = (m1.shapeType, m2.shapeType) - - val compatible = shapeTypes match { - case (`matType`, `matType`) => m1.nRows == m2.nRows && m1.nCols == m2.nCols - case (`matType`, `rowType`) => m1.nCols == m2.nCols - case (`matType`, `colType`) => m1.nRows == m2.nRows - case (`rowType`, `matType`) => m1.nCols == m2.nCols - case (`rowType`, `rowType`) => m1.nCols == m2.nCols - case (`colType`, `matType`) => m1.nRows == m2.nRows - case (`colType`, `colType`) => m1.nRows == m2.nRows - case _ => true - } - - if (!compatible) - fatal(s"Incompatible shapes for $op with broadcasting: ${ m1.shape } and ${ m2.shape }") - - shapeTypes - } -} - -// Matrix with NumPy-style ops and broadcasting -class LocalMatrix(val m: BDM[Double]) { - val nRows: Int = m.rows - - val nCols: Int = m.cols - - def shape: (Int, Int) = (nRows, nCols) - - val isTranspose: Boolean = m.isTranspose - - def asArray: Array[Double] = - if (m.isCompact && !isTranspose) - m.data - else - toArray - - def toArray: Array[Double] = m.toArray - - def copy() = LocalMatrix(m.copy) - - def write(hc: HailContext, path: String) { - m.write(hc, path, bufferSpec = LocalMatrix.bufferSpec) - } - - // writeBlockMatrix and BlockMatrix.read is safer than toBlockMatrix - def writeBlockMatrix(hc: HailContext, path: String, blockSize: Int = BlockMatrix.defaultBlockSize, - forceRowMajor: Boolean = false) { - m.writeBlockMatrix(hc, path, blockSize, forceRowMajor) - } - - def toBlockMatrix(hc: HailContext, blockSize: Int = BlockMatrix.defaultBlockSize): BlockMatrix = - BlockMatrix.fromBreezeMatrix(hc.sc, m, blockSize) - - def apply(i: Int, j: Int): Double = m(i, j) - - def apply(i: Int, jj: Range) = LocalMatrix(m(i to i, jj)) - - def apply(ii: Range, j: Int) = LocalMatrix(m(ii, j to j)) - - def apply(ii: Range, jj: Range) = LocalMatrix(m(ii, jj)) - - def add(e: Double) = LocalMatrix(m + e) - - def subtract(e: Double) = LocalMatrix(m - e) - def rsubtract(e: Double) = LocalMatrix(e - m) - - def multiply(e: Double) = LocalMatrix(m * e) - - def divide(e: Double) = LocalMatrix(m / e) - def rdivide(e: Double) = LocalMatrix(e / m) - - def unary_+ = this - - def unary_- = LocalMatrix(-m) - - import LocalMatrix.{sclrType, colType, rowType, matType} - import LocalMatrix.ops._ - - private val shapeType: Int = - if (nRows > 1) { - if (nCols > 1) matType else colType - } else { - if (nCols > 1) rowType else sclrType - } - - def add(that: LocalMatrix): LocalMatrix = { - val (st, st2) = LocalMatrix.checkShapes(this, that, "addition") - if (st == st2) - LocalMatrix(m + that.m) - else if (st2 == sclrType) - this + that(0, 0) - else if (st == sclrType) - that + m(0, 0) - else - (st, st2) match { - case (`matType`, `colType`) => LocalMatrix(m(::, B_*) +:+ BDV(that.asArray)) - case (`matType`, `rowType`) => LocalMatrix(m(B_*, ::) +:+ BDV(that.asArray)) - case (`colType`, `matType`) => LocalMatrix(that.m(::, B_*) +:+ BDV(this.asArray)) - case (`rowType`, `matType`) => LocalMatrix(that.m(B_*, ::) +:+ BDV(this.asArray)) - case (`colType`, `rowType`) => LocalMatrix.outerSum(row = that.asArray, col = this.asArray) - case (`rowType`, `colType`) => LocalMatrix.outerSum(row = this.asArray, col = that.asArray) - } - } - - def subtract(that: LocalMatrix): LocalMatrix = { - val (st, st2) = LocalMatrix.checkShapes(this, that, "subtraction") - if (st == st2) - LocalMatrix(m - that.m) - else if (st2 == sclrType) - this - that(0, 0) - else if (st == sclrType) - m(0, 0) - that - else - (st, st2) match { - case (`matType`, `colType`) => LocalMatrix(m(::, B_*) -:- BDV(that.asArray)) - case (`matType`, `rowType`) => LocalMatrix(m(B_*, ::) -:- BDV(that.asArray)) - case _ => this + (-that) // FIXME: room for improvement - } - } - - def multiply(that: LocalMatrix): LocalMatrix = { - val (st, st2) = LocalMatrix.checkShapes(this, that, "pointwise multiplication") - if (st == st2) - LocalMatrix(m *:* that.m) - else if (st2 == sclrType) - this * that(0, 0) - else if (st == sclrType) - that * m(0, 0) - else - (shapeType, that.shapeType) match { - case (`matType`, `colType`) => LocalMatrix(m(::, B_*) *:* BDV(that.asArray)) - case (`matType`, `rowType`) => LocalMatrix(m(B_*, ::) *:* BDV(that.asArray)) - case (`colType`, `matType`) => LocalMatrix(that.m(::, B_*) *:* BDV(this.asArray)) - case (`rowType`, `matType`) => LocalMatrix(that.m(B_*, ::) *:* BDV(this.asArray)) - case (`colType`, `rowType`) => LocalMatrix(m * that.m) - case (`rowType`, `colType`) => LocalMatrix(that.m * m) - } - } - - def divide(that: LocalMatrix): LocalMatrix = { - val (st, st2) = LocalMatrix.checkShapes(this, that, "pointwise division") - if (st == st2) - LocalMatrix(m /:/ that.m) - else if (st2 == sclrType) - this / that(0, 0) - else if (st == sclrType) - m(0, 0) / that - else - (st, st2) match { - case (`matType`, `colType`) => LocalMatrix(m(::, B_*) /:/ BDV(that.asArray)) - case (`matType`, `rowType`) => LocalMatrix(m(B_*, ::) /:/ BDV(that.asArray)) - case _ => this * (1.0 / that) // FIXME: room for improvement - } - } - - def sqrt(): LocalMatrix = LocalMatrix(breezeSqrt(m)) - - def pow(e: Double): LocalMatrix = LocalMatrix(breezePow(m, e)) - - def t = LocalMatrix(m.t) - - def diagonal(): LocalMatrix = LocalMatrix(diag(m).toArray) - - def inverse() = LocalMatrix(inv(m)) - - def matrixMultiply(that: LocalMatrix) = LocalMatrix(m * that.m) - - // solve for X in AX = B, with A = this and B = that - def solve(that: LocalMatrix) = LocalMatrix(m \ that.m) - - // eigendecomposition of symmetric matrix using lapack.dsyevd (Divide and Conquer) - // X = USU^T with U orthonormal and S diagonal - // returns (diag(S), U) - // no symmetry check, uses lower triangle only - def eigh(): (LocalMatrix, LocalMatrix) = { - val (evals, optEvects) = eigSymD.doeigSymD(m, rightEigenvectors = true) - (LocalMatrix(evals), LocalMatrix(optEvects.get)) - } - - def eigvalsh(): LocalMatrix = { - val (evals, _) = eigSymD.doeigSymD(m, rightEigenvectors = false) - LocalMatrix(evals) - } - - // singular value decomposition of n x m matrix X using lapack.dgesdd (Divide and Conquer) - // X = USV^T with U and V orthonormal and S diagonal - // returns (U, singular values, V^T) with the following dimensions where k = min(n, m) - // regular: n x n, k x 1, m x m - // reduced: n x k, k x 1, k x m - def svd(reduced: Boolean = false): (LocalMatrix, LocalMatrix, LocalMatrix) = { - val res = if (reduced) breezeSVD.reduced(m) else breezeSVD(m) - (LocalMatrix(res.leftVectors), LocalMatrix(res.singularValues), LocalMatrix(res.rightVectors)) - } - - // QR decomposition of n x m matrix X - // X = QR with Q orthonormal columns and R upper triangular - // returns (Q, R) with the following dimensions where k = min(n, m) - // regular: n x m, m x m - // reduced: n x k, k x m - def qr(reduced: Boolean = false): (LocalMatrix, LocalMatrix) = { - val res = if (reduced) breezeQR.reduced(m) else breezeQR(m) - (LocalMatrix(res.q), LocalMatrix(res.r)) - } - - // Cholesky factor L of a symmetric, positive-definite matrix X - // X = LL^T with L lower triangular - def cholesky(): LocalMatrix = { - m.forceSymmetry() // needed to prevent failure of symmetry check, even though lapack only uses lower triangle - LocalMatrix(breezeCholesky(m)) - } -} diff --git a/src/main/scala/is/hail/methods/LDPrune.scala b/src/main/scala/is/hail/methods/LDPrune.scala index 6e29d67f724..85b6bdf1f8a 100644 --- a/src/main/scala/is/hail/methods/LDPrune.scala +++ b/src/main/scala/is/hail/methods/LDPrune.scala @@ -492,6 +492,6 @@ object LDPrune { val ((globalPrunedRDD, nVariantsFinal), globalDuration) = time(pruneGlobal(rddLP2, r2Threshold, windowSize)) info(s"LD prune step 3 of 3: nVariantsKept=$nVariantsFinal, time=${ formatTime(globalDuration) }") - vsm.copy2(rvd = vsm.rvd.copy(rdd = vsm.rvd.orderedJoinDistinct(globalPrunedRDD, "inner").map(_.rvLeft))) + vsm.copy2(rvd = vsm.rvd.orderedJoinDistinct(globalPrunedRDD, "inner", _.map(_.rvLeft), vsm.rvd.typ)) } } diff --git a/src/main/scala/is/hail/methods/PCRelate.scala b/src/main/scala/is/hail/methods/PCRelate.scala index 7cba26eb5cb..04b7b8f0f99 100644 --- a/src/main/scala/is/hail/methods/PCRelate.scala +++ b/src/main/scala/is/hail/methods/PCRelate.scala @@ -278,14 +278,14 @@ class PCRelate(maf: Double, blockSize: Int, statistics: PCRelate.StatisticSubset def apply(blockedG: M, pcs: DenseMatrix[Double]): Result[M] = { val preMu = this.mu(blockedG, pcs) - val mu = (BlockMatrix.map2 { (g, mu) => + val mu = BlockMatrix.map2 { (g, mu) => if (badgt(g) || badmu(mu)) Double.NaN else mu - } (blockedG, preMu)).cache() + } (blockedG, preMu).cache() val variance = cacheWhen(PhiK2)( - mu.map(mu => if (mu.isNaN()) 0.0 else mu * (1.0 - mu))) + mu.map(mu => if (mu.isNaN) 0.0 else mu * (1.0 - mu))) val phi = cacheWhen(PhiK2)( this.phi(mu, variance, blockedG)) @@ -296,7 +296,7 @@ class PCRelate(maf: Double, blockSize: Int, statistics: PCRelate.StatisticSubset val k0 = cacheWhen(PhiK2K0K1)( this.k0(phi, mu, k2, blockedG, ibs0(blockedG, mu, blockSize))) if (statistics >= PhiK2K0K1) { - val k1 = (1.0 - (k2 +:+ k0)) + val k1 = 1.0 - (k2 +:+ k0) Result(phi, k0, k1, k2) } else Result(phi, k0, null, k2) @@ -350,7 +350,7 @@ class PCRelate(maf: Double, blockSize: Int, statistics: PCRelate.StatisticSubset } private[methods] def k2(phi: M, mu: M, variance: M, g: M): M = { - val twoPhi_ii = phi.diagonal.map(2.0 * _) + val twoPhi_ii = phi.diagonal().map(2.0 * _) val normalizedGD = g.map2WithIndex(mu, { case (_, i, g, mu) => if (mu.isNaN) 0.0 // https://github.com/Bioconductor-mirror/GENESIS/blob/release-3.5/R/pcrelate.R#L391 diff --git a/src/main/scala/is/hail/rvd/KeyedOrderedRVD.scala b/src/main/scala/is/hail/rvd/KeyedOrderedRVD.scala new file mode 100644 index 00000000000..9a34fce078d --- /dev/null +++ b/src/main/scala/is/hail/rvd/KeyedOrderedRVD.scala @@ -0,0 +1,96 @@ +package is.hail.rvd + +import is.hail.annotations.{JoinedRegionValue, OrderedRVIterator, RegionValue} +import org.apache.spark.rdd.RDD +import is.hail.utils.fatal + +class KeyedOrderedRVD(val rvd: OrderedRVD, val key: Array[String]) { + val typ: OrderedRVDType = rvd.typ + val (kType, _) = rvd.rowType.select(key) + require(kType isPrefixOf rvd.typ.kType) + + private def checkJoinCompatability(right: KeyedOrderedRVD) { + if (!(kType isIsomorphicTo kType)) + fatal( + s"""Incompatible join keys. Keys must have same length and types, in order: + | Left key type: ${ kType.toString } + | Right key type: ${ kType.toString } + """.stripMargin) + } + + def orderedJoin( + right: KeyedOrderedRVD, + joinType: String, + joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue], + joinedType: OrderedRVDType + ): OrderedRVD = { + checkJoinCompatability(right) + val lTyp = typ + val rTyp = right.typ + + val newPartitioner = + this.rvd.partitioner.enlargeToRange(right.rvd.partitioner.range) + val repartitionedLeft = + this.rvd.constrainToOrderedPartitioner(this.typ, newPartitioner) + val repartitionedRight = + right.rvd.constrainToOrderedPartitioner(right.typ, newPartitioner) + val compute: (OrderedRVIterator, OrderedRVIterator) => Iterator[JoinedRegionValue] = + (joinType: @unchecked) match { + case "inner" => _.innerJoin(_) + case "left" => _.leftJoin(_) + case "right" => _.rightJoin(_) + case "outer" => _.outerJoin(_) + } + val joinedRDD = + repartitionedLeft.rdd.zipPartitions(repartitionedRight.rdd, true) { + (leftIt, rightIt) => + joiner(compute( + OrderedRVIterator(lTyp, leftIt), + OrderedRVIterator(rTyp, rightIt))) + } + + new OrderedRVD(joinedType, newPartitioner, joinedRDD) + } + + def orderedJoinDistinct( + right: KeyedOrderedRVD, + joinType: String, + joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue], + joinedType: OrderedRVDType + ): OrderedRVD = { + checkJoinCompatability(right) + val rekeyedLTyp = new OrderedRVDType(typ.partitionKey, key, typ.rowType) + val rekeyedRTyp = new OrderedRVDType(right.typ.partitionKey, right.key, right.typ.rowType) + + val newPartitioner = this.rvd.partitioner + val repartitionedRight = + right.rvd.constrainToOrderedPartitioner(right.typ, newPartitioner) + val compute: (OrderedRVIterator, OrderedRVIterator) => Iterator[JoinedRegionValue] = + (joinType: @unchecked) match { + case "inner" => _.innerJoinDistinct(_) + case "left" => _.leftJoinDistinct(_) + } + val joinedRDD = + this.rvd.rdd.zipPartitions(repartitionedRight.rdd, true) { + (leftIt, rightIt) => + joiner(compute( + OrderedRVIterator(rekeyedLTyp, leftIt), + OrderedRVIterator(rekeyedRTyp, rightIt))) + } + + new OrderedRVD(joinedType, newPartitioner, joinedRDD) + } + + def orderedZipJoin(right: KeyedOrderedRVD): RDD[JoinedRegionValue] = { + val newPartitioner = rvd.partitioner.enlargeToRange(right.rvd.partitioner.range) + + val repartitionedLeft = rvd.constrainToOrderedPartitioner(typ, newPartitioner) + val repartitionedRight = right.rvd.constrainToOrderedPartitioner(right.typ, newPartitioner) + + val leftType = this.typ + val rightType = right.typ + repartitionedLeft.rdd.zipPartitions(repartitionedRight.rdd, true){ (leftIt, rightIt) => + OrderedRVIterator(leftType, leftIt).zipJoin(OrderedRVIterator(rightType, rightIt)) + } + } +} diff --git a/src/main/scala/is/hail/rvd/OrderedRVD.scala b/src/main/scala/is/hail/rvd/OrderedRVD.scala index d03cf04ff01..44422ad928a 100644 --- a/src/main/scala/is/hail/rvd/OrderedRVD.scala +++ b/src/main/scala/is/hail/rvd/OrderedRVD.scala @@ -16,28 +16,13 @@ import org.apache.spark.sql.Row import scala.collection.mutable import scala.reflect.ClassTag -class OrderedRVD private( +class OrderedRVD( val typ: OrderedRVDType, val partitioner: OrderedRVDPartitioner, val rdd: RDD[RegionValue]) extends RVD with Serializable { self => def rowType: TStruct = typ.rowType - // should be totally generic, permitting any number of keys, but that requires more work - def downcastToPK(): OrderedRVD = { - val newType = new OrderedRVDType(partitionKey = typ.partitionKey, - key = typ.partitionKey, - rowType = rowType) - OrderedRVD(newType, partitioner, rdd) - } - - def upcast(castKeys: Array[String]): OrderedRVD = { - val newType = new OrderedRVDType(partitionKey = typ.partitionKey, - key = typ.key ++ castKeys, - rowType = rowType) - OrderedRVD(newType, partitioner, rdd) - } - def mapPreservesPartitioning(newTyp: OrderedRVDType)(f: (RegionValue) => RegionValue): OrderedRVD = OrderedRVD(newTyp, partitioner, @@ -95,42 +80,44 @@ class OrderedRVD private( override def unpersist(): OrderedRVD = this - def orderedJoinDistinct(right: OrderedRVD, joinType: String): RDD[JoinedRegionValue] = { - val lTyp = typ - val rTyp = right.typ + def constrainToOrderedPartitioner( + ordType: OrderedRVDType, + newPartitioner: OrderedRVDPartitioner + ): OrderedRVD = { + + require(ordType.rowType == typ.rowType) + require(ordType.kType isPrefixOf typ.kType) + require(newPartitioner.pkType isIsomorphicTo ordType.pkType) + // Should remove this requirement in the future + require(typ.pkType isPrefixOf ordType.pkType) + + new OrderedRVD( + typ = ordType, + partitioner = newPartitioner, + rdd = new RepartitionedOrderedRDD2(this, newPartitioner)) + } - if (!lTyp.kType.types.sameElements(rTyp.kType.types)) - fatal( - s"""Incompatible join keys. Keys must have same length and types, in order: - | Left key type: ${ lTyp.kType.toString } - | Right key type: ${ rTyp.kType.toString } - """.stripMargin) + def keyBy(key: Array[String] = typ.key): KeyedOrderedRVD = + new KeyedOrderedRVD(this, key) - joinType match { - case "inner" | "left" => new OrderedJoinDistinctRDD2(this, right, joinType) - case _ => fatal(s"Unknown join type `$joinType'. Choose from `inner' or `left'.") - } - } + def orderedJoin( + right: OrderedRVD, + joinType: String, + joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue], + joinedType: OrderedRVDType + ): OrderedRVD = + keyBy().orderedJoin(right.keyBy(), joinType, joiner, joinedType) - def orderedZipJoin(right: OrderedRVD): OrderedZipJoinRDD = { - val pkOrd = this.partitioner.pkType.ordering - if (this.partitioner.range.includes(pkOrd, right.partitioner.range)) - new OrderedZipJoinRDD(this, right) - else { - val newRangeBounds = partitioner.rangeBounds.toArray - val newStart = pkOrd.min(this.partitioner.range.start, right.partitioner.range.start) - val newEnd = pkOrd.max(this.partitioner.range.end, right.partitioner.range.end) - newRangeBounds(0) = newRangeBounds(0).asInstanceOf[Interval] - .copy(start = newStart, includesStart = true) - newRangeBounds(newRangeBounds.length - 1) = newRangeBounds(newRangeBounds.length - 1).asInstanceOf[Interval] - .copy(end = newEnd, includesEnd = true) - - val newPartitioner = new OrderedRVDPartitioner(partitioner.partitionKey, - partitioner.kType, UnsafeIndexedSeq(partitioner.rangeBoundsType, newRangeBounds)) - - new OrderedZipJoinRDD(OrderedRVD(this.typ, newPartitioner, this.rdd), right) - } - } + def orderedJoinDistinct( + right: OrderedRVD, + joinType: String, + joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue], + joinedType: OrderedRVDType + ): OrderedRVD = + keyBy().orderedJoinDistinct(right.keyBy(), joinType, joiner, joinedType) + + def orderedZipJoin(right: OrderedRVD): RDD[JoinedRegionValue] = + keyBy().orderedZipJoin(right.keyBy()) def partitionSortedUnion(rdd2: OrderedRVD): OrderedRVD = { assert(typ == rdd2.typ) @@ -385,8 +372,8 @@ object OrderedRVD { def apply(typ: OrderedRVDType, rdd: RDD[RegionValue], fastKeys: Option[RDD[RegionValue]], hintPartitioner: Option[OrderedRVDPartitioner]): OrderedRVD = { - val (_, orderedRDD) = coerce(typ, rdd, fastKeys, hintPartitioner) - orderedRDD + val (_, orderedRVD) = coerce(typ, rdd, fastKeys, hintPartitioner) + orderedRVD } /** @@ -564,7 +551,6 @@ object OrderedRVD { val pkType = partitioner.pkType val pkOrdUnsafe = pkType.unsafeOrdering(true) - val pkOrd = pkType.ordering.toOrdering val pkis = getPartitionKeyInfo(typ, OrderedRVD.getKeys(typ, rdd)) if (pkis.isEmpty) @@ -573,18 +559,7 @@ object OrderedRVD { val min = new UnsafeRow(pkType, pkis.map(_.min).min(pkOrdUnsafe)) val max = new UnsafeRow(pkType, pkis.map(_.max).max(pkOrdUnsafe)) - val newRangeBounds = partitioner.rangeBounds.toArray - - newRangeBounds(0) = newRangeBounds(0).asInstanceOf[Interval] - .copy(start = pkOrd.min(newRangeBounds(0).asInstanceOf[Interval].start, min)) - - newRangeBounds(newRangeBounds.length - 1) = newRangeBounds(newRangeBounds.length - 1).asInstanceOf[Interval] - .copy(end = pkOrd.max(newRangeBounds(newRangeBounds.length - 1).asInstanceOf[Interval].end, max)) - - val newPartitioner = new OrderedRVDPartitioner(partitioner.partitionKey, - partitioner.kType, UnsafeIndexedSeq(partitioner.rangeBoundsType, newRangeBounds)) - - shuffle(typ, newPartitioner, rdd) + shuffle(typ, partitioner.enlargeToRange(Interval(min, max, true, true)), rdd) } def shuffle(typ: OrderedRVDType, diff --git a/src/main/scala/is/hail/rvd/OrderedRVDPartitioner.scala b/src/main/scala/is/hail/rvd/OrderedRVDPartitioner.scala index 7f76f6619d4..ec999923e57 100644 --- a/src/main/scala/is/hail/rvd/OrderedRVDPartitioner.scala +++ b/src/main/scala/is/hail/rvd/OrderedRVDPartitioner.scala @@ -3,6 +3,7 @@ package is.hail.rvd import is.hail.annotations._ import is.hail.expr.types._ import is.hail.utils._ +import org.apache.spark.sql.Row import org.apache.spark.{Partitioner, SparkContext} import org.apache.spark.broadcast.Broadcast @@ -33,8 +34,6 @@ class OrderedRVDPartitioner( (rangeBounds(i).asInstanceOf[Interval], i) }) - val pkKFieldIdx: Array[Int] = partitionKey.map(n => kType.fieldIdx(n)) - def region: Region = rangeBounds.region def loadElement(i: Int): Long = rangeBoundsType.loadElement(region, rangeBounds.aoff, rangeBounds.length, i) @@ -45,17 +44,23 @@ class OrderedRVDPartitioner( def range: Interval = rangeTree.root.get.range - // if outside bounds, return min or max depending on location - // pk: Annotation[pkType] - def getPartitionPK(pk: Any): Int = { - assert(pkType.typeCheck(pk)) - val part = rangeTree.queryValues(pkType.ordering, pk) + /** + * Find the partition containing the given Row. + * + * If pkType is a prefix of the type of row, the prefix of row is used to + * find the partition. + * + * If row falls outside the bounds of the partitioner, return the min or max + * partition. + */ + def getPartitionPK(row: Any): Int = { + val part = rangeTree.queryValues(pkType.ordering, row) part match { case Array() => - if (range.isAbovePosition(pkType.ordering, pk)) + if (range.isAbovePosition(pkType.ordering, row)) 0 else { - assert(range.isBelowPosition(pkType.ordering, pk)) + assert(range.isBelowPosition(pkType.ordering, row)) numPartitions - 1 } @@ -63,29 +68,38 @@ class OrderedRVDPartitioner( } } - // return the partition containing key - // if outside bounds, return min or max depending on location - // key: RegionValue[kType] - def getPartition(key: Any): Int = { - val keyrv = key.asInstanceOf[RegionValue] - val wpkrv = WritableRegionValue(pkType) - wpkrv.setSelect(kType, pkKFieldIdx, keyrv) - val pkUR = new UnsafeRow(pkType, wpkrv.value) - - val part = rangeTree.queryValues(pkType.ordering, pkUR) - - part match { - case Array() => - if (range.isAbovePosition(pkType.ordering, pkUR)) - 0 + // Return the sequence of partition IDs overlapping the given interval of + // partition keys. + def getPartitionRange(query: Any): Seq[Int] = { + query match { + case row: Row => + rangeTree.queryValues(pkType.ordering, row) + case interval: Interval => + if (!rangeTree.probablyOverlaps(pkType.ordering, interval)) + Seq.empty[Int] else { - assert(range.isBelowPosition(pkType.ordering, pkUR)) - numPartitions - 1 + val startRange = getPartitionRange(interval.start) + val start = if (startRange.nonEmpty) + startRange.min + else + 0 + val endRange = getPartitionRange(interval.end) + val end = if (endRange.nonEmpty) + endRange.max + else + numPartitions - 1 + start to end } - case Array(x) => x } } + // return the partition containing key + // if outside bounds, return min or max depending on location + // key: RegionValue[kType] + def getPartition(key: Any): Int = + getPartitionPK(new UnsafeRow(kType, key.asInstanceOf[RegionValue])) + + def withKType(newPartitionKey: Array[String], newKType: TStruct): OrderedRVDPartitioner = { val (newPKType, _) = newKType.select(newPartitionKey) val newRangeBounds = new UnsafeIndexedSeq(TArray(TInterval(newPKType)), rangeBounds.region, rangeBounds.aoff) @@ -99,6 +113,18 @@ class OrderedRVDPartitioner( new OrderedRVDPartitioner(partitionKey, kType, rangeBounds) } + // FIXME Make work if newRange has different point type than pkType + def enlargeToRange(newRange: Interval): OrderedRVDPartitioner = { + val newStart = pkType.ordering.min(range.start, newRange.start) + val newEnd = pkType.ordering.max(range.end, newRange.end) + val newRangeBounds = rangeBounds.toArray + newRangeBounds(0) = newRangeBounds(0).asInstanceOf[Interval] + .copy(start = newStart, includesStart = true) + newRangeBounds(newRangeBounds.length - 1) = newRangeBounds(newRangeBounds.length - 1) + .asInstanceOf[Interval].copy(end = newEnd, includesEnd = true) + copy(rangeBounds = UnsafeIndexedSeq(rangeBoundsType, newRangeBounds)) + } + def coalesceRangeBounds(newPartEnd: Array[Int]): OrderedRVDPartitioner = { val newRangeBounds = UnsafeIndexedSeq( rangeBoundsType, diff --git a/src/main/scala/is/hail/rvd/RVD.scala b/src/main/scala/is/hail/rvd/RVD.scala index c78eb0ecafa..1d3e1d5027a 100644 --- a/src/main/scala/is/hail/rvd/RVD.scala +++ b/src/main/scala/is/hail/rvd/RVD.scala @@ -147,6 +147,11 @@ trait RVD { def mapPartitions[T](f: (Iterator[RegionValue]) => Iterator[T])(implicit tct: ClassTag[T]): RDD[T] = rdd.mapPartitions(f) + def constrainToOrderedPartitioner( + ordType: OrderedRVDType, + newPartitioner: OrderedRVDPartitioner + ): OrderedRVD + def treeAggregate[U: ClassTag](zeroValue: U)( seqOp: (U, RegionValue) => U, combOp: (U, U) => U, diff --git a/src/main/scala/is/hail/rvd/UnpartitionedRVD.scala b/src/main/scala/is/hail/rvd/UnpartitionedRVD.scala index cc378cfc3bd..d581cd3daa1 100644 --- a/src/main/scala/is/hail/rvd/UnpartitionedRVD.scala +++ b/src/main/scala/is/hail/rvd/UnpartitionedRVD.scala @@ -1,9 +1,9 @@ package is.hail.rvd -import is.hail.annotations.RegionValue -import is.hail.utils._ +import is.hail.annotations.{KeyedRow, RegionValue, UnsafeRow} import is.hail.expr.types.TStruct import is.hail.io.CodecSpec +import is.hail.utils._ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -49,4 +49,26 @@ class UnpartitionedRVD(val rowType: TStruct, val rdd: RDD[RegionValue]) extends } def coalesce(maxPartitions: Int, shuffle: Boolean): UnpartitionedRVD = new UnpartitionedRVD(rowType, rdd.coalesce(maxPartitions, shuffle = shuffle)) + + def constrainToOrderedPartitioner( + ordType: OrderedRVDType, + newPartitioner: OrderedRVDPartitioner + ): OrderedRVD = { + + assert(ordType.rowType == rowType) + + val localRowType = rowType + val pkOrdering = ordType.pkType.ordering + val rangeTree = newPartitioner.rangeTree + val filtered = rdd.mapPartitions { it => + val ur = new UnsafeRow(localRowType, null, 0) + val key = new KeyedRow(ur, ordType.pkRowFieldIdx) + it.filter { rv => + ur.set(rv) + rangeTree.contains(pkOrdering, key) + } + } + + OrderedRVD.shuffle(ordType, newPartitioner, filtered) + } } diff --git a/src/main/scala/is/hail/sparkextras/OrderedRDD2.scala b/src/main/scala/is/hail/sparkextras/OrderedRDD2.scala deleted file mode 100644 index 04bcb2255a6..00000000000 --- a/src/main/scala/is/hail/sparkextras/OrderedRDD2.scala +++ /dev/null @@ -1,111 +0,0 @@ -package is.hail.sparkextras - -import is.hail.annotations._ -import is.hail.rvd.{OrderedRVD, OrderedRVDPartitioner, OrderedRVDType} -import is.hail.utils._ -import org.apache.spark._ -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row - -class OrderedDependency(left: OrderedRVD, right: OrderedRVD) extends NarrowDependency[RegionValue](right.rdd) { - override def getParents(partitionId: Int): Seq[Int] = - OrderedDependency.getDependencies(left.partitioner, right.partitioner)(partitionId) -} - -object OrderedDependency { - def getDependencies(p1: OrderedRVDPartitioner, p2: OrderedRVDPartitioner)(partitionId: Int): Seq[Int] = { - val partBounds = p1.rangeBounds(partitionId).asInstanceOf[Interval] - - if (!p2.rangeTree.probablyOverlaps(p2.pkType.ordering, partBounds)) - Seq.empty[Int] - else { - val start = p2.getPartitionPK(partBounds.start) - val end = p2.getPartitionPK(partBounds.end) - start to end - } - } -} - -case class OrderedJoinDistinctRDD2Partition(index: Int, leftPartition: Partition, rightPartitions: Array[Partition]) extends Partition - -class OrderedJoinDistinctRDD2(left: OrderedRVD, right: OrderedRVD, joinType: String) - extends RDD[JoinedRegionValue](left.sparkContext, - Seq[Dependency[_]](new OneToOneDependency(left.rdd), - new OrderedDependency(left, right))) { - assert(joinType == "left" || joinType == "inner") - override val partitioner: Option[Partitioner] = Some(left.partitioner) - - def getPartitions: Array[Partition] = { - Array.tabulate[Partition](left.getNumPartitions)(i => - OrderedJoinDistinctRDD2Partition(i, - left.partitions(i), - OrderedDependency.getDependencies(left.partitioner, right.partitioner)(i) - .map(right.partitions) - .toArray)) - } - - override def getPreferredLocations(split: Partition): Seq[String] = left.rdd.preferredLocations(split) - - override def compute(split: Partition, context: TaskContext): Iterator[JoinedRegionValue] = { - val partition = split.asInstanceOf[OrderedJoinDistinctRDD2Partition] - - val leftIt = left.rdd.iterator(partition.leftPartition, context) - val rightIt = partition.rightPartitions.iterator.flatMap { p => - right.rdd.iterator(p, context) - } - - joinType match { - case "inner" => OrderedRVIterator(left.typ, leftIt) - .innerJoinDistinct(OrderedRVIterator(right.typ, rightIt)) - case "left" => OrderedRVIterator(left.typ, leftIt) - .leftJoinDistinct(OrderedRVIterator(right.typ, rightIt)) - case _ => fatal(s"Unknown join type `$joinType'. Choose from `inner' or `left'.") - } - } -} - -case class OrderedZipJoinRDDPartition( - index: Int, - leftPartition: Partition, - rightPartitions: Array[Partition]) - extends Partition - -class OrderedZipJoinRDD(left: OrderedRVD, right: OrderedRVD) - extends RDD[JoinedRegionValue](left.sparkContext, - Seq[Dependency[_]](new OneToOneDependency(left.rdd), - new OrderedDependency(left, right))) { - - assert(left.partitioner.range.includes(left.partitioner.pkType.ordering, right.partitioner.range)) - - private val leftPartitionForRightRow = new OrderedRVDPartitioner( - right.typ.partitionKey, - right.typ.rowType, - left.partitioner.rangeBounds) - - override val partitioner: Option[Partitioner] = Some(left.partitioner) - - def getPartitions: Array[Partition] = { - Array.tabulate[Partition](left.getNumPartitions)(i => - OrderedZipJoinRDDPartition(i, - left.partitions(i), - OrderedDependency.getDependencies(left.partitioner, right.partitioner)(i) - .map(right.partitions) - .toArray)) - } - - override def getPreferredLocations(split: Partition): Seq[String] = - left.rdd.preferredLocations(split) - - override def compute(split: Partition, context: TaskContext): Iterator[JoinedRegionValue] = { - val partition = split.asInstanceOf[OrderedZipJoinRDDPartition] - val index = partition.index - - val leftIt = left.rdd.iterator(partition.leftPartition, context) - val rightIt = partition.rightPartitions.iterator.flatMap { p => - right.rdd.iterator(p, context) - } - .filter { rrv => leftPartitionForRightRow.getPartition(rrv) == index } - - OrderedRVIterator(left.typ, leftIt).zipJoin(OrderedRVIterator(right.typ, rightIt)) - } -} diff --git a/src/main/scala/is/hail/sparkextras/RepartitionedOrderedRDD2.scala b/src/main/scala/is/hail/sparkextras/RepartitionedOrderedRDD2.scala new file mode 100644 index 00000000000..8b34b2f7687 --- /dev/null +++ b/src/main/scala/is/hail/sparkextras/RepartitionedOrderedRDD2.scala @@ -0,0 +1,69 @@ +package is.hail.sparkextras + +import is.hail.annotations._ +import is.hail.rvd.{OrderedRVD, OrderedRVDPartitioner, OrderedRVDType} +import is.hail.utils._ +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row + +/** + * Repartition prev to comply with newPartitioner, using narrow dependencies. + * Assumes new key type is a prefix of old key type, so no reordering is + * needed. No assumption should need to be made about partition keys, but currently + * assumes old partition key type is a prefix of the new partition key type. + */ +class RepartitionedOrderedRDD2( + prev: OrderedRVD, + newPartitioner: OrderedRVDPartitioner) + extends RDD[RegionValue](prev.sparkContext, Nil) { // Nil since we implement getDependencies + +// require(newPartitioner.kType isPrefixOf prev.typ.kType) + // There should really be no precondition on partition keys. Drop this when + // we're able + require(prev.typ.pkType isPrefixOf newPartitioner.pkType) + + def getPartitions: Array[Partition] = { + Array.tabulate[Partition](newPartitioner.numPartitions) { i => + RepartitionedOrderedRDD2Partition( + i, + dependency.getParents(i).toArray.map(prev.rdd.partitions), + newPartitioner.rangeBounds(i).asInstanceOf[Interval]) + } + } + + override def compute(partition: Partition, context: TaskContext): Iterator[RegionValue] = { + val ordPartition = partition.asInstanceOf[RepartitionedOrderedRDD2Partition] + val it = ordPartition.parents.iterator + .flatMap { parentPartition => + prev.rdd.iterator(parentPartition, context) + } + OrderedRVIterator(prev.typ, it).restrictToPKInterval(ordPartition.range) + } + + val dependency = new OrderedDependency(prev, newPartitioner) + + override def getDependencies: Seq[Dependency[_]] = Seq(dependency) +} + +class OrderedDependency(prev: OrderedRVD, newPartitioner: OrderedRVDPartitioner) + extends NarrowDependency[RegionValue](prev.rdd) { + + // no precondition on partition keys +// require(newPartitioner.kType isPrefixOf prev.typ.kType) + // There should really be no precondition on partition keys. Drop this when + // we're able + require(prev.typ.pkType isPrefixOf newPartitioner.pkType) + + override def getParents(partitionId: Int): Seq[Int] = { + val partBounds = + newPartitioner.rangeBounds(partitionId).asInstanceOf[Interval] + prev.partitioner.getPartitionRange(partBounds) + } +} + +case class RepartitionedOrderedRDD2Partition( + index: Int, + parents: Array[Partition], + range: Interval) + extends Partition diff --git a/src/main/scala/is/hail/stats/BaldingNicholsModel.scala b/src/main/scala/is/hail/stats/BaldingNicholsModel.scala index 244b44ec478..18c0a1cdfd0 100644 --- a/src/main/scala/is/hail/stats/BaldingNicholsModel.scala +++ b/src/main/scala/is/hail/stats/BaldingNicholsModel.scala @@ -7,15 +7,22 @@ import is.hail.annotations._ import is.hail.expr.types._ import is.hail.rvd.OrderedRVD import is.hail.utils._ -import is.hail.variant.{Call, Call2, ReferenceGenome, MatrixTable} +import is.hail.variant.{Call2, ReferenceGenome, MatrixTable} import org.apache.commons.math3.random.JDKRandomGenerator object BaldingNicholsModel { - def apply(hc: HailContext, nPops: Int, nSamples: Int, nVariants: Int, - popDistArrayOpt: Option[Array[Double]], FstOfPopArrayOpt: Option[Array[Double]], - seed: Int, nPartitionsOpt: Option[Int], af_dist: Distribution, - rg: ReferenceGenome = ReferenceGenome.defaultReference): MatrixTable = { + def apply(hc: HailContext, + nPops: Int, + nSamples: Int, + nVariants: Int, + popDistArrayOpt: Option[Array[Double]], + FstOfPopArrayOpt: Option[Array[Double]], + seed: Int, + nPartitionsOpt: Option[Int], + af_dist: Distribution, + rg: ReferenceGenome = ReferenceGenome.defaultReference, + mixture: Boolean = false): MatrixTable = { val sc = hc.sc @@ -69,17 +76,24 @@ object BaldingNicholsModel { Rand.generator.setSeed(seed) val popDist_k = popDist - popDist_k :/= sum(popDist_k) - - val popDistRV = Multinomial(popDist_k) - val popOfSample_n: DenseVector[Int] = DenseVector.fill[Int](N)(popDistRV.draw()) + val popOfSample_n = DenseMatrix.zeros[Double](if (mixture) K else 1, N) + + if (mixture) { + val popDistRV = Dirichlet(popDist_k) + (0 until N).foreach(j => popOfSample_n(::, j) := popDistRV.draw()) + } else { + popDist_k :/= sum(popDist_k) + val popDistRV = Multinomial(popDist_k) + (0 until N).foreach(j => popOfSample_n(0, j) = popDistRV.draw()) + } + val popOfSample_nBc = sc.broadcast(popOfSample_n) val Fst_k = FstOfPop val Fst1_k = (1d - Fst_k) /:/ Fst_k val Fst1_kBc = sc.broadcast(Fst1_k) - val saSignature = TStruct("sample_idx" -> TInt32(), "pop" -> TInt32()) + val saSignature = TStruct("sample_idx" -> TInt32(), "pop" -> (if (mixture) TArray(TFloat64()) else TInt32())) val vaSignature = TStruct("ancestralAF" -> TFloat64(), "AF" -> TArray(TFloat64())) val ancestralAFAnnotation = af_dist match { @@ -88,7 +102,7 @@ object BaldingNicholsModel { case TruncatedBetaDist(a, b, min, max) => Annotation("TruncatedBetaDist", a, b, min, max) } val globalAnnotation = - Annotation(K, N, M, popDistArray: IndexedSeq[Double], FstOfPopArray: IndexedSeq[Double], ancestralAFAnnotation, seed) + Annotation(K, N, M, popDistArray: IndexedSeq[Double], FstOfPopArray: IndexedSeq[Double], ancestralAFAnnotation, seed, mixture) val ancestralAFAnnotationSignature = af_dist match { case UniformDist(min, max) => TStruct("type" -> TString(), "min" -> TFloat64(), "max" -> TFloat64()) @@ -103,7 +117,8 @@ object BaldingNicholsModel { "pop_dist" -> TArray(TFloat64()), "fst" -> TArray(TFloat64()), "ancestral_af_dist" -> ancestralAFAnnotationSignature, - "seed" -> TInt32()) + "seed" -> TInt32(), + "mixture" -> TBoolean()) val matrixType: MatrixType = MatrixType.fromParts( globalType = globalSignature, @@ -130,9 +145,11 @@ object BaldingNicholsModel { val ancestralAF = af_dist.getBreezeDist(perVariantRandomBasis).draw() - val popAF_k: IndexedSeq[Double] = Array.tabulate(K) { k => - new Beta(ancestralAF * Fst1_kBc.value(k), (1 - ancestralAF) * Fst1_kBc.value(k))(perVariantRandomBasis).draw() - } + val popAF_k: DenseVector[Double] = DenseVector( + Array.tabulate(K) { k => + new Beta(ancestralAF * Fst1_kBc.value(k), (1 - ancestralAF) * Fst1_kBc.value(k))(perVariantRandomBasis) + .draw() + }) region.clear() rvb.start(rvType) @@ -165,7 +182,11 @@ object BaldingNicholsModel { i = 0 val unif = new Uniform(0, 1)(perVariantRandomBasis) while (i < N) { - val p = popAF_k(popOfSample_nBc.value(i)) + val p = + if (mixture) + popOfSample_nBc.value(::, i) dot popAF_k + else + popAF_k(popOfSample_nBc.value(0, i).toInt) val pSq = p * p val x = unif.draw() val c = @@ -188,7 +209,11 @@ object BaldingNicholsModel { } } - val sampleAnnotations = (0 until N).map { i => Annotation(i, popOfSample_n(i)) }.toArray + val sampleAnnotations: Array[Annotation] = + if (mixture) + Array.tabulate(N)(i => Annotation(i, popOfSample_n(::, i).data.toIndexedSeq)) + else + Array.tabulate(N)(i => Annotation(i, popOfSample_n(0, i).toInt)) // FIXME: should use fast keys val ordrdd = OrderedRVD(matrixType.orvdType, rdd, None, None) diff --git a/src/main/scala/is/hail/table/Table.scala b/src/main/scala/is/hail/table/Table.scala index 0901f27b4c9..160ab2bbfa5 100644 --- a/src/main/scala/is/hail/table/Table.scala +++ b/src/main/scala/is/hail/table/Table.scala @@ -354,70 +354,24 @@ class Table(val hc: HailContext, val tir: TableIR) { annotateGlobal(ann, t, name) } - def annotateGlobalExpr(expr: String): Table = { + def selectGlobal(expr: String): Table = { val ec = EvalContext("global" -> globalSignature) ec.set(0, globals.value) - val (paths, types, f) = Parser.parseAnnotationExprs(expr, ec, None) - - val inserterBuilder = new ArrayBuilder[Inserter]() - - val finalType = (paths, types).zipped.foldLeft(globalSignature) { case (v, (ids, signature)) => - val (s, i) = v.insert(signature, ids) - inserterBuilder += i - s.asInstanceOf[TStruct] - } - - val inserters = inserterBuilder.result() - - val ga = inserters - .zip(f()) - .foldLeft(globals.value) { case (a, (ins, res)) => - ins(a, res).asInstanceOf[Row] - } - - copy2(globals = globals.copy(value = ga, t = finalType), - globalSignature = finalType) - } - - def selectGlobal(fields: java.util.ArrayList[String]): Table = { - selectGlobal(fields.asScala.toArray: _*) - } - - def selectGlobal(fields: String*): Table = { - val ec = EvalContext("global" -> globalSignature) - ec.set(0, globals.value) - - val (paths, types, f) = Parser.parseSelectExprs(fields.toArray, ec) - - val names = paths.map { - case Left(n) => n - case Right(l) => l.last - } - - val overlappingPaths = names.counter().filter { case (n, i) => i != 1 }.keys - - if (overlappingPaths.nonEmpty) - fatal(s"Found ${ overlappingPaths.size } ${ plural(overlappingPaths.size, "selected field name") } that are duplicated.\n" + - "Overlapping fields:\n " + - s"@1", overlappingPaths.truncatable("\n ")) + val ast = Parser.parseToAST(expr, ec) + assert(ast.`type`.isInstanceOf[TStruct]) - val inserterBuilder = new ArrayBuilder[Inserter]() + ast.toIR() match { + case Some(ir) => + new Table(hc, TableMapGlobals(tir, ir)) + case None => + val (t, f) = Parser.parseExpr(expr, ec) + val newSignature = t.asInstanceOf[TStruct] + val newGlobal = f() - val finalSignature = (names, types).zipped.foldLeft(TStruct()) { case (vs, (p, sig)) => - val (s: TStruct, i) = vs.insert(sig, p) - inserterBuilder += i - s + copy2(globalSignature = newSignature, + globals = globals.copy(value = newGlobal, t = newSignature)) } - - val inserters = inserterBuilder.result() - - val newGlobal = f().zip(inserters) - .foldLeft(Row()) { case (a1, (v, inserter)) => - inserter(a1, v).asInstanceOf[Row] - } - - copy2(globalSignature = finalSignature, globals = globals.copy(value = newGlobal, t = finalSignature)) } def filter(cond: String, keep: Boolean): Table = { @@ -496,72 +450,8 @@ class Table(val hc: HailContext, val tir: TableIR) { } } - def join(other: Table, joinType: String): Table = { - if (key.length != other.key.length || !(keyFields.map(_.typ) sameElements other.keyFields.map(_.typ))) - fatal( - s"""Both tables must have the same number of keys and the types of keys must be identical. Order matters. - | Left signature: ${ keySignature.toString } - | Right signature: ${ other.keySignature.toString }""".stripMargin) - - val joinedFields = keySignature.fields ++ valueSignature.fields ++ other.valueSignature.fields - - val preNames = joinedFields.map(_.name).toArray - val (finalColumnNames, remapped) = mangle(preNames) - if (remapped.nonEmpty) { - warn(s"Remapped ${ remapped.length } ${ plural(remapped.length, "column") } from right-hand table:\n @1", - remapped.map { case (pre, post) => s""""$pre" => "$post"""" }.truncatable("\n ")) - } - - val newSignature = TStruct(joinedFields - .zipWithIndex - .map { case (fd, i) => (finalColumnNames(i), fd.typ) }: _*) - val localNKeys = nKeys - val size1 = valueSignature.size - val size2 = other.valueSignature.size - val totalSize = newSignature.size - - assert(totalSize == localNKeys + size1 + size2) - - val merger = (k: Row, r1: Row, r2: Row) => { - val result = Array.fill[Any](totalSize)(null) - - var i = 0 - while (i < localNKeys) { - result(i) = k.get(i) - i += 1 - } - - if (r1 != null) { - i = 0 - while (i < size1) { - result(localNKeys + i) = r1.get(i) - i += 1 - } - } - - if (r2 != null) { - i = 0 - while (i < size2) { - result(localNKeys + size1 + i) = r2.get(i) - i += 1 - } - } - Row.fromSeq(result) - } - - val rddLeft = keyedRDD() - val rddRight = other.keyedRDD() - - val joinedRDD = joinType match { - case "left" => rddLeft.leftOuterJoin(rddRight).map { case (k, (l, r)) => merger(k, l, r.orNull) } - case "right" => rddLeft.rightOuterJoin(rddRight).map { case (k, (l, r)) => merger(k, l.orNull, r) } - case "inner" => rddLeft.join(rddRight).map { case (k, (l, r)) => merger(k, l, r) } - case "outer" => rddLeft.fullOuterJoin(rddRight).map { case (k, (l, r)) => merger(k, l.orNull, r.orNull) } - case _ => fatal("Invalid join type specified. Choose one of `left', `right', `inner', `outer'") - } - - copy(rdd = joinedRDD, signature = newSignature, key = key) - } + def join(other: Table, joinType: String): Table = + new Table(hc, TableJoin(this.tir, other.tir, joinType)) def export(output: String, typesFile: String = null, header: Boolean = true, exportType: Int = ExportType.CONCATENATED) { val hConf = hc.hadoopConf @@ -1085,7 +975,7 @@ class Table(val hc: HailContext, val tir: TableIR) { nodes.annotateGlobal(relatedNodesToKeep, TSet(iType), "relatedNodesToKeep") .filter(s"global.relatedNodesToKeep.contains(row.node)", keep = keep) - .selectGlobal() + .selectGlobal("{}") } def show(n: Int = 10, truncate: Option[Int] = None, printTypes: Boolean = true, maxWidth: Int = 100): Unit = { @@ -1252,8 +1142,6 @@ class Table(val hc: HailContext, val tir: TableIR) { } def toOrderedRVD(hintPartitioner: Option[OrderedRVDPartitioner], partitionKeys: Int): OrderedRVD = { - val localSignature = signature - val orderedKTType = new OrderedRVDType(key.take(partitionKeys).toArray, key.toArray, signature) assert(hintPartitioner.forall(p => p.pkType.types.sameElements(orderedKTType.pkType.types))) OrderedRVD(orderedKTType, rvd.rdd, None, hintPartitioner) diff --git a/src/main/scala/is/hail/variant/MatrixTable.scala b/src/main/scala/is/hail/variant/MatrixTable.scala index 56d894198a2..a2bd2ca651d 100644 --- a/src/main/scala/is/hail/variant/MatrixTable.scala +++ b/src/main/scala/is/hail/variant/MatrixTable.scala @@ -4,6 +4,7 @@ import is.hail.annotations._ import is.hail.check.Gen import is.hail.linalg._ import is.hail.expr._ +import is.hail.expr.ir import is.hail.methods._ import is.hail.rvd._ import is.hail.table.{Table, TableSpec} @@ -976,75 +977,70 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { def orderedRVDLeftJoinDistinctAndInsert(right: OrderedRVD, root: String, product: Boolean): MatrixTable = { assert(!rowKey.contains(root)) - assert(right.typ.pkType.types.map(_.deepOptional()) - .sameElements(rowPartitionKeyTypes.map(_.deepOptional()))) + val valueType = if (product) + TArray(right.typ.valueType, required = true) + else + right.typ.valueType - val (leftRVD, upcastKeys) = if (right.typ.kType.types.map(_.deepOptional()).sameElements(rowPartitionKeyTypes.map(_.deepOptional()))) { - (rvd.downcastToPK(), rowKey.drop(rowPartitionKey.length).toArray) - } else (rvd, Array.empty[String]) - - var valueType: Type = right.typ.valueType - - var rightRVD = right - if (product) { - valueType = TArray(valueType, required = true) - rightRVD = rightRVD.groupByKey(" !!! values !!! ") - } + val rightRVD = if (product) + right.groupByKey(" !!! values !!! ") + else + right val (newRVType, ins) = rvRowType.unsafeStructInsert(valueType, List(root)) - val leftRowType = leftRVD.rowType val rightRowType = rightRVD.rowType - val oldRVType = leftRVD.typ.rowType val rightValueIndices = rightRVD.typ.valueIndices assert(!product || rightValueIndices.length == 1) val newMatrixType = matrixType.copy(rvRowType = newRVType) - val intermediateMatrixType = newMatrixType.copy(rowKey = newMatrixType.rowPartitionKey) - copyMT(matrixType = newMatrixType, - rvd = OrderedRVD( - newMatrixType.orvdType, - leftRVD.partitioner, - leftRVD.orderedJoinDistinct(rightRVD, "left") - .mapPartitions { it => - val rvb = new RegionValueBuilder() - val rv = RegionValue() - - it.map { jrv => - val lrv = jrv.rvLeft - - rvb.set(lrv.region) - rvb.start(newRVType) - ins(lrv.region, lrv.offset, rvb, - () => { - if (product) { - if (jrv.rvRight == null) { - rvb.startArray(0) - rvb.endArray() - } else - rvb.addField(rightRowType, jrv.rvRight, rightValueIndices(0)) - } else { - if (jrv.rvRight == null) - rvb.setMissing() - else { - rvb.startStruct() - var i = 0 - while (i < rightValueIndices.length) { - rvb.addField(rightRowType, jrv.rvRight, rightValueIndices(i)) - i += 1 - } - rvb.endStruct() - } - } - }) - rv.set(lrv.region, rvb.end()) - rv + val joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue] = { it => + val rvb = new RegionValueBuilder() + val rv = RegionValue() + + it.map { jrv => + val lrv = jrv.rvLeft + + rvb.set(lrv.region) + rvb.start(newRVType) + ins(lrv.region, lrv.offset, rvb, + () => { + if (product) { + if (jrv.rvRight == null) { + rvb.startArray(0) + rvb.endArray() + } else + rvb.addField(rightRowType, jrv.rvRight, rightValueIndices(0)) + } else { + if (jrv.rvRight == null) + rvb.setMissing() + else { + rvb.startStruct() + var i = 0 + while (i < rightValueIndices.length) { + rvb.addField(rightRowType, jrv.rvRight, rightValueIndices(i)) + i += 1 + } + rvb.endStruct() + } } }) + rv.set(lrv.region, rvb.end()) + rv + } + } + + val joinedRVD = this.rvd.keyBy(rowKey.take(right.typ.key.length).toArray).orderedJoinDistinct( + right.keyBy(), + "left", + joiner, + newMatrixType.orvdType ) + + copyMT(matrixType = newMatrixType, rvd = joinedRVD) } private def annotateRowsIntervalTable(kt: Table, root: String, product: Boolean): MatrixTable = { @@ -1336,106 +1332,46 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { rvd = rvd.mapPartitionsPreservesPartitioning(newMatrixType.orvdType)(mapPartitionsF)) } - def selectEntries(selectExprs: java.util.ArrayList[String]): MatrixTable = selectEntries(selectExprs.asScala.toArray: _*) - - def selectEntries(exprs: String*): MatrixTable = { + def selectEntries(expr: String): MatrixTable = { val ec = entryEC - val globalsBc = globals.broadcast - - val (paths, types, f) = Parser.parseSelectExprs(exprs.toArray, ec) - val topLevelFields = mutable.Set.empty[String] - - val finalNames = paths.map { - // assignment - case Left(name) => name - case Right(path) => - assert(path.head == Annotation.ENTRY_HEAD) - path match { - case List(Annotation.ENTRY_HEAD, name) => topLevelFields += name - } - path.last - } - assert(finalNames.areDistinct()) - - val newEntryType = TStruct(finalNames.zip(types): _*) - val fullRowType = rvRowType - val localEntriesIndex = entriesIndex - val localNCols = numCols - val localColValuesBc = colValuesBc - - insertEntries(() => { - val fullRow = new UnsafeRow(fullRowType) - val row = fullRow.deleteField(localEntriesIndex) - ec.set(0, globalsBc.value) - ec.set(1, row) - fullRow -> row - })(newEntryType, { case ((fullRow, row), rv, rvb) => - fullRow.set(rv) - val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex) - rvb.startArray(localNCols) - var i = 0 - while (i < localNCols) { - val entry = entries(i) - ec.set(2, localColValuesBc.value(i)) - ec.set(3, entry) - val results = f() - var j = 0 - rvb.startStruct() - while (j < types.length) { - rvb.addAnnotation(types(j), results(j)) - j += 1 - } - rvb.endStruct() - i += 1 - } - rvb.endArray() - }) - } - - - def dropEntries(fields: java.util.ArrayList[String]): MatrixTable = dropEntries(fields.asScala.toArray: _*) - - def dropEntries(fields: String*): MatrixTable = { - if (fields.isEmpty) - return this - assert(fields.areDistinct()) - val dropSet = fields.toSet - val allEntryFields = entryType.fieldNames.toSet - assert(fields.forall(allEntryFields.contains)) - - val keepIndices = entryType.fields - .filter(f => !dropSet.contains(f.name)) - .map(f => f.index) - .toArray - val newEntryType = TStruct(keepIndices.map(entryType.fields(_)).map(f => f.name -> f.typ): _*) - val fullRowType = rvRowType - val localEntriesIndex = entriesIndex - val localEntriesType = matrixType.entryArrayType - val localEntryType = entryType - val localNCols = numCols - // FIXME: replace with physical type - insertEntries(noOp)(newEntryType, { case (_, rv, rvb) => - val entriesOffset = fullRowType.loadField(rv, localEntriesIndex) - rvb.startArray(localNCols) - var i = 0 - while (i < localNCols) { - if (localEntriesType.isElementMissing(rv.region, entriesOffset, i)) - rvb.setMissing() - else { - val eltOffset = localEntriesType.loadElement(rv.region, entriesOffset, localNCols, i) - rvb.startStruct() - var j = 0 - while (j < keepIndices.length) { - rvb.addField(localEntryType, rv.region, eltOffset, keepIndices(j)) - j += 1 + val entryAST = Parser.parseToAST(expr, ec) + assert(entryAST.`type`.isInstanceOf[TStruct]) + + entryAST.toIR() match { + case Some(ir) => + new MatrixTable(hc, MapEntries(ast, ir)) + case None => + val (t, f) = Parser.parseExpr(expr, ec) + val newEntryType = t.asInstanceOf[TStruct] + val globalsBc = globals.broadcast + val fullRowType = rvRowType + val localEntriesIndex = entriesIndex + val localNCols = numCols + val localColValuesBc = colValuesBc + + insertEntries(() => { + val fullRow = new UnsafeRow(fullRowType) + val row = fullRow.deleteField(localEntriesIndex) + ec.set(0, globalsBc.value) + fullRow -> row + })(newEntryType, { case ((fullRow, row), rv, rvb) => + fullRow.set(rv) + ec.set(1, row) + val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex) + rvb.startArray(localNCols) + var i = 0 + while (i < localNCols) { + val entry = entries(i) + ec.set(2, localColValuesBc.value(i)) + ec.set(3, entry) + val result = f() + rvb.addAnnotation(newEntryType, result) + i += 1 } - rvb.endStruct() - } - i += 1 - } - rvb.endArray() - }) + rvb.endArray() + }) + } } def nPartitions: Int = rvd.partitions.length @@ -1568,61 +1504,6 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { new Table(hc, TableLiteral(TableValue(TableType(rvRowType.rename(m), rowKey, globalType), globals, rvd))) } - def annotateEntriesExpr(expr: String): MatrixTable = { - val symTab = Map( - "va" -> (0, rowType), - "sa" -> (1, colType), - "g" -> (2, entryType), - "global" -> (3, globalType)) - val ec = EvalContext(symTab) - - val globalsBc = globals.broadcast - - val (paths, types, f) = Parser.parseAnnotationExprs(expr, ec, Some(Annotation.ENTRY_HEAD)) - - val inserterBuilder = new ArrayBuilder[Inserter]() - val newEntryType = (paths, types).zipped.foldLeft(entryType) { case (gsig, (ids, signature)) => - val (s, i) = gsig.structInsert(signature, ids) - inserterBuilder += i - s - } - val inserters = inserterBuilder.result() - - val localNSamples = numCols - val fullRowType = rvRowType - val localColValuesBc = colValuesBc - val localEntriesIndex = entriesIndex - - insertEntries(() => { - val fullRow = new UnsafeRow(fullRowType) - val row = fullRow.deleteField(localEntriesIndex) - (fullRow, row) - })(newEntryType, { case ((fullRow, row), rv, rvb) => - fullRow.set(rv) - val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex) - - rvb.startArray(localNSamples) - - var i = 0 - while (i < localNSamples) { - val entry = entries(i) - ec.setAll(row, - localColValuesBc.value(i), - entry, - globalsBc.value) - - val newEntry = f().zip(inserters) - .foldLeft(entry) { case (ga, (a, inserter)) => - inserter(ga, a) - } - rvb.addAnnotation(newEntryType, newEntry) - - i += 1 - } - rvb.endArray() - }) - } - def filterCols(p: (Annotation, Int) => Boolean): MatrixTable = { copyAST(ast = MatrixLiteral(matrixType, value.filterCols(p))) } @@ -1761,7 +1642,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val localEntriesType = matrixType.entryArrayType assert(right.matrixType.entryArrayType == localEntriesType) - val joined = rvd.orderedJoinDistinct(right.rvd, "inner").mapPartitions({ it => + val joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue] = { it => val rvb = new RegionValueBuilder() val rv2 = RegionValue() @@ -1805,13 +1686,13 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { rv2.set(lrv.region, rvb.end()) rv2 } - }, preservesPartitioning = true) + } val newMatrixType = matrixType.copyParts() // move entries to the end copyMT(matrixType = newMatrixType, colValues = colValues ++ right.colValues, - rvd = OrderedRVD(rvd.typ, rvd.partitioner, joined)) + rvd = rvd.orderedJoinDistinct(right.rvd, "inner", joiner, rvd.typ)) } def makeKT(rowExpr: String, entryExpr: String, keyNames: Array[String] = Array.empty, seperator: String = "."): Table = { diff --git a/src/test/scala/is/hail/expr/TableIRSuite.scala b/src/test/scala/is/hail/expr/TableIRSuite.scala index e1d9bf6202c..ce8178fc7ac 100644 --- a/src/test/scala/is/hail/expr/TableIRSuite.scala +++ b/src/test/scala/is/hail/expr/TableIRSuite.scala @@ -26,7 +26,7 @@ class TableIRSuite extends SparkSuite { } @Test def testFilterGlobals() { - val kt = getKT.annotateGlobalExpr("g = 3") + val kt = getKT.selectGlobal("{g: 3}") val kt2 = new Table(hc, TableFilter(kt.tir, ir.ApplyBinaryPrimOp(ir.EQ(), ir.GetField(ir.Ref("row"), "field1"), ir.GetField(ir.Ref("global"), "g")))) assert(kt2.count() == 1) diff --git a/src/test/scala/is/hail/expr/ir/CompileSuite.scala b/src/test/scala/is/hail/expr/ir/CompileSuite.scala index 8a76c7130d2..79ad29f764c 100644 --- a/src/test/scala/is/hail/expr/ir/CompileSuite.scala +++ b/src/test/scala/is/hail/expr/ir/CompileSuite.scala @@ -494,7 +494,7 @@ class CompileSuite { val a2t = TArray(TString()) val a1 = In(0, TArray(TInt32())) val a2 = In(1, TArray(TString())) - val min = IRFunctionRegistry.lookupFunction("min", Seq(TArray(TInt32()))).get + val min = IRFunctionRegistry.lookupConversion("min", Seq(TArray(TInt32()))).get val range = ArrayRange(I32(0), min(Seq(MakeArray(Seq(ArrayLen(a1), ArrayLen(a2))))), I32(1)) val ir = ArrayMap(range, "i", MakeTuple(Seq(ArrayRef(a1, Ref("i")), ArrayRef(a2, Ref("i"))))) val region = Region() diff --git a/src/test/scala/is/hail/expr/ir/FunctionSuite.scala b/src/test/scala/is/hail/expr/ir/FunctionSuite.scala index 5e6302ae411..f597ee23f9f 100644 --- a/src/test/scala/is/hail/expr/ir/FunctionSuite.scala +++ b/src/test/scala/is/hail/expr/ir/FunctionSuite.scala @@ -58,10 +58,8 @@ class FunctionSuite { fb.result(Some(new PrintWriter(System.out)))() } - def lookup(meth: String, types: Type*)(irs: IR*): IR = { - val possible = IRFunctionRegistry.registry(meth) - IRFunctionRegistry.lookupFunction(meth, types).get(irs) - } + def lookup(meth: String, types: Type*)(irs: IR*): IR = + IRFunctionRegistry.lookupConversion(meth, types).get(irs) @Test def testCodeFunction() { @@ -110,10 +108,10 @@ class FunctionSuite { @Test def testVariableUnification() { - assert(IRFunctionRegistry.lookupFunction("testCodeUnification", Seq(TInt32(), TInt32())).isDefined) - assert(IRFunctionRegistry.lookupFunction("testCodeUnification", Seq(TInt64(), TInt32())).isEmpty) - assert(IRFunctionRegistry.lookupFunction("testCodeUnification", Seq(TInt64(), TInt64())).isEmpty) - assert(IRFunctionRegistry.lookupFunction("testCodeUnification2", Seq(TArray(TInt32()))).isDefined) + assert(IRFunctionRegistry.lookupConversion("testCodeUnification", Seq(TInt32(), TInt32())).isDefined) + assert(IRFunctionRegistry.lookupConversion("testCodeUnification", Seq(TInt64(), TInt32())).isEmpty) + assert(IRFunctionRegistry.lookupConversion("testCodeUnification", Seq(TInt64(), TInt64())).isEmpty) + assert(IRFunctionRegistry.lookupConversion("testCodeUnification2", Seq(TArray(TInt32()))).isDefined) } @Test diff --git a/src/test/scala/is/hail/io/ExportVCFSuite.scala b/src/test/scala/is/hail/io/ExportVCFSuite.scala index b1da4c2235d..b906afff185 100644 --- a/src/test/scala/is/hail/io/ExportVCFSuite.scala +++ b/src/test/scala/is/hail/io/ExportVCFSuite.scala @@ -9,6 +9,7 @@ import is.hail.io.vcf.ExportVCF import is.hail.utils._ import is.hail.variant.{MatrixTable, VSMSubgen, Variant} import org.testng.annotations.Test +import is.hail.testUtils._ import scala.io.Source import scala.language.postfixOps @@ -210,13 +211,13 @@ class ExportVCFSuite extends SparkSuite { TestUtils.interceptFatal("Invalid type for format field 'BOOL'. Found 'bool'.") { ExportVCF(vds - .annotateEntriesExpr("g = {BOOL: true}"), + .annotateEntriesExpr(("BOOL","true")), out) } TestUtils.interceptFatal("Invalid type for format field 'AA'.") { ExportVCF(vds - .annotateEntriesExpr("g = {AA: [[0]]}"), + .annotateEntriesExpr(("AA", "[[0]]")), out) } } @@ -281,15 +282,15 @@ class ExportVCFSuite extends SparkSuite { val callArrayFields = schema.fields.filter(fd => fd.typ == TArray(TCall())).map(_.name) val callSetFields = schema.fields.filter(fd => fd.typ == TSet(TCall())).map(_.name) - val callAnnots = callFields.map(name => s"g.$name = let c = g.$name in " + - s"if (c.ploidy == 0 || (c.ploidy == 1 && c.isPhased())) Call(0, 0, false) else c") + val callAnnots = callFields.map(name => (name, s"let c = g.$name in " + + s"if (c.ploidy == 0 || (c.ploidy == 1 && c.isPhased())) Call(0, 0, false) else c")) - val callContainerAnnots = (callArrayFields ++ callSetFields).map(name => s"g.$name = " + - s"g.$name.map(c => if (c.ploidy == 0 || (c.ploidy == 1 && c.isPhased())) Call(0, 0, false) else c)") + val callContainerAnnots = (callArrayFields ++ callSetFields).map(name => (name, + s"g.$name.map(c => if (c.ploidy == 0 || (c.ploidy == 1 && c.isPhased())) Call(0, 0, false) else c)")) val annots = callAnnots ++ callContainerAnnots - val vsmAnn = if (annots.nonEmpty) vsm.annotateEntriesExpr(annots.mkString(",")) else vsm + val vsmAnn = if (annots.nonEmpty) vsm.annotateEntriesExpr(annots: _*) else vsm hadoopConf.delete(out, recursive = true) ExportVCF(vsmAnn, out) diff --git a/src/test/scala/is/hail/io/HardCallsSuite.scala b/src/test/scala/is/hail/io/HardCallsSuite.scala index 824b0afceaa..6a272a6c491 100644 --- a/src/test/scala/is/hail/io/HardCallsSuite.scala +++ b/src/test/scala/is/hail/io/HardCallsSuite.scala @@ -9,7 +9,7 @@ import org.testng.annotations.Test class HardCallsSuite extends SparkSuite { @Test def test() { val p = forAll(MatrixTable.gen(hc, VSMSubgen.random)) { vds => - val hard = vds.selectEntries("g.GT") + val hard = vds.selectEntries("{GT: g.GT}") assert(hard.queryEntries("AGG.map(g => g.GT).counter()") == vds.queryEntries("AGG.map(g => g.GT).counter()")) diff --git a/src/test/scala/is/hail/linalg/LocalMatrixSuite.scala b/src/test/scala/is/hail/linalg/LocalMatrixSuite.scala deleted file mode 100644 index 38a755067ed..00000000000 --- a/src/test/scala/is/hail/linalg/LocalMatrixSuite.scala +++ /dev/null @@ -1,189 +0,0 @@ -package is.hail.linalg - -import is.hail.{SparkSuite, TestUtils} -import org.testng.annotations.Test -import breeze.linalg.{DenseVector => BDV} - -class LocalMatrixSuite extends SparkSuite { - - def assertEqual(lm1: LocalMatrix, lm2: LocalMatrix) { assert(lm1.m === lm2.m) } - def assertApproxEqual(lm1: LocalMatrix, lm2: LocalMatrix) { TestUtils.assertMatrixEqualityDouble(lm1.m, lm2.m) } - - @Test - def applyWriteRead() { - val fname = tmpDir.createTempFile("test") - - val m1 = LocalMatrix(2, 3, Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) - val m2 = LocalMatrix(2, 3, Array(1.0, 2.0, 0.0, 3.0, 4.0, 0.0, 5.0, 6.0, 0.0), 0, 3, isTransposed = false) - val m3 = LocalMatrix(2, 3, Array(0.0, 1.0, 2.0, 0.0, 3.0, 4.0, 0.0, 5.0, 6.0), 1, 3, isTransposed = false) - val m4 = LocalMatrix(2, 3, Array(1.0, 3.0, 5.0, 2.0, 4.0, 6.0), isTransposed = true) - val m5 = LocalMatrix(2, 3, Array(0.0, 1.0, 3.0, 5.0, 0.0, 2.0, 4.0, 6.0), 1, 4, isTransposed = true) - - for { m <- Seq(m1, m2, m3, m4, m5) } { - m.write(hc, fname) - assertEqual(LocalMatrix.read(hc, fname), m1) - } - - val v1 = LocalMatrix(Array(1.0, 2.0)) - val v2 = LocalMatrix(BDV[Double](1.0, 2.0)) - val v3 = LocalMatrix(new BDV[Double](Array(0.0, 1.0, 0.0, 2.0, 0.0, 0.0), offset = 1, stride = 2, length = 2)) - - for { v <- Seq(v1, v2, v3) } { - v.write(hc, fname) - assertEqual(LocalMatrix.read(hc, fname), v1) - } - } - - @Test - def checkShapesTest() { - import TestUtils.interceptFatal - - val x = LocalMatrix(1, 1, Array(2.0)) - val c = LocalMatrix(2, 1, Array(1.0, 2.0)) - val r = LocalMatrix(1, 3, Array(1.0, 2.0, 3.0)) - val m = LocalMatrix(2, 3, Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) - - LocalMatrix.checkShapes(m, m, "") - LocalMatrix.checkShapes(m, r, "") - LocalMatrix.checkShapes(m, c, "") - LocalMatrix.checkShapes(m, x, "") - LocalMatrix.checkShapes(r, m, "") - LocalMatrix.checkShapes(r, r, "") - LocalMatrix.checkShapes(r, c, "") - LocalMatrix.checkShapes(r, x, "") - LocalMatrix.checkShapes(c, m, "") - LocalMatrix.checkShapes(c, r, "") - LocalMatrix.checkShapes(c, c, "") - LocalMatrix.checkShapes(c, x, "") - LocalMatrix.checkShapes(x, m, "") - LocalMatrix.checkShapes(x, r, "") - LocalMatrix.checkShapes(x, c, "") - LocalMatrix.checkShapes(x, x, "") - - interceptFatal("Incompatible shapes") { LocalMatrix.checkShapes(m.t, m, "") } - interceptFatal("Incompatible shapes") { LocalMatrix.checkShapes(m.t, r, "") } - interceptFatal("Incompatible shapes") { LocalMatrix.checkShapes(m.t, c, "") } - interceptFatal("Incompatible shapes") { LocalMatrix.checkShapes(m, m.t, "") } - interceptFatal("Incompatible shapes") { LocalMatrix.checkShapes(m, r.t, "") } - interceptFatal("Incompatible shapes") { LocalMatrix.checkShapes(m, c.t, "") } - interceptFatal("Incompatible shapes") { LocalMatrix.checkShapes(r.t, m, "") } - interceptFatal("Incompatible shapes") { LocalMatrix.checkShapes(r.t, c, "") } - interceptFatal("Incompatible shapes") { LocalMatrix.checkShapes(r, m.t, "") } - interceptFatal("Incompatible shapes") { LocalMatrix.checkShapes(r, c.t, "") } - interceptFatal("Incompatible shapes") { LocalMatrix.checkShapes(c.t, m, "") } - interceptFatal("Incompatible shapes") { LocalMatrix.checkShapes(c.t, r, "") } - interceptFatal("Incompatible shapes") { LocalMatrix.checkShapes(c, m.t, "") } - interceptFatal("Incompatible shapes") { LocalMatrix.checkShapes(c, r.t, "") } - } - - @Test - def ops() { - import LocalMatrix.ops._ - import TestUtils.interceptFatal - - val e = 2.0 - val x = LocalMatrix(1, 1, Array(2.0)) - val c = LocalMatrix(2, 1, Array(1.0, 2.0)) - val r = LocalMatrix(1, 3, Array(1.0, 2.0, 3.0)) - val m = LocalMatrix(2, 3, Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) - - // add - assertEqual(x + x, e * x) - assertEqual(c + c, e * c) - assertEqual(r + r, e * r) - assertEqual(m + m, e * m) - - assertEqual(x + c, LocalMatrix(2, 1, Array(3.0, 4.0))) - assertEqual(x + c, c + x) - assertEqual(x + r, LocalMatrix(1, 3, Array(3.0, 4.0, 5.0))) - assertEqual(x + r, r + x) - assertEqual(x + m, LocalMatrix(2, 3, Array(3.0, 4.0, 5.0, 6.0, 7.0, 8.0))) - assertEqual(x + m, m + x) - assertEqual(x + m, e + m) - assertEqual(x + m, m + e) - - assertEqual(c + m, LocalMatrix(2, 3, Array(2.0, 4.0, 4.0, 6.0, 6.0, 8.0))) - assertEqual(c + m, m + c) - assertEqual(r + m, LocalMatrix(2, 3, Array(2.0, 3.0, 5.0, 6.0, 8.0, 9.0))) - assertEqual(r + m, m + r) - assertEqual(c + r, LocalMatrix(2, 3, Array(2.0, 3.0, 3.0, 4.0, 4.0, 5.0))) - assertEqual(c + r, r + c) - - interceptFatal("addition") { m.t + m } - - // subtract - assertEqual((x + x) - x, x) - assertEqual((c + c) - c, c) - assertEqual((r + r) - r, r) - assertEqual((m + m) - m, m) - - assertEqual(x - c, LocalMatrix(2, 1, Array(1.0, 0.0))) - assertEqual(x - c, -(c - x)) - assertEqual(x - r, LocalMatrix(1, 3, Array(1.0, 0.0, -1.0))) - assertEqual(x - r, -(r - x)) - assertEqual(x - m, LocalMatrix(2, 3, Array(1.0, 0.0, -1.0, -2.0, -3.0, -4.0))) - assertEqual(x - m, -(m - x)) - assertEqual(x - m, e - m) - assertEqual(x - m, -(m - e)) - - assertEqual(c - m, LocalMatrix(2, 3, Array(0.0, 0.0, -2.0, -2.0, -4.0, -4.0))) - assertEqual(c - m, -(m - c)) - assertEqual(r - m, LocalMatrix(2, 3, Array(0.0, -1.0, -1.0, -2.0, -2.0, -3.0))) - assertEqual(r - m, -(m - r)) - assertEqual(c - r, LocalMatrix(2, 3, Array(0.0, 1.0, -1.0, 0.0, -2.0, -1.0))) - assertEqual(c - r, -(r - c)) - - interceptFatal("subtraction") { m.t - m } - - // multiply - assertEqual(x * x, LocalMatrix(1, 1, Array(4.0))) - assertEqual(c * c, LocalMatrix(2, 1, Array(1.0, 4.0))) - assertEqual(r * r, LocalMatrix(1, 3, Array(1.0, 4.0, 9.0))) - assertEqual(m * m, LocalMatrix(2, 3, Array(1.0, 4.0, 9.0, 16.0, 25.0, 36.0))) - - assertEqual(x * c, LocalMatrix(2, 1, Array(2.0, 4.0))) - assertEqual(x * c, c * x) - assertEqual(x * r, LocalMatrix(1, 3, Array(2.0, 4.0, 6.0))) - assertEqual(x * r, r * x) - assertEqual(x * m, LocalMatrix(2, 3, Array(2.0, 4.0, 6.0, 8.0, 10.0, 12.0))) - assertEqual(x * m, m * x) - assertEqual(x * m, e * m) - assertEqual(x * m, m * e) - - assertEqual(c * m, LocalMatrix(2, 3, Array(1.0, 4.0, 3.0, 8.0, 5.0, 12.0))) - assertEqual(c * m, m * c) - assertEqual(r * m, LocalMatrix(2, 3, Array(1.0, 2.0, 6.0, 8.0, 15.0, 18.0))) - assertEqual(r * m, m * r) - assertEqual(c * r, LocalMatrix(2, 3, Array(1.0, 2.0, 2.0, 4.0, 3.0, 6.0))) - - interceptFatal("multiplication") { m.t * m } - - // divide - assertApproxEqual((x * x) / x, x) - assertApproxEqual((c * c) / c, c) - assertApproxEqual((r * r) / r, r) - assertApproxEqual((m * m) / m, m) - - assertApproxEqual(x / c, LocalMatrix(2, 1, Array(2.0, 1.0))) - assertApproxEqual(x / c, 1.0 / (c / x)) - assertApproxEqual(x / r, LocalMatrix(1, 3, Array(2.0, 1.0, 2.0 / 3))) - assertApproxEqual(x / r, 1.0 / (r / x)) - assertApproxEqual(x / m, LocalMatrix(2, 3, Array(2.0, 1.0, 2.0 / 3, 0.5, 0.4, 1.0 / 3))) - assertApproxEqual(x / m, 1.0 / (m / x)) - assertApproxEqual(x / m, e / m) - assertApproxEqual(x / m, 1.0 / (m / e)) - - assertApproxEqual(c / m, LocalMatrix(2, 3, Array(1.0, 1.0, 1.0 / 3, 0.5, 0.2, 1.0 / 3))) - assertApproxEqual(c / m, 1.0 / (m / c)) - assertApproxEqual(r / m, LocalMatrix(2, 3, Array(1.0, 0.5, 2.0 / 3, 0.5, 0.6, 0.5))) - assertApproxEqual(r / m, 1.0 / (m / r)) - assertApproxEqual(c / r, LocalMatrix(2, 3, Array(1.0, 2.0, 0.5, 1.0, 1.0 / 3, 2.0 / 3))) - assertApproxEqual(c / r, 1.0 / (r / c)) - - interceptFatal("division") { m.t / m } - - assertApproxEqual(m.sqrt(), LocalMatrix(2, 3, Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0).map(math.sqrt))) - - assertApproxEqual(m.pow(0.5), m.sqrt()) - } -} diff --git a/src/test/scala/is/hail/methods/PCRelateSuite.scala b/src/test/scala/is/hail/methods/PCRelateSuite.scala index 86deb11bddc..cf522ad9432 100644 --- a/src/test/scala/is/hail/methods/PCRelateSuite.scala +++ b/src/test/scala/is/hail/methods/PCRelateSuite.scala @@ -157,7 +157,8 @@ class PCRelateSuite extends SparkSuite { None } if (!fails.isEmpty) - fails.foreach(println _) + println(fails.length) + fails.foreach(println) assert(fails.isEmpty) } @@ -171,14 +172,34 @@ class PCRelateSuite extends SparkSuite { val pcr = new PCRelate(0.01, blockSize, PCRelate.PhiK2K0K1) val g = PCRelate.vdsToMeanImputedMatrix(vds) - val dmu = pcr.mu(BlockMatrix.fromIRM(g, blockSize), pcs).cache() + // blockedG : variant x sample val blockedG = BlockMatrix.fromIRM(g, blockSize) + + val predmu = pcr.mu(BlockMatrix.fromIRM(g, blockSize), pcs) + + val dmu = BlockMatrix.map2 { (g, mu) => + def badmu(mu: Double, maf: Double): Boolean = + mu <= maf || mu >= (1.0 - maf) || mu <= 0.0 || mu >= 1.0 + + def badgt(gt: Double): Boolean = + gt != 0.0 && gt != 1.0 && gt != 2.0 + + if (badgt(g) || badmu(mu, 0.01)) + Double.NaN + else + mu + } (blockedG, predmu).cache() + val actual = runPcRelateHail(vds, pcs, 0.01) val actual_g = blockedG.toBreezeMatrix().t val actual_ibs0 = pcr.ibs0(blockedG, dmu, blockSize).toBreezeMatrix() val actual_mean = dmu.toBreezeMatrix() + println(blockedG.toBreezeMatrix().toArray.count(g => g != 0.0 && g != 1.0 && g != 2.0)) + println(actual_ibs0) + println(truth_ibs0) + compareBDMs(actual_mean, truth_mu, tolerance=1e-14) compareBDMs(actual_ibs0, truth_ibs0, tolerance=1e-14) compareBDMs(actual_g, truth_g, tolerance=1e-14) diff --git a/src/test/scala/is/hail/methods/TableSuite.scala b/src/test/scala/is/hail/methods/TableSuite.scala index a387d7cfc70..f8966e4f711 100644 --- a/src/test/scala/is/hail/methods/TableSuite.scala +++ b/src/test/scala/is/hail/methods/TableSuite.scala @@ -202,7 +202,7 @@ class TableSuite extends SparkSuite { val ktRight = hc.importTable(inputFile2, impute = true).keyBy("Sample").rename(Map("Sample" -> "sample"), Map()) val ktBad = ktRight.keyBy("qPhen2") - intercept[HailException] { + intercept[Exception] { val ktJoinBad = ktLeft.join(ktBad, "left") assert(ktJoinBad.key sameElements Array("Sample")) } @@ -477,4 +477,10 @@ class TableSuite extends SparkSuite { val expectedSets = List(Set("A", "C", "E", "G"), Set("A", "C", "E", "H")).map(set => set.map(str => (str, true))) assert(expectedSets.contains(mis)) } + + @Test def testSelectGlobals() { + val kt = hc.importTable("src/test/resources/sampleAnnotations.tsv", impute = true) + val kt2 = kt.selectGlobal("{x: 5}").selectGlobal("{y: global.x}") + assert(kt2.globalSignature == TStruct("y" -> TInt32()) && kt2.globals.value.asInstanceOf[Row].get(0) == 5) + } } diff --git a/src/test/scala/is/hail/rvd/OrderedRVDPartitionerSuite.scala b/src/test/scala/is/hail/rvd/OrderedRVDPartitionerSuite.scala new file mode 100644 index 00000000000..8c93295ad3f --- /dev/null +++ b/src/test/scala/is/hail/rvd/OrderedRVDPartitionerSuite.scala @@ -0,0 +1,46 @@ +package is.hail.rvd + +import is.hail.annotations.UnsafeIndexedSeq +import is.hail.expr.types._ +import is.hail.utils.Interval +import org.apache.spark.sql.Row +import org.scalatest.testng.TestNGSuite +import org.testng.annotations.Test + +class OrderedRVDPartitionerSuite extends TestNGSuite { + val partitioner = + new OrderedRVDPartitioner( + Array("A", "B"), + TStruct(("A", TInt32()), ("C", TInt32()), ("B", TInt32())), + UnsafeIndexedSeq( + TArray(TInterval(TTuple(TInt32(), TInt32()), true), true), + IndexedSeq( + Interval(Row(1, 0), Row(4, 3), true, false), + Interval(Row(4, 3), Row(7, 9), true, false), + Interval(Row(7, 9), Row(10, 0), true, true))) + ) + + @Test def testGetPartitionPKWithPartitionKeys() { + assert(partitioner.getPartitionPK(Row(-1, 7)) == 0) + assert(partitioner.getPartitionPK(Row(4, 2)) == 0) + assert(partitioner.getPartitionPK(Row(4, 3)) == 1) + assert(partitioner.getPartitionPK(Row(5, -10259)) == 1) + assert(partitioner.getPartitionPK(Row(7, 8)) == 1) + assert(partitioner.getPartitionPK(Row(7, 9)) == 2) + assert(partitioner.getPartitionPK(Row(10, 0)) == 2) + assert(partitioner.getPartitionPK(Row(12, 19)) == 2) + } + + @Test def testGetPartitionPKWithLargerKeys() { + assert(partitioner.getPartitionPK(Row(0, 1, 3)) == 0) + assert(partitioner.getPartitionPK(Row(2, 7, "foo")) == 0) + assert(partitioner.getPartitionPK(Row(4, 2, 1, 2.7, "bar")) == 0) + assert(partitioner.getPartitionPK(Row(4, 3, 5)) == 1) + assert(partitioner.getPartitionPK(Row(7, 9, 7)) == 2) + assert(partitioner.getPartitionPK(Row(11, 1, 42)) == 2) + } + + // @Test def testGetPartitionPKWithSmallerKeys() { + // assert(partitioner.getPartitionPK(Row(2)) == 0) + // } +} diff --git a/src/test/scala/is/hail/stats/BaldingNicholsModelSuite.scala b/src/test/scala/is/hail/stats/BaldingNicholsModelSuite.scala index be38831bbb4..80d10ed496a 100644 --- a/src/test/scala/is/hail/stats/BaldingNicholsModelSuite.scala +++ b/src/test/scala/is/hail/stats/BaldingNicholsModelSuite.scala @@ -2,7 +2,7 @@ package is.hail.stats import breeze.stats._ import is.hail.SparkSuite -import is.hail.variant.{Call, Locus, Variant} +import is.hail.variant.{Call, Variant} import is.hail.testUtils._ import org.apache.spark.sql.Row import org.testng.Assert.assertEquals diff --git a/src/test/scala/is/hail/testUtils/RichTable.scala b/src/test/scala/is/hail/testUtils/RichTable.scala index a455842ca95..9b37dd75b01 100644 --- a/src/test/scala/is/hail/testUtils/RichTable.scala +++ b/src/test/scala/is/hail/testUtils/RichTable.scala @@ -43,7 +43,7 @@ class RichTable(ht: Table) { def rename(rowUpdateMap: Map[String, String], globalUpdateMap: Map[String, String]): Table = { select(ht.fieldNames.map(n => s"${ rowUpdateMap.getOrElse(n, n) } = row.$n")) .keyBy(ht.key.map(k => rowUpdateMap.getOrElse(k, k))) - .selectGlobal(ht.globalSignature.fieldNames.map(n => s"${ globalUpdateMap.getOrElse(n, n) } = global.$n"): _*) + .selectGlobal(ht.globalSignature.fieldNames.map(n => s"${ globalUpdateMap.getOrElse(n, n) } = global.$n")) } def select(exprs: Array[String]): Table = { @@ -126,4 +126,66 @@ class RichTable(ht: Table) { new Table(ht.hc, TableMapRows(ht.tir, newIR)) } } + + def selectGlobal(fields: Array[String]): Table = { + val ec = EvalContext("global" -> ht.globalSignature) + ec.set(0, ht.globals.value) + + val (paths, types, f) = Parser.parseSelectExprs(fields, ec) + + val names = paths.map { + case Left(n) => n + case Right(l) => l.last + } + + val overlappingPaths = names.counter().filter { case (n, i) => i != 1 }.keys + + if (overlappingPaths.nonEmpty) + fatal(s"Found ${ overlappingPaths.size } ${ plural(overlappingPaths.size, "selected field name") } that are duplicated.\n" + + "Overlapping fields:\n " + + s"@1", overlappingPaths.truncatable("\n ")) + + val inserterBuilder = new ArrayBuilder[Inserter]() + + val finalSignature = (names, types).zipped.foldLeft(TStruct()) { case (vs, (p, sig)) => + val (s: TStruct, i) = vs.insert(sig, p) + inserterBuilder += i + s + } + + val inserters = inserterBuilder.result() + + val newGlobal = f().zip(inserters) + .foldLeft(Row()) { case (a1, (v, inserter)) => + inserter(a1, v).asInstanceOf[Row] + } + + ht.copy2(globalSignature = finalSignature, globals = ht.globals.copy(value = newGlobal, t = finalSignature)) + } + + def annotateGlobalExpr(expr: String): Table = { + val ec = EvalContext("global" -> ht.globalSignature) + ec.set(0, ht.globals.value) + + val (paths, types, f) = Parser.parseAnnotationExprs(expr, ec, None) + + val inserterBuilder = new ArrayBuilder[Inserter]() + + val finalType = (paths, types).zipped.foldLeft(ht.globalSignature) { case (v, (ids, signature)) => + val (s, i) = v.insert(signature, ids) + inserterBuilder += i + s.asInstanceOf[TStruct] + } + + val inserters = inserterBuilder.result() + + val ga = inserters + .zip(f()) + .foldLeft(ht.globals.value) { case (a, (ins, res)) => + ins(a, res).asInstanceOf[Row] + } + + ht.copy2(globals = ht.globals.copy(value = ga, t = finalType), + globalSignature = finalType) + } } diff --git a/src/test/scala/is/hail/utils/RichMatrixTable.scala b/src/test/scala/is/hail/utils/RichMatrixTable.scala index 7ee3035e448..e462e5f7b3f 100644 --- a/src/test/scala/is/hail/utils/RichMatrixTable.scala +++ b/src/test/scala/is/hail/utils/RichMatrixTable.scala @@ -43,6 +43,9 @@ class RichMatrixTable(vsm: MatrixTable) { vsm.annotateCols(t, i) { case (_, i) => annotation(sampleIds(i)) } } + def annotateEntriesExpr(exprs: (String, String)*): MatrixTable = + vsm.selectEntries(s"annotate(g, {${ exprs.map { case (n, e) => s"`$n`: $e" }.mkString(",") }})") + def querySA(code: String): (Type, Querier) = { val st = Map(Annotation.COL_HEAD -> (0, vsm.colType)) val ec = EvalContext(st) diff --git a/src/test/scala/is/hail/variant/vsm/PartitioningSuite.scala b/src/test/scala/is/hail/variant/vsm/PartitioningSuite.scala index 96c0d0604c1..53d15175420 100644 --- a/src/test/scala/is/hail/variant/vsm/PartitioningSuite.scala +++ b/src/test/scala/is/hail/variant/vsm/PartitioningSuite.scala @@ -75,7 +75,7 @@ class PartitioningSuite extends SparkSuite { val mt = MatrixTable.fromRowsTable(Table.range(hc, 100, "idx", partitions=Some(6))) val orvdType = mt.matrixType.orvdType - mt.rvd.orderedJoinDistinct(OrderedRVD.empty(hc.sc, orvdType), "left").count() - mt.rvd.orderedJoinDistinct(OrderedRVD.empty(hc.sc, orvdType), "inner").count() + mt.rvd.orderedJoinDistinct(OrderedRVD.empty(hc.sc, orvdType), "left", _.map(_._1), orvdType).count() + mt.rvd.orderedJoinDistinct(OrderedRVD.empty(hc.sc, orvdType), "inner", _.map(_._1), orvdType).count() } } diff --git a/src/test/scala/is/hail/variant/vsm/VSMSuite.scala b/src/test/scala/is/hail/variant/vsm/VSMSuite.scala index f5fc0f54c0e..a6679d577cb 100644 --- a/src/test/scala/is/hail/variant/vsm/VSMSuite.scala +++ b/src/test/scala/is/hail/variant/vsm/VSMSuite.scala @@ -154,7 +154,7 @@ class VSMSuite extends SparkSuite { .indexRows("rowIdx") .indexCols("colIdx") - mt.selectEntries("x = (g.GT.nNonRefAlleles().toInt64 + va.rowIdx + sa.colIdx.toInt64 + 1L).toFloat64") + mt.selectEntries("{x: (g.GT.nNonRefAlleles().toInt64 + va.rowIdx + sa.colIdx.toInt64 + 1L).toFloat64}") .writeBlockMatrix(dirname, "x", blockSize) val data = mt.entriesTable() diff --git a/src/test/scala/is/hail/vds/JoinSuite.scala b/src/test/scala/is/hail/vds/JoinSuite.scala index 10ef3c963bc..1020844eb98 100644 --- a/src/test/scala/is/hail/vds/JoinSuite.scala +++ b/src/test/scala/is/hail/vds/JoinSuite.scala @@ -67,24 +67,23 @@ class JoinSuite extends SparkSuite { val localRowType = left.rvRowType // Inner distinct ordered join - val jInner = left.rvd.orderedJoinDistinct(right.rvd, "inner") + val jInner = left.rvd.orderedJoinDistinct(right.rvd, "inner", _.map(_._1), left.rvd.typ) val jInnerOrdRDD1 = left.rdd.join(right.rdd.distinct) assert(jInner.count() == jInnerOrdRDD1.count()) - assert(jInner.forall(jrv => jrv.rvLeft != null && jrv.rvRight != null)) - assert(jInner.map { jrv => - val ur = new UnsafeRow(localRowType, jrv.rvLeft) + assert(jInner.map { rv => + val ur = new UnsafeRow(localRowType, rv) ur.getAs[Locus](0) }.collect() sameElements jInnerOrdRDD1.map(_._1.asInstanceOf[Row].get(0)).collect().sorted(vType.ordering.toOrdering)) // Left distinct ordered join - val jLeft = left.rvd.orderedJoinDistinct(right.rvd, "left") + val jLeft = left.rvd.orderedJoinDistinct(right.rvd, "left", _.map(_._1), left.rvd.typ) val jLeftOrdRDD1 = left.rdd.leftOuterJoin(right.rdd.distinct) assert(jLeft.count() == jLeftOrdRDD1.count()) - assert(jLeft.forall(jrv => jrv.rvLeft != null)) - assert(jLeft.map { jrv => - val ur = new UnsafeRow(localRowType, jrv.rvLeft) + assert(jLeft.rdd.forall(rv => rv != null)) + assert(jLeft.map { rv => + val ur = new UnsafeRow(localRowType, rv) ur.getAs[Locus](0) }.collect() sameElements jLeftOrdRDD1.map(_._1.asInstanceOf[Row].get(0)).collect().sorted(vType.ordering.toOrdering)) }