From 567acfd7e2b42e5074ea46a24df75ce168b3de16 Mon Sep 17 00:00:00 2001 From: Patrick Schultz Date: Fri, 2 Dec 2022 10:39:07 -0500 Subject: [PATCH] [compiler] enable new rng (#12139) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add stream expressions * enable new rng * delete global seed management * fixes * persist = checkpoint * fixes * wip * fixup stream randomness handling * fixes * set global seed in test setup * try to catch who’s making a HailContext * start to update doctests * init hunt take 2 * refactor hail context management in tests * fixes * update random.rst * typo * fixes * fix docs * add dev doc * fix _seeded_func seed handling * address comments * delete ‘startTestHailContext’ in new python fs tests * fix bad doctest from merge * remove unittest to try to get pytest fixture used Co-authored-by: Dan King --- dev-docs/hail-query/randomness.md | 193 ++++++++++ hail/python/hail/__init__.py | 8 +- hail/python/hail/conftest.py | 11 +- hail/python/hail/context.py | 34 +- hail/python/hail/docs/api.rst | 1 + hail/python/hail/docs/functions/random.rst | 176 ++++++---- hail/python/hail/docs/guides/agg.rst | 10 +- hail/python/hail/docs/scans.rst | 20 +- .../hail/experimental/full_outer_join_mt.py | 22 +- hail/python/hail/experimental/ldscsim.py | 2 - .../hail/experimental/table_ndarray_utils.py | 2 + .../hail/expr/aggregators/aggregators.py | 42 ++- .../hail/expr/expressions/base_expression.py | 43 ++- hail/python/hail/expr/functions.py | 117 +++--- hail/python/hail/ir/__init__.py | 5 +- hail/python/hail/ir/base_ir.py | 15 +- hail/python/hail/ir/blockmatrix_ir.py | 16 +- hail/python/hail/ir/ir.py | 184 +++++----- hail/python/hail/ir/matrix_ir.py | 135 ++++--- hail/python/hail/ir/table_ir.py | 153 ++++---- hail/python/hail/ir/utils.py | 40 +-- hail/python/hail/linalg/blockmatrix.py | 4 +- hail/python/hail/matrixtable.py | 6 +- hail/python/hail/methods/impex.py | 18 +- hail/python/hail/methods/pca.py | 10 +- .../relatedness/identity_by_descent.py | 2 +- .../hail/methods/relatedness/pc_relate.py | 4 +- hail/python/hail/methods/statgen.py | 33 +- hail/python/hail/table.py | 2 +- hail/python/hail/utils/__init__.py | 3 +- hail/python/hail/utils/java.py | 16 +- hail/python/hail/utils/misc.py | 13 - hail/python/test/hail/conftest.py | 17 +- .../hail/experimental/test_annotation_db.py | 28 +- .../test/hail/experimental/test_codec.py | 3 - .../hail/experimental/test_experimental.py | 3 - .../hail/experimental/test_vcf_combiner.py | 5 +- hail/python/test/hail/expr/test_expr.py | 121 ++++++- hail/python/test/hail/expr/test_ndarrays.py | 3 - hail/python/test/hail/expr/test_show.py | 4 - hail/python/test/hail/expr/test_types.py | 3 - .../test/hail/fs/test_worker_driver_fs.py | 4 - hail/python/test/hail/genetics/test_call.py | 3 - hail/python/test/hail/genetics/test_locus.py | 3 - .../test/hail/genetics/test_pedigree.py | 3 - .../hail/genetics/test_reference_genome.py | 3 - hail/python/test/hail/helpers.py | 17 +- hail/python/test/hail/linalg/test_linalg.py | 3 - .../hail/matrixtable/test_file_formats.py | 7 +- .../matrixtable/test_grouped_matrix_table.py | 3 - .../hail/matrixtable/test_matrix_table.py | 332 +++++++++++++++++- .../relatedness/test_identity_by_descent.py | 6 +- .../methods/relatedness/test_pc_relate.py | 5 +- .../test/hail/methods/test_family_methods.py | 3 - hail/python/test/hail/methods/test_impex.py | 3 - hail/python/test/hail/methods/test_king.py | 5 +- hail/python/test/hail/methods/test_misc.py | 3 - hail/python/test/hail/methods/test_pca.py | 7 +- hail/python/test/hail/methods/test_qc.py | 3 - hail/python/test/hail/methods/test_statgen.py | 88 ++--- .../test/hail/table/test_grouped_table.py | 4 - hail/python/test/hail/table/test_table.py | 246 ++++++++++++- hail/python/test/hail/test_context.py | 4 - hail/python/test/hail/test_ir.py | 7 +- .../hail/utils/test_hl_hadoop_and_hail_fs.py | 5 - .../test/hail/utils/test_placement_tree.py | 4 - hail/python/test/hail/utils/test_utils.py | 32 +- hail/python/test/hail/vds/test_combiner.py | 5 +- hail/python/test/hail/vds/test_vds.py | 5 +- .../test/hail/vds/test_vds_functions.py | 5 - .../main/scala/is/hail/HailFeatureFlags.scala | 3 +- .../is/hail/backend/ExecuteContext.scala | 2 +- .../hail/backend/service/ServiceBackend.scala | 4 +- .../scala/is/hail/expr/ir/BlockMatrixIR.scala | 6 +- .../main/scala/is/hail/expr/ir/Children.scala | 4 +- .../src/main/scala/is/hail/expr/ir/Copy.scala | 6 +- .../src/main/scala/is/hail/expr/ir/Emit.scala | 57 ++- .../is/hail/expr/ir/EmitClassBuilder.scala | 28 +- hail/src/main/scala/is/hail/expr/ir/IR.scala | 13 +- .../scala/is/hail/expr/ir/InferType.scala | 2 +- .../scala/is/hail/expr/ir/LowerMatrixIR.scala | 9 +- .../main/scala/is/hail/expr/ir/Parser.scala | 10 +- .../main/scala/is/hail/expr/ir/Pretty.scala | 8 +- .../is/hail/expr/ir/PruneDeadFields.scala | 11 +- .../main/scala/is/hail/expr/ir/Random.scala | 279 +++++++-------- .../scala/is/hail/expr/ir/Requiredness.scala | 6 +- .../main/scala/is/hail/expr/ir/Simplify.scala | 4 +- .../main/scala/is/hail/expr/ir/TableIR.scala | 2 +- .../scala/is/hail/expr/ir/TypeCheck.scala | 4 +- .../is/hail/expr/ir/functions/Functions.scala | 118 +------ .../ir/functions/RandomSeededFunctions.scala | 200 +++++++---- .../expr/ir/lowering/LowerBlockMatrixIR.scala | 17 +- .../hail/expr/ir/lowering/LowerTableIR.scala | 2 +- .../main/scala/is/hail/expr/ir/package.scala | 9 +- .../scala/is/hail/linalg/BlockMatrix.scala | 31 +- .../types/physical/PCanonicalNDArray.scala | 8 + .../physical/stypes/concrete/SRNGState.scala | 101 ++++-- .../physical/stypes/interfaces/SNDArray.scala | 5 +- .../is/hail/expr/ir/FoldConstantsSuite.scala | 2 +- .../test/scala/is/hail/expr/ir/IRSuite.scala | 150 ++------ .../scala/is/hail/expr/ir/MatrixIRSuite.scala | 25 +- .../hail/expr/ir/RandomFunctionsSuite.scala | 153 -------- .../scala/is/hail/expr/ir/RandomSuite.scala | 192 ++++++++-- .../is/hail/linalg/BlockMatrixSuite.scala | 19 +- 104 files changed, 2246 insertions(+), 1589 deletions(-) create mode 100644 dev-docs/hail-query/randomness.md delete mode 100644 hail/src/test/scala/is/hail/expr/ir/RandomFunctionsSuite.scala diff --git a/dev-docs/hail-query/randomness.md b/dev-docs/hail-query/randomness.md new file mode 100644 index 00000000000..2c3df06a1ea --- /dev/null +++ b/dev-docs/hail-query/randomness.md @@ -0,0 +1,193 @@ +Our design for pseudo-random number generation is inspired by [1], but several details differ. At a high level, the idea is: +* Assign to each random function invocation some unique identifier. In general we can't bound the size of the identifier. We use arrays of longs. +* Use a construction of a psuedo-random function to map unique identifiers to random streams of bits. Intuitively, it's as if we used the identifier to seed a stateful RNG. + +The key property is that random function invocations with distinct identifiers produce independent random results, while invocations with the same identifier always produce the same result. Thus random function invocations are actually pure functions, with no side effects, which gives the compiler great freedom to optimize queries without affecting the results. + +Psuedo-random functions are important building blocks in cryptography, and so they are very well studied, with many different practical constructions. We use the PMAC message authentication code, which depends on a tweakable block cipher, for which we use a reduced-round Threefish. Either or both of these pieces could be replaced with little effort, e.g. to improve performance. + +# Threefish/Threefry: +We use the Threefish [2] block cipher, modified to use 20 rounds for efficiency (the full Threefish4x64 uses 72 rounds), as suggested by [3] (although we make use of the Threefish tweak). Reference implementation is `Threefry.encrypt`. + +`threefish4x64` takes: +* key `K = (k_0, ..., k_3)`: 4 words +* tweak `T = (t_0, t_1)`: 2 words +* plaintext `P = (p_0, ..., p_3)`: 4 words + +Intutively, this is a function taking a key and tweak as input, and returning a permutation on the space of all 256-bit blocks. The security claim is that if the key is chosen randomly, then for any choice of tweak, the resulting permutation "looks like" a uniformly chosen random permutation. + +Like most (all?) block ciphers, it is constructed as a sequence of simpler permutations. Think of shuffling a deck of cards: each shuffle isn't that random (is easily distinguishable from a completely random permutation), but a sequence of seven shuffles is indistinguishable from a random permutation. + +The simple permutations are called "rounds". Each round consists of applying a function "Mix" to pairs of 64-bit words, which is a bit-level permutation, followed by a permutation of the four words. + +threefish + +## key schedule +The key schedule turns the key and tweak into 6 subkeys, each 4 words. Subkey `s` is denoted `(k_{s,0}, ..., k_{s,3})`. + +First compute two additional words `k_4 = C ^ k_0 ^ k_1 ^ k_2 ^ k_3` and `t_2 = t0 ^ t_1`, where `C = 0x1BD11BDAA9FC1A22`. Then +``` +k_{s,0} = k_{s mod 5} +k_{s,1} = k_{s+1 mod 5} + t_{s mod 3} +k_{s,2} = k_{s+2 mod 5} + t_{s+1 mod 3} +k_{s,3} = k_{s+3 mod 5} + s +``` + +## an encryption round +Encryption is performed over 20 rounds. Let `v_i` be the `i`th word of the encryption state, initialized +``` +v_i = p_i +``` +Before round `d` if `d mod 4 = 0`, add subkey `s = d/4` +``` +v_i += k_{s,i} +``` +Then apply the `mix` function to adjacent pairs of words, where the rotation constant `r = R[d mod 8][j]` is looked up in a table. +``` +mix(v_{2j}, v_{2j+1}, r) +``` +`mix` is defined +``` +mix(x0, x1, r) { + x0 += x1 + rotL(x1, r) + x1 ^= x0 +} +``` +MIX + +Lastly, the words are permuted +``` +v_1, v_3 = v_3, v_1 +``` + +# PMAC +PMAC is a message authentication code. Intuitively, a MAC uses a block cipher to construct a function from abritrary length messages to 256 bit message tags. We extend this to a function from arbitrary length messages to "infinite" length message tags (really a very large finite length). The security claim is that if the block cipher used "looks like" a random permutation, then the MAC "looks like" a random function. In particular, for each message `m`, `pmac(m)` looks like a stream of random bits, and for distinct messages `m1` and `m2`, `pmac(m1)` and `pmac(m2)` look like completely independent streams of random bits. Yet this is a deterministic function, so computing `pmac` on the same message always produces the same stream of bits. + +Screen Shot 2022-10-25 at 11 53 13 AM + +Many MAC constructions must process blocks sequentially. As we'll see below, this would add significant overhead to random number generation. PMAC has the property that blocks of the message can be processed in any order. + +In our case, we use a modification of the PMAC1 construction in [4]. We restrict the message length to multiples of 64-bits for simplicity. Our modified PMAC is a function `pmac(nonce: Long, staticID: Long, message: Array[Long], counter: Long)`, defined as follows (reference implementation `Threefry.pmac`): +* Form a block `M[-1] = [nonce, staticID, 0L, 0L]`. +* Split `message` into blocks of 4 longs each, `M[0], ..., M[d]`, allowing the last block to be shorter. +* Let `E[i] = encrypt(key, [i, 0L], M[i])`, for `i=-1..d-1`, *all but the last block* +* Let `E` be the xor of all `E[i]` +* If the last block is not full, let `B` be `M[d]` padded by a single `1L` followed by `0L`s, to 4 longs. Otherwise, let `B = M[d]`. +* Compute the hash `H = E ^ B`. +* If the last block was full, compute the final MAC tag as + * `T = encrypt(key, [-2, counter], H)` +* otherwise + * `T = encrypt(key, [-3, counter], H)` + +The counter is used to enable generating long streams of random numbers for each message, not just a single 256 bit tag. The intuition is that each message (plus nonce and staticID) gets reduced to a 256 bit hash, such that distinct messages are highly unlikely to have the same hash. Then for each value of the counter, we use a distinct random function (really a random permutation) from the space of hashes to the space of random outputs. + +## Lazy computation +In practice, we don't need to save entire messages in memory. Instead we compute the hash on the fly. + +The new type is `RNGState`. A value of this type consists of the data: +* `runningSum: IndexedSeq[Long]`: the xor of the encrypted contents of all full blocks +* `lastDynBlock: IndexedSeq[Long]`: the partial contents of the last block. The length of the sequence is `numWordsInLastDynBlock` +* `numWordsInLastDynBlock: Int`: the number of words (longs), in the range `[0, 4)`, currently contained in `lastDynBlock` +* `hasStaticSplit: Boolean`: whether the static block has been incorporated into `runningSum` +* `numDynBlocks: Int`: the number of completed blocks, not including the static block + +This system is implemented using three IR nodes: +* `RNGStateLiteral` - creates an `RNGState` representing the empty message +* `RNGSplit(state: RNGState, dynBitstring: ?): RNGState` - appends to `lastDynBlock`. When the last block is full, encrypt it (using `numDynBlocks` for the tweak), and xor it into `runningSum`. Here `?` is either a single long, or an arbitrary sized tuple of longs. +* `ApplySeeded(..., rngState: RNGState, staticUID: Long)` + * Statically, forms the static block `[nonce, staticUID, 0L, 0L]`, encrypts it, and embeds the result as a literal in the code. + * At runtime, only needs to xor into the `runningSum` the encryped static block and the (possibly padded) `lastDynBlock`, and encrypt the result. Hence each `ApplySeeded` call only needs one invocation of the block cipher at runtime (more precisely, one invocation per 256 random bits needed by the random function). This minimizes the overhead of random number generation in inner loops, and is the reason for choose PMAC. + +# UIDs +To use the above PMAC scheme, we need to assign a "message" to every random function invocation in the program. As long as each invocation gets a distinct message, the PMAC random function generates approximately independent randomness for each invocation. + +We fix a key for the block cipher once and for all. It was generated randomly, and is hard coded in the compiler. This saves us from issues of users specifying "bad" keys. Instead, we reserve a part of the message to encode a session scoped uid. By changing that uid between sessions, we allow running identical code repeatedly with independent randomness. + +## Static UIDs +We split the message into static and dynamic components. The static component consists of two longs. The first, called the "rng nonce", is a hail session constant. It replaces the old "global seed", allowing the same pipeline to run with independent randomness each session, unless determinism is specifically requested. The second component is stored in the `ApplySeeded` IR node. We simply maintain a global counter, and increment it each time an `ApplySeeded` node is constructed, ensuring that each node in a pipeline has a distinct static uid. + +The dynamic component is needed to distinguish between different invocations of a single `ApplySeeded` node inside a looping construct. It is an arbitrary length message (though it will typically be quite small, probably less than 10 longs). It is constructed as follows: + +## Dynamic UIDs +Every stream, table, or matrix table pipeline is transformed to explicitly generate a unique uid per stream entry, table row, and matrix table row/column. These uids are explicit in the IR as ordinary values/fields, so the compiler automatically preserves the RNG determinism. + +## Putting it all together +Consider the example pipeline +``` +mt = hl.utils.range_matrix_table(10, 10) +mt = mt.annotate_entries(a=hl.range(10).map(lambda i: hl.rand_int32(100))) +``` +Before elaborating UIDs in the IR in python, the IR looks like this (after a little tidying): +``` +!1 = MatrixRead [DropRowColUIDs, ...] // don't add uid fields +!3 = MatrixMapEntries(!1) { +(%g, %col, %row, %entry) => + !c0 = I32 [0] + !c10 = I32 [10] + !c1 = I32 [1] + !s = StreamRange(!c0, !c10, !c1) [1, False] + !s2 = StreamMap(!s) { (%elt) => + !c100 = I32 [100] + ApplySeeded(!c100, %__rng_state) [rand_int32, 0, Int32] // unbound %__rng_state + } + !2 = ToArray(!s2) + InsertFields !entry (a: !2) +} +``` +Note that the `ApplySeeded` node is tageed with a static UID `0`, and references an unbound variable `__rng_state`. It is the responsibility of the `handle_randomness` pass to give proper definitions of `__rng_state` in any scope that needs it. After `handle_randomness` (and some more tidying), the IR looks like: +``` +// Now MatrixRead adds row and col uids +!1 = MatrixRead [None, False, False, (MatrixRangeReader MatrixRangeReaderParameters(10,10,None) 8)] +!11 = MatrixMapEntries(!1) { +(%g, %col, %row, %entry) => + !2 = RNGStateLiteral // RNGState corresponding to empty message + !3 = GetField(%row) [__row_uid] // get row and col uids + !4 = GetField(%col) [__col_uid] + !5 = MakeTuple(!3, !4) [(0 1)] + %6 = RNGSplit(!2, !5) // append row and col uids to message + !c0 = I32 [0] + !c10 = I32 [10] + !c1 = I32 [1] + !s = StreamRange(!c0, !c10, !c1) [1, False] + !s2 = StreamMap(!s) { (%elt) => + !7 = Cast(%elt) [Int64] + MakeTuple(!7, %elt) [(0 1)] // map to stream of (uid, elt) pairs + } + !s3 = StreamMap(!s2) { (%elt2) => + !8 = GetTupleElement(%elt2) [0] + %9 = RNGSplit(%6, !8) // append stream element uid to message + !c100 = I32 [100] + // call random function with current message/RNGState %9 and static uid 0 + ApplySeeded(!c100, %9) [rand_int32, 0, Int32] + } + !10 = ToArray(!s3) + InsertFields !entry (a: !10) +} +``` +Note that because only 3 longs are added to the message, none of the `RNGSplit` calls generate any runtime code. They simply encode statically that the last block of the message at the time of the `ApplySeeded` call consists of the locals `[!3, !4, !8]`. Then the `ApplySeeded` just needs to pad the last block, xor it with the running sum (which is the encrypted static block, embedded as a constant in the code), and call the Threefry `encrypt` function just once. + +# Security +Cryptogrophers have developed a very pragmatic theory of what makes for "good" pseudorandomness. One of the benefits of using cryptographic primitives (even while weakening some of the components for performance, as we do with Threefish) is that we can use this framework to evaluate how well users can trust the outputs of the RNG. + +Using this theory for a quick sanity check, consider a pipeline with a 1e7 row by 1e7 column matrixtable, with 1e4 random function invocations per entry, running for a year on 1e23 cores. Let `b` be any boolean output of this pipeline. Let `P_1` and `P_2` be the probabilities that `b=1` in the scenarios where random functions are truly random, and using the above scheme, respectively. Then `abs(P_1 - P_2) < 3e-40`. + +The only assumption in this bound is that Threefry is a secure block cipher, i.e. that the best attack against it is a brute force search of the space of all keys. The time bound comes from limiting how much of the key space the program is able to search. Clearly this will never be the weak link, and we can focus on how many random numbers are generated. + +This is a very practically reasurring result. It says that users can really trust that their results--interpreted under a model of true randomness--are not skewed by our implementation of pseudorandomness. + +# User interface +For the most part, users should not need to interact directly with the randomness infrastructure. If they don't, the default semantics are: +* Evaluating a hail expression multiple times in the same session always produces the same results +* Rebuilding an identical hail expression (e.g. `x = hl.rand_unif()` and `y = hl.rand_unif()`) evaluates with independent randomness. +* Running the same pipeline in multiple hail sessions uses independend randomness each time. + +The last two can be overridden if needed: +* To build identical expressions using the same randomness, manually specify "seeds" (should we rename this?) on each random function call. E.g. `x = hl.rand_unif(seed=0)`. This overrides using the global counter to populate the static uid. It is guaranteed that user specified static uids never clash with automatically generated ones. +* To run the same pipeline in multiple sessions with the same randomness, manually specify the "global seed" on init: `hl.init(global_seed=0)`. + + +[1] "Splittable pseudorandom number generators using cryptographic hashing" +[2] "The Skein Hash Function Family" +[3] "Parallel random numbers: as easy as 1, 2, 3" +[4] Rogaway, "Efficient Instantiations of Tweakable Blockciphers and Refinements to Modes OCB and PMAC" diff --git a/hail/python/hail/__init__.py b/hail/python/hail/__init__.py index d28101ea451..0ced5b61477 100644 --- a/hail/python/hail/__init__.py +++ b/hail/python/hail/__init__.py @@ -53,9 +53,10 @@ hadoop_is_dir, hadoop_scheme_supported, copy_log) from .context import (init, init_local, init_batch, stop, spark_context, tmp_dir, # noqa: E402 - default_reference, get_reference, set_global_seed, _set_flags, _get_flags, _with_flags, - _async_current_backend, current_backend, debug_info, citation, cite_hail, - cite_hail_bibtex, version, TemporaryFilename, TemporaryDirectory) + default_reference, get_reference, set_global_seed, reset_global_randomness, + _set_flags, _get_flags, _with_flags, _async_current_backend, + current_backend, debug_info, citation, cite_hail, cite_hail_bibtex, + version, TemporaryFilename, TemporaryDirectory) scan = agg.aggregators.ScanFunctions({name: getattr(agg, name) for name in agg.__all__}) @@ -71,6 +72,7 @@ 'default_reference', 'get_reference', 'set_global_seed', + 'reset_global_randomness', '_set_flags', '_get_flags', '_with_flags', diff --git a/hail/python/hail/conftest.py b/hail/python/hail/conftest.py index 7d7b36562e5..302d358c680 100644 --- a/hail/python/hail/conftest.py +++ b/hail/python/hail/conftest.py @@ -33,6 +33,10 @@ def init(doctest_namespace): olddir = os.getcwd() os.chdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "docs")) + + hl.init(global_seed=0) + hl.reset_global_randomness() + try: generate_datasets(doctest_namespace) print("finished setting up doctest...") @@ -41,12 +45,17 @@ def init(doctest_namespace): os.chdir(olddir) +@pytest.fixture(autouse=True) +def reset_randomness(init): + hl.reset_global_randomness() + + def generate_datasets(doctest_namespace): doctest_namespace['hl'] = hl doctest_namespace['np'] = np ds = hl.import_vcf('data/sample.vcf.bgz') - ds = ds.sample_rows(0.03) + ds = ds.sample_rows(0.035) ds = ds.annotate_rows(use_as_marker=hl.rand_bool(0.5), panel_maf=0.1, anno1=5, diff --git a/hail/python/hail/context.py b/hail/python/hail/context.py index 039d7ebec4f..07528a309bc 100644 --- a/hail/python/hail/context.py +++ b/hail/python/hail/context.py @@ -4,6 +4,7 @@ import os from contextlib import contextmanager from urllib.parse import urlparse, urlunparse +from random import Random import pkg_resources from pyspark import SparkContext @@ -115,14 +116,13 @@ def __init__(self, log, quiet, append, tmpdir, local_tmpdir, global_seed, backen ' the latest changes weekly.\n') sys.stderr.write(f'LOGGING: writing to {log}\n') + self._user_specified_rng_nonce = True if global_seed is None: - if Env._seed_generator is None: - Env.set_seed(6348563392232659379) - else: # global_seed is not None - if Env._seed_generator is not None: - raise ValueError( - 'Do not call hl.init with a non-None global seed *after* calling hl.set_global_seed') - Env.set_seed(global_seed) + if 'rng_nonce' not in backend.get_flags('rng_nonce'): + backend.set_flags(rng_nonce=hex(Random().randrange(-2**63, 2**63 - 1))) + self._user_specified_rng_nonce = False + else: + backend.set_flags(rng_nonce=hex(global_seed)) Env._hc = self def initialize_references(self, references, default_reference): @@ -805,7 +805,12 @@ def get_reference(name) -> ReferenceGenome: @typecheck(seed=int) def set_global_seed(seed): - """Sets Hail's global seed to `seed`. + """Deprecated. + + Has no effect. To ensure reproducible randomness, use the `global_seed` + argument to :func:`.init` and :func:`.reset_global_randomness`. + + See the :ref:`random functions ` reference docs for more. Parameters ---------- @@ -813,7 +818,18 @@ def set_global_seed(seed): Integer used to seed Hail's random number generator """ - Env.set_seed(seed) + warning('hl.set_global_seed has no effect. See ' + 'https://hail.is/docs/0.2/functions/random.html for details on ' + 'ensuring reproducibility of randomness.') + pass + + +@typecheck() +def reset_global_randomness(): + """Restore global randomness to initial state for test reproducibility. + """ + + Env.reset_global_randomness() def _set_flags(**flags): diff --git a/hail/python/hail/docs/api.rst b/hail/python/hail/docs/api.rst index 1b6a07ad73d..3d269c1466a 100644 --- a/hail/python/hail/docs/api.rst +++ b/hail/python/hail/docs/api.rst @@ -56,5 +56,6 @@ Top-Level Functions .. autofunction:: hail.default_reference .. autofunction:: hail.get_reference .. autofunction:: hail.set_global_seed +.. autofunction:: hail.reset_global_randomness .. autofunction:: hail.citation .. autofunction:: hail.version diff --git a/hail/python/hail/docs/functions/random.rst b/hail/python/hail/docs/functions/random.rst index bd72ed99bfd..4e557aecbb7 100644 --- a/hail/python/hail/docs/functions/random.rst +++ b/hail/python/hail/docs/functions/random.rst @@ -1,3 +1,5 @@ +.. _sec-random-functions: + Random functions ---------------- @@ -18,20 +20,20 @@ a random number generated with the function :func:`.rand_unif`: The value of `x` will not change, although other calls to :func:`.rand_unif` will generate different values: - >>> hl.eval(x) # doctest: +SKIP_OUTPUT_CHECK - 0.5562065047992025 + >>> hl.eval(x) + 0.9828239225846387 - >>> hl.eval(x) # doctest: +SKIP_OUTPUT_CHECK - 0.5562065047992025 + >>> hl.eval(x) + 0.9828239225846387 - >>> hl.eval(hl.rand_unif(0, 1)) # doctest: +SKIP_OUTPUT_CHECK - 0.4678132874101748 + >>> hl.eval(hl.rand_unif(0, 1)) + 0.49094525115847415 - >>> hl.eval(hl.rand_unif(0, 1)) # doctest: +SKIP_OUTPUT_CHECK - 0.9097632224065403 + >>> hl.eval(hl.rand_unif(0, 1)) + 0.3972543766997359 - >>> hl.eval(hl.array([x, x, x])) # doctest: +SKIP_OUTPUT_CHECK - [0.5562065047992025, 0.5562065047992025, 0.5562065047992025] + >>> hl.eval(hl.array([x, x, x])) + [0.9828239225846387, 0.9828239225846387, 0.9828239225846387] If the three values in the last expression should be distinct, three separate calls to :func:`.rand_unif` should be made: @@ -39,26 +41,27 @@ calls to :func:`.rand_unif` should be made: >>> a = hl.rand_unif(0, 1) >>> b = hl.rand_unif(0, 1) >>> c = hl.rand_unif(0, 1) - >>> hl.eval(hl.array([a, b, c])) # doctest: +SKIP_OUTPUT_CHECK - [0.8846327207915881, 0.14415148553468504, 0.8202677741734825] + >>> hl.eval(hl.array([a, b, c])) + [0.992090957001768, 0.9564448098124774, 0.3905029525642664] Within the rows of a :class:`.Table`, the same expression will yield a consistent value within each row, but different (random) values across rows: >>> table = hl.utils.range_table(5, 1) >>> table = table.annotate(x1=x, x2=x, rand=hl.rand_unif(0, 1)) - >>> table.show() # doctest: +SKIP_OUTPUT_CHECK - +-------+-------------+-------------+-------------+ - | idx | x1 | x2 | rand | - +-------+-------------+-------------+-------------+ - | int32 | float64 | float64 | float64 | - +-------+-------------+-------------+-------------+ - | 0 | 8.50369e-01 | 8.50369e-01 | 9.64129e-02 | - | 1 | 5.15437e-01 | 5.15437e-01 | 8.60843e-02 | - | 2 | 5.42493e-01 | 5.42493e-01 | 1.69816e-01 | - | 3 | 5.51289e-01 | 5.51289e-01 | 6.48706e-01 | - | 4 | 6.40977e-01 | 6.40977e-01 | 8.22508e-01 | - +-------+-------------+-------------+-------------+ + >>> table.show() + +-------+----------+----------+----------+ + | idx | x1 | x2 | rand | + +-------+----------+----------+----------+ + | int32 | float64 | float64 | float64 | + +-------+----------+----------+----------+ + | 0 | 4.68e-01 | 4.68e-01 | 6.36e-01 | + | 1 | 8.24e-01 | 8.24e-01 | 9.72e-01 | + | 2 | 7.33e-01 | 7.33e-01 | 1.43e-01 | + | 3 | 8.99e-01 | 8.99e-01 | 5.52e-01 | + | 4 | 4.03e-01 | 4.03e-01 | 3.50e-01 | + +-------+----------+----------+----------+ + The same is true of the rows, columns, and entries of a :class:`.MatrixTable`. @@ -69,51 +72,86 @@ All random functions can take a specified seed as an argument. This guarantees that multiple invocations of the same function within the same context will return the same result, e.g. - >>> hl.eval(hl.rand_unif(0, 1, seed=0)) # doctest: +SKIP_OUTPUT_CHECK - 0.5488135008937808 - - >>> hl.eval(hl.rand_unif(0, 1, seed=0)) # doctest: +SKIP_OUTPUT_CHECK - 0.5488135008937808 - -This does not guarantee the same behavior across different contexts; e.g., the -rows may have different values if the expression is applied to different tables: - - >>> table = hl.utils.range_table(5, 1).annotate(x=hl.rand_bool(0.5, seed=0)) - >>> table.x.collect() # doctest: +SKIP_OUTPUT_CHECK - [0.5488135008937808, - 0.7151893652121089, - 0.6027633824638369, - 0.5448831893094143, - 0.42365480398481625] - - >>> table = hl.utils.range_table(5, 1).annotate(x=hl.rand_bool(0.5, seed=0)) - >>> table.x.collect() # doctest: +SKIP_OUTPUT_CHECK - [0.5488135008937808, - 0.7151893652121089, - 0.6027633824638369, - 0.5448831893094143, - 0.42365480398481625] - - >>> table = hl.utils.range_table(5, 5).annotate(x=hl.rand_bool(0.5, seed=0)) - >>> table.x.collect() # doctest: +SKIP_OUTPUT_CHECK - [0.5488135008937808, - 0.9595974306263271, - 0.42205690070893265, - 0.828743805759555, - 0.6414977904324134] - -The seed can also be set globally using :func:`.set_global_seed`. This sets the -seed globally for all subsequent Hail operations, and a pipeline will be -guaranteed to have the same results if the global seed is set right beforehand: - - >>> hl.set_global_seed(0) - >>> hl.eval(hl.array([hl.rand_unif(0, 1), hl.rand_unif(0, 1)])) # doctest: +SKIP_OUTPUT_CHECK - [0.6830630912401323, 0.4035978197966855] - - >>> hl.set_global_seed(0) - >>> hl.eval(hl.array([hl.rand_unif(0, 1), hl.rand_unif(0, 1)])) # doctest: +SKIP_OUTPUT_CHECK - [0.6830630912401323, 0.4035978197966855] - + >>> hl.eval(hl.rand_unif(0, 1, seed=0)) + 0.2664972565962568 + + >>> hl.eval(hl.rand_unif(0, 1, seed=0)) + 0.2664972565962568 + + >>> table = hl.utils.range_table(5, 1).annotate(x=hl.rand_unif(0, 1, seed=0)) + >>> table.x.collect() + [0.5820244750020055, + 0.33150686392731943, + 0.20526631289173847, + 0.6964416913998893, + 0.6092952493383876] + + >>> table = hl.utils.range_table(5, 5).annotate(x=hl.rand_unif(0, 1, seed=0)) + >>> table.x.collect() + [0.5820244750020055, + 0.33150686392731943, + 0.20526631289173847, + 0.6964416913998893, + 0.6092952493383876] + +However, moving it to a sufficiently different context will produce different +results: + + >>> table = hl.utils.range_table(7, 1) + >>> table = table.filter(table.idx >= 2).annotate(x=hl.rand_unif(0, 1, seed=0)) + >>> table.x.collect() + [0.20526631289173847, + 0.6964416913998893, + 0.6092952493383876, + 0.6404026938964441, + 0.5550464170615771] + +In fact, in this case we are getting the tail of + + >>> table = hl.utils.range_table(7, 1).annotate(x=hl.rand_unif(0, 1, seed=0)) + >>> table.x.collect() + [0.5820244750020055, + 0.33150686392731943, + 0.20526631289173847, + 0.6964416913998893, + 0.6092952493383876, + 0.6404026938964441, + 0.5550464170615771] + +Reproducibility across sessions +=============================== + +The values of a random function are fully determined by three things: + +* The seed set on the function itself. If not specified, these are simply + generated sequentially. +* Some data uniquely identifying the current position within a larger context, + e.g. Table, MatrixTable, or array. For instance, in a :func:`.range_table`, + this data is simply the row id, as suggested by the previous examples. +* The global seed. This is fixed for the entire session, and can only be set + using the ``global_seed`` argument to :func:`.init`. + +To ensure reproducibility within a single hail session, it suffices to either +manually set the seed on every random function call, or to call +:func:`.reset_global_randomness` at the start of a pipeline, which resets the +counter used to generate seeds. + + >>> hl.reset_global_randomness() + >>> hl.eval(hl.array([hl.rand_unif(0, 1), hl.rand_unif(0, 1)])) + [0.9828239225846387, 0.49094525115847415] + + >>> hl.reset_global_randomness() + >>> hl.eval(hl.array([hl.rand_unif(0, 1), hl.rand_unif(0, 1)])) + [0.9828239225846387, 0.49094525115847415] + +To ensure reproducibility across sessions, one must in addition specify the +`global_seed` in :func:`.init`. If not specified, the global seed is chosen +randomly. All documentation examples were computed using ``global_seed=0``. + + >>> hl.stop() # doctest: +SKIP + >>> hl.init(global_seed=0) # doctest: +SKIP + >>> hl.eval(hl.array([hl.rand_unif(0, 1), hl.rand_unif(0, 1)])) # doctest: +SKIP + [0.9828239225846387, 0.49094525115847415] .. autosummary:: diff --git a/hail/python/hail/docs/guides/agg.rst b/hail/python/hail/docs/guides/agg.rst index 3185076037b..58ff362375f 100644 --- a/hail/python/hail/docs/guides/agg.rst +++ b/hail/python/hail/docs/guides/agg.rst @@ -98,7 +98,7 @@ One aggregation :**code**: >>> mt.aggregate_cols(hl.agg.fraction(mt.pheno.is_female)) - 0.48 + 0.44 :**dependencies**: :meth:`.MatrixTable.aggregate_cols`, :func:`.aggregators.fraction` @@ -114,7 +114,7 @@ Multiple aggregations >>> mt.aggregate_cols(hl.struct( ... fraction_female=hl.agg.fraction(mt.pheno.is_female), ... case_ratio=hl.agg.count_where(mt.is_case) / hl.agg.count())) - Struct(fraction_female=0.48, case_ratio=1.0) + Struct(fraction_female=0.44, case_ratio=1.0) :**dependencies**: :meth:`.MatrixTable.aggregate_cols`, :func:`.aggregators.fraction`, :func:`.aggregators.count_where`, :class:`.StructExpression` @@ -129,7 +129,7 @@ One aggregation :**code**: >>> mt.aggregate_rows(hl.agg.mean(mt.qual)) - 544323.8915384616 + 140054.73333333334 :**dependencies**: :meth:`.MatrixTable.aggregate_rows`, :func:`.aggregators.mean` @@ -148,7 +148,7 @@ Multiple aggregations >>> mt.aggregate_rows( ... hl.struct(n_high_quality=hl.agg.count_where(mt.qual > 40), ... mean_qual=hl.agg.mean(mt.qual))) - Struct(n_high_quality=13, mean_qual=544323.8915384616) + Struct(n_high_quality=9, mean_qual=140054.73333333334) :**dependencies**: :meth:`.MatrixTable.aggregate_rows`, :func:`.aggregators.count_where`, :func:`.aggregators.mean`, :class:`.StructExpression` @@ -167,7 +167,7 @@ Aggregate Entry Values Into A Local Value >>> mt.aggregate_entries( ... hl.struct(global_gq_mean=hl.agg.mean(mt.GQ), ... call_rate=hl.agg.fraction(hl.is_defined(mt.GT)))) - Struct(global_gq_mean=64.01841473178543, call_rate=0.9607692307692308) + Struct(global_gq_mean=69.60514541387025, call_rate=0.9933333333333333) :**dependencies**: :meth:`.MatrixTable.aggregate_entries`, :func:`.aggregators.mean`, :func:`.aggregators.fraction`, :class:`.StructExpression` diff --git a/hail/python/hail/docs/scans.rst b/hail/python/hail/docs/scans.rst index 517c87cee1d..804345b6cd0 100644 --- a/hail/python/hail/docs/scans.rst +++ b/hail/python/hail/docs/scans.rst @@ -45,18 +45,16 @@ along the genome: +---------------+------------+-----------+---------------+ | locus | array | int64 | int64 | +---------------+------------+-----------+---------------+ - | 20:12990057 | ["T","A"] | 43 | 0 | - | 20:13029862 | ["C","T"] | 99 | 43 | - | 20:13074235 | ["G","A"] | 99 | 142 | - | 20:13140720 | ["G","A"] | 5 | 241 | - | 20:13695498 | ["G","A"] | 25 | 246 | - | 20:13714384 | ["A","C"] | 1 | 271 | - | 20:13765944 | ["C","G"] | 2 | 272 | - | 20:13765954 | ["C","T"] | 2 | 274 | - | 20:13845987 | ["C","T"] | 100 | 276 | - | 20:16223957 | ["T","C"] | 31 | 376 | + | 20:10579373 | ["C","T"] | 1 | 0 | + | 20:10579398 | ["C","T"] | 1 | 1 | + | 20:10627772 | ["C","T"] | 2 | 2 | + | 20:10633237 | ["G","A"] | 69 | 4 | + | 20:10636995 | ["C","T"] | 2 | 73 | + | 20:10639222 | ["G","A"] | 22 | 75 | + | 20:13763601 | ["A","G"] | 2 | 97 | + | 20:16223922 | ["T","C"] | 66 | 99 | + | 20:17479617 | ["G","A"] | 9 | 165 | +---------------+------------+-----------+---------------+ - showing top 10 rows Scans over column fields can be done in a similar manner. diff --git a/hail/python/hail/experimental/full_outer_join_mt.py b/hail/python/hail/experimental/full_outer_join_mt.py index 44a9e839bad..b5daea0c149 100644 --- a/hail/python/hail/experimental/full_outer_join_mt.py +++ b/hail/python/hail/experimental/full_outer_join_mt.py @@ -20,7 +20,7 @@ def full_outer_join_mt(left: hl.MatrixTable, right: hl.MatrixTable) -> hl.Matrix genotypes for loci 1:1 and 1:2 because those loci are not present in `mt2` and these samples are not present in `mt1` - >>> hl.set_global_seed(0) + >>> hl.reset_global_randomness() >>> mt1 = hl.balding_nichols_model(1, 2, 3) >>> mt2 = hl.balding_nichols_model(1, 2, 3) >>> mt2 = mt2.key_rows_by(locus=hl.locus(mt2.locus.contig, @@ -34,19 +34,19 @@ def full_outer_join_mt(left: hl.MatrixTable, right: hl.MatrixTable) -> hl.Matrix | locus | array | call | call | +---------------+------------+------+------+ | 1:1 | ["A","C"] | 0/1 | 0/1 | - | 1:2 | ["A","C"] | 1/1 | 1/1 | + | 1:2 | ["A","C"] | 0/0 | 1/1 | | 1:3 | ["A","C"] | 0/0 | 0/0 | +---------------+------------+------+------+ - >>> mt2.show() # doctest: +SKIP_OUTPUT_CHECK + >>> mt2.show() +---------------+------------+------+------+ - | locus | alleles | 0.GT | 1.GT | + | locus | alleles | 2.GT | 3.GT | +---------------+------------+------+------+ | locus | array | call | call | +---------------+------------+------+------+ - | 1:3 | ["A","C"] | 0/1 | 1/1 | - | 1:4 | ["A","C"] | 0/1 | 0/1 | - | 1:5 | ["A","C"] | 1/1 | 0/0 | + | 1:3 | ["A","C"] | 0/0 | 0/1 | + | 1:4 | ["A","C"] | 1/1 | 0/1 | + | 1:5 | ["A","C"] | 0/0 | 0/0 | +---------------+------------+------+------+ >>> mt3 = hl.experimental.full_outer_join_mt(mt1, mt2) @@ -58,10 +58,10 @@ def full_outer_join_mt(left: hl.MatrixTable, right: hl.MatrixTable) -> hl.Matrix | locus | array | call | call | call | call | +---------------+------------+------+------+------+------+ | 1:1 | ["A","C"] | 0/1 | 0/1 | NA | NA | - | 1:2 | ["A","C"] | 1/1 | 1/1 | NA | NA | - | 1:3 | ["A","C"] | 0/0 | 0/0 | 0/1 | 1/1 | - | 1:4 | ["A","C"] | NA | NA | 0/1 | 0/1 | - | 1:5 | ["A","C"] | NA | NA | 1/1 | 0/0 | + | 1:2 | ["A","C"] | 0/0 | 1/1 | NA | NA | + | 1:3 | ["A","C"] | 0/0 | 0/0 | 0/0 | 0/1 | + | 1:4 | ["A","C"] | NA | NA | 1/1 | 0/1 | + | 1:5 | ["A","C"] | NA | NA | 0/0 | 0/0 | +---------------+------------+------+------+------+------+ diff --git a/hail/python/hail/experimental/ldscsim.py b/hail/python/hail/experimental/ldscsim.py index 486e25b13d4..1b3705c9544 100644 --- a/hail/python/hail/experimental/ldscsim.py +++ b/hail/python/hail/experimental/ldscsim.py @@ -264,7 +264,6 @@ def multitrait_inf(mt, h2=None, rg=None, cov_matrix=None, seed=None): rg = rg.tolist() if type(rg) is np.ndarray else ([rg] if type(rg) is not list else rg) assert (all(x >= 0 and x <= 1 for x in h2)), 'h2 values must be between 0 and 1' assert h2 is not [None] or cov_matrix is not None, 'h2 and cov_matrix cannot both be None' - seed = seed if seed is not None else int(str(Env.next_seed())[:8]) M = mt.count_rows() if cov_matrix is not None: n_phens = cov_matrix.shape[0] @@ -336,7 +335,6 @@ def multitrait_ss(mt, h2, pi, rg=0, seed=None): covariance matrix was not positive semi-definite. """ assert sum(pi) <= 1, "probabilities of being causal must sum to be less than 1" - seed = seed if seed is not None else int(str(Env.next_seed())[:8]) ptt, ptf, pft, pff = pi[0], pi[1], pi[2], 1 - sum(pi) cov_matrix = np.asarray([[1 / (ptt + ptf), rg / ptt], [rg / ptt, 1 / (ptt + pft)]]) M = mt.count_rows() diff --git a/hail/python/hail/experimental/table_ndarray_utils.py b/hail/python/hail/experimental/table_ndarray_utils.py index 7d9fcef5cef..be5dbecf9f5 100644 --- a/hail/python/hail/experimental/table_ndarray_utils.py +++ b/hail/python/hail/experimental/table_ndarray_utils.py @@ -39,6 +39,8 @@ def get_even_partitioning(ht, partition_size, total_num_rows): grouped = new_part_ht._group_within_partitions("groups", block_size) A = grouped.select(ndarray=hl.nd.array(grouped.groups.map(lambda group: group.xs))) + temp_file_name2 = hl.utils.new_temp_file("mt_to_table_of_ndarray", "A") + A = A.checkpoint(temp_file_name2) if return_checkpointed_table_also: return A, ht diff --git a/hail/python/hail/expr/aggregators/aggregators.py b/hail/python/hail/expr/aggregators/aggregators.py index 06c98b73f88..45fe86fda35 100644 --- a/hail/python/hail/expr/aggregators/aggregators.py +++ b/hail/python/hail/expr/aggregators/aggregators.py @@ -1085,7 +1085,7 @@ def explode(f, array_agg_expr) -> Expression: Compute the set of all observed elements in the `filters` field (``Set[String]``): >>> dataset.aggregate_rows(hl.agg.explode(lambda elt: hl.agg.collect_as_set(elt), dataset.filters)) - {'VQSRTrancheINDEL97.00to99.00'} + set() Notes ----- @@ -1154,16 +1154,16 @@ def inbreeding(expr, prior) -> StructExpression: +------------------+-----------+-------------+------------------+------------------+ | str | float64 | int64 | float64 | int64 | +------------------+-----------+-------------+------------------+------------------+ - | "C1046::HG02024" | 3.88e-01 | 12 | 1.04e+01 | 11 | - | "C1046::HG02025" | 3.99e-01 | 13 | 1.13e+01 | 12 | - | "C1046::HG02026" | 3.88e-01 | 12 | 1.04e+01 | 11 | - | "C1047::HG00731" | -1.31e+00 | 12 | 1.07e+01 | 9 | - | "C1047::HG00732" | 1.00e+00 | 12 | 1.04e+01 | 12 | - | "C1047::HG00733" | -8.04e-01 | 13 | 1.13e+01 | 10 | - | "C1048::HG02024" | 3.88e-01 | 12 | 1.04e+01 | 11 | - | "C1048::HG02025" | 3.99e-01 | 13 | 1.13e+01 | 12 | - | "C1048::HG02026" | 1.00e+00 | 12 | 1.04e+01 | 12 | - | "C1049::HG00731" | -1.41e+00 | 13 | 1.13e+01 | 9 | + | "C1046::HG02024" | 2.79e-01 | 9 | 7.61e+00 | 8 | + | "C1046::HG02025" | -4.41e-01 | 9 | 7.61e+00 | 7 | + | "C1046::HG02026" | -4.41e-01 | 9 | 7.61e+00 | 7 | + | "C1047::HG00731" | 2.79e-01 | 9 | 7.61e+00 | 8 | + | "C1047::HG00732" | 2.79e-01 | 9 | 7.61e+00 | 8 | + | "C1047::HG00733" | 2.79e-01 | 9 | 7.61e+00 | 8 | + | "C1048::HG02024" | -4.41e-01 | 9 | 7.61e+00 | 7 | + | "C1048::HG02025" | -4.41e-01 | 9 | 7.61e+00 | 7 | + | "C1048::HG02026" | -4.41e-01 | 9 | 7.61e+00 | 7 | + | "C1049::HG00731" | 2.79e-01 | 9 | 7.61e+00 | 8 | +------------------+-----------+-------------+------------------+------------------+ showing top 10 rows @@ -1232,18 +1232,16 @@ def call_stats(call, alleles) -> StructExpression: +---------------+--------------+---------------------+-------------+---------------------------+ | locus | array | array | int32 | array | +---------------+--------------+---------------------+-------------+---------------------------+ - | 20:12990057 | [148,52] | [7.40e-01,2.60e-01] | 200 | [57,9] | - | 20:13029862 | [0,198] | [0.00e+00,1.00e+00] | 198 | [0,99] | - | 20:13074235 | [13,187] | [6.50e-02,9.35e-01] | 200 | [1,88] | - | 20:13140720 | [194,6] | [9.70e-01,3.00e-02] | 200 | [95,1] | - | 20:13695498 | [175,25] | [8.75e-01,1.25e-01] | 200 | [75,0] | - | 20:13714384 | [199,1] | [9.95e-01,5.00e-03] | 200 | [99,0] | - | 20:13765944 | [132,2] | [9.85e-01,1.49e-02] | 134 | [65,0] | - | 20:13765954 | [180,2] | [9.89e-01,1.10e-02] | 182 | [89,0] | - | 20:13845987 | [2,198] | [1.00e-02,9.90e-01] | 200 | [0,98] | - | 20:16223957 | [145,45] | [7.63e-01,2.37e-01] | 190 | [64,14] | + | 20:10579373 | [199,1] | [9.95e-01,5.00e-03] | 200 | [99,0] | + | 20:10579398 | [198,2] | [9.90e-01,1.00e-02] | 200 | [99,1] | + | 20:10627772 | [198,2] | [9.90e-01,1.00e-02] | 200 | [98,0] | + | 20:10633237 | [108,92] | [5.40e-01,4.60e-01] | 200 | [31,23] | + | 20:10636995 | [198,2] | [9.90e-01,1.00e-02] | 200 | [98,0] | + | 20:10639222 | [175,25] | [8.75e-01,1.25e-01] | 200 | [78,3] | + | 20:13763601 | [198,2] | [9.90e-01,1.00e-02] | 200 | [98,0] | + | 20:16223922 | [87,101] | [4.63e-01,5.37e-01] | 188 | [28,35] | + | 20:17479617 | [191,9] | [9.55e-01,4.50e-02] | 200 | [91,0] | +---------------+--------------+---------------------+-------------+---------------------------+ - showing top 10 rows Notes diff --git a/hail/python/hail/expr/expressions/base_expression.py b/hail/python/hail/expr/expressions/base_expression.py index a82974256ce..ac9a5bf32ac 100644 --- a/hail/python/hail/expr/expressions/base_expression.py +++ b/hail/python/hail/expr/expressions/base_expression.py @@ -963,31 +963,29 @@ def export(self, path, delimiter='\t', missing='NA', header=True): ... for line in f: ... print(line, end='') locus alleles 0 1 2 3 - 1:1 ["A","C"] 0/1 0/1 0/0 0/0 - 1:2 ["A","C"] 1/1 0/1 1/1 1/1 - 1:3 ["A","C"] 1/1 0/1 0/1 0/0 - 1:4 ["A","C"] 1/1 0/1 1/1 1/1 - + 1:1 ["A","C"] 1/1 1/1 0/1 0/1 + 1:2 ["A","C"] 1/1 1/1 0/0 1/1 + 1:3 ["A","C"] 0/0 0/0 0/1 0/0 + 1:4 ["A","C"] 1/1 0/1 1/1 0/1 >>> small_mt.GT.export('output/gt-no-header.tsv', header=False) >>> with open('output/gt-no-header.tsv', 'r') as f: ... for line in f: ... print(line, end='') - 1:1 ["A","C"] 0/1 0/1 0/0 0/0 - 1:2 ["A","C"] 1/1 0/1 1/1 1/1 - 1:3 ["A","C"] 1/1 0/1 0/1 0/0 - 1:4 ["A","C"] 1/1 0/1 1/1 1/1 - + 1:1 ["A","C"] 1/1 1/1 0/1 0/1 + 1:2 ["A","C"] 1/1 1/1 0/0 1/1 + 1:3 ["A","C"] 0/0 0/0 0/1 0/0 + 1:4 ["A","C"] 1/1 0/1 1/1 0/1 >>> small_mt.pop.export('output/pops.tsv') >>> with open('output/pops.tsv', 'r') as f: ... for line in f: ... print(line, end='') sample_idx pop - 0 2 - 1 2 - 2 0 - 3 2 + 0 0 + 1 0 + 2 2 + 3 0 >>> small_mt.ancestral_af.export('output/ancestral_af.tsv') @@ -995,13 +993,12 @@ def export(self, path, delimiter='\t', missing='NA', header=True): ... for line in f: ... print(line, end='') locus alleles ancestral_af - 1:1 ["A","C"] 5.3905e-01 - 1:2 ["A","C"] 8.6768e-01 - 1:3 ["A","C"] 4.3765e-01 - 1:4 ["A","C"] 7.6300e-01 + 1:1 ["A","C"] 5.6562e-01 + 1:2 ["A","C"] 3.6521e-01 + 1:3 ["A","C"] 2.6421e-01 + 1:4 ["A","C"] 6.5715e-01 - >>> mt = small_mt >>> small_mt.bn.export('output/bn.tsv') >>> with open('output/bn.tsv', 'r') as f: ... for line in f: @@ -1025,10 +1022,10 @@ def export(self, path, delimiter='\t', missing='NA', header=True): ... for line in f: ... print(line, end='') locus alleles {"s":0,"family":"fam1"} {"s":1,"family":"fam1"} {"s":2,"family":"fam1"} {"s":3,"family":"fam1"} - 1:1 ["A","C"] 0/1 0/1 0/0 0/0 - 1:2 ["A","C"] 1/1 0/1 1/1 1/1 - 1:3 ["A","C"] 1/1 0/1 0/1 0/0 - 1:4 ["A","C"] 1/1 0/1 1/1 1/1 + 1:1 ["A","C"] 1/1 1/1 0/1 0/1 + 1:2 ["A","C"] 1/1 1/1 0/0 1/1 + 1:3 ["A","C"] 0/0 0/0 0/1 0/0 + 1:4 ["A","C"] 1/1 0/1 1/1 0/1 diff --git a/hail/python/hail/expr/functions.py b/hail/python/hail/expr/functions.py index 7427e83b234..ef1efe8087f 100644 --- a/hail/python/hail/expr/functions.py +++ b/hail/python/hail/expr/functions.py @@ -43,10 +43,17 @@ def _func(name, ret_type, *args, type_args=()): def _seeded_func(name, ret_type, seed, *args): - seed = seed if seed is not None else Env.next_seed() + if seed is None: + static_rng_uid = Env.next_static_rng_uid() + else: + if Env._hc is None or not Env._hc._user_specified_rng_nonce: + warning('To ensure reproducible randomness across Hail sessions, ' + 'you must set the "global_seed" parameter in hl.init(), in ' + 'addition to the local seed in each random function.') + static_rng_uid = -seed - 1 indices, aggregations = unify_all(*args) - rng_state = ir.RNGSplit(ir.Ref('__rng_state', trngstate), ir.MakeTuple([])) - return construct_expr(ir.ApplySeeded(name, seed, rng_state, ret_type, *(a._ir for a in args)), ret_type, indices, aggregations) + rng_state = ir.Ref('__rng_state', trngstate) + return construct_expr(ir.ApplySeeded(name, static_rng_uid, rng_state, ret_type, *(a._ir for a in args)), ret_type, indices, aggregations) def ndarray_broadcasting(func): @@ -2441,7 +2448,7 @@ def rand_bool(p, seed=None) -> BooleanExpression: Examples -------- - >>> hl.set_global_seed(0) + >>> hl.reset_global_randomness() >>> hl.eval(hl.rand_bool(0.5)) False @@ -2462,20 +2469,20 @@ def rand_bool(p, seed=None) -> BooleanExpression: return _seeded_func("rand_bool", tbool, seed, p) -@typecheck(mean=expr_float64, sd=expr_float64, seed=nullable(int)) -def rand_norm(mean=0, sd=1, seed=None) -> Float64Expression: +@typecheck(mean=expr_float64, sd=expr_float64, seed=nullable(int), size=nullable(tupleof(expr_int64))) +def rand_norm(mean=0, sd=1, seed=None, size=None) -> Float64Expression: """Samples from a normal distribution with mean `mean` and standard deviation `sd`. Examples -------- - >>> hl.set_global_seed(0) + >>> hl.reset_global_randomness() >>> hl.eval(hl.rand_norm()) - 0.30971254606692267 + 0.347110923255205 >>> hl.eval(hl.rand_norm()) - -1.6120679347033475 + -0.9281375348070483 Parameters ---------- @@ -2485,12 +2492,17 @@ def rand_norm(mean=0, sd=1, seed=None) -> Float64Expression: Standard deviation of normal distribution. seed : :obj:`int`, optional Random seed. + size : :obj:`int` or :obj:`tuple` of :obj:`int`, optional Returns ------- :class:`.Float64Expression` """ - return _seeded_func("rand_norm", tfloat64, seed, mean, sd) + if size is None: + return _seeded_func("rand_norm", tfloat64, seed, mean, sd) + else: + (nrows, ncols) = size + return _seeded_func("rand_norm_nd", tndarray(tfloat64, 2), seed, nrows, ncols, mean, sd) @typecheck(mean=nullable(expr_array(expr_float64)), cov=nullable(expr_array(expr_float64)), seed=nullable(int)) @@ -2500,12 +2512,12 @@ def rand_norm2d(mean=None, cov=None, seed=None) -> ArrayNumericExpression: Examples -------- - >>> hl.set_global_seed(0) + >>> hl.reset_global_randomness() >>> hl.eval(hl.rand_norm2d()) - [0.30971254606692267, -1.266553783097155] + [-1.3909495945443346, 1.2805588680053859] >>> hl.eval(hl.rand_norm2d()) - [-1.6120679347033475, 1.6121791827078364] + [0.289520302334123, -1.1108917435930954] Notes ----- @@ -2562,12 +2574,12 @@ def rand_pois(lamb, seed=None) -> Float64Expression: Examples -------- - >>> hl.set_global_seed(0) + >>> hl.reset_global_randomness() >>> hl.eval(hl.rand_pois(1)) - 1.0 + 4.0 >>> hl.eval(hl.rand_pois(1)) - 1.0 + 4.0 Parameters ---------- @@ -2583,23 +2595,23 @@ def rand_pois(lamb, seed=None) -> Float64Expression: return _seeded_func("rand_pois", tfloat64, seed, lamb) -@typecheck(lower=expr_float64, upper=expr_float64, seed=nullable(int)) -def rand_unif(lower=0.0, upper=1.0, seed=None) -> Float64Expression: +@typecheck(lower=expr_float64, upper=expr_float64, seed=nullable(int), size=nullable(tupleof(expr_int64))) +def rand_unif(lower=0.0, upper=1.0, seed=None, size=None) -> Float64Expression: """Samples from a uniform distribution within the interval [`lower`, `upper`]. Examples -------- - >>> hl.set_global_seed(0) + >>> hl.reset_global_randomness() >>> hl.eval(hl.rand_unif()) - 0.6830630912401323 + 0.9828239225846387 >>> hl.eval(hl.rand_unif(0, 1)) - 0.4035978197966855 + 0.49094525115847415 >>> hl.eval(hl.rand_unif(0, 1)) - 0.26020045338162423 + 0.3972543766997359 Parameters ---------- @@ -2609,12 +2621,17 @@ def rand_unif(lower=0.0, upper=1.0, seed=None) -> Float64Expression: Right boundary of range. Defaults to 1.0. seed : :obj:`int`, optional Random seed. + size : :obj:`int` or :obj:`tuple` of :obj:`int`, optional Returns ------- :class:`.Float64Expression` """ - return _seeded_func("rand_unif", tfloat64, seed, lower, upper) + if size is None: + return _seeded_func("rand_unif", tfloat64, seed, lower, upper) + else: + (nrows, ncols) = size + return _seeded_func("rand_unif_nd", tndarray(tfloat64, 2), seed, nrows, ncols, lower, upper) @typecheck(a=expr_int32, b=nullable(expr_int32), seed=nullable(int)) @@ -2627,12 +2644,12 @@ def rand_int32(a, b=None, *, seed=None) -> Int32Expression: Examples -------- - >>> hl.set_global_seed(0) + >>> hl.reset_global_randomness() >>> hl.eval(hl.rand_int32(10)) - 0 + 9 >>> hl.eval(hl.rand_int32(10, 15)) - 11 + 14 >>> hl.eval(hl.rand_int32(10, 15)) 12 @@ -2656,25 +2673,27 @@ def rand_int32(a, b=None, *, seed=None) -> Int32Expression: return _seeded_func("rand_int32", tint32, seed, b - a) + a -@typecheck(a=expr_int64, b=nullable(expr_int64), seed=nullable(int)) -def rand_int64(a, b=None, *, seed=None) -> Int64Expression: +@typecheck(a=nullable(expr_int64), b=nullable(expr_int64), seed=nullable(int)) +def rand_int64(a=None, b=None, *, seed=None) -> Int64Expression: """Samples from a uniform distribution of 64-bit integers. - If b is `None`, samples from the uniform distribution over [0, a). Otherwise, sample from the - uniform distribution over [a, b). + If a and b are both specified, samples from the uniform distribution over [a, b). + If b is `None`, samples from the uniform distribution over [0, a). + If both a and b are `None` samples from the uniform distribution over all + 64-bit integers. Examples -------- - >>> hl.set_global_seed(0) + >>> hl.reset_global_randomness() >>> hl.eval(hl.rand_int64(10)) - 2 + 9 >>> hl.eval(hl.rand_int64(1 << 33, 1 << 35)) - 13313179445 + 33089740109 >>> hl.eval(hl.rand_int64(1 << 33, 1 << 35)) - 18981019040 + 18195458570 Parameters ---------- @@ -2689,6 +2708,8 @@ def rand_int64(a, b=None, *, seed=None) -> Int64Expression: ------- :class:`.Int64Expression` """ + if a is None: + return _seeded_func("rand_int64", tint64, seed) if b is None: return _seeded_func("rand_int64", tint64, seed, a) return _seeded_func("rand_int64", tint64, seed, b - a) + a @@ -2715,12 +2736,12 @@ def rand_beta(a, b, lower=None, upper=None, seed=None) -> Float64Expression: Examples -------- - >>> hl.set_global_seed(0) + >>> hl.reset_global_randomness() >>> hl.eval(hl.rand_beta(0.5, 0.5)) - 0.3483677318466065 + 0.30607924177641355 >>> hl.eval(hl.rand_beta(2, 5)) - 0.23894608018057753 + 0.1103872607301062 Parameters ---------- @@ -2758,12 +2779,12 @@ def rand_gamma(shape, scale, seed=None) -> Float64Expression: Examples -------- - >>> hl.set_global_seed(0) + >>> hl.reset_global_randomness() >>> hl.eval(hl.rand_gamma(1, 1)) - 0.8934929450909811 + 3.115449479063202 >>> hl.eval(hl.rand_gamma(1, 1)) - 0.3423233699402248 + 3.077698059931638 Parameters ---------- @@ -2798,12 +2819,12 @@ def rand_cat(prob, seed=None) -> Int32Expression: Examples -------- - >>> hl.set_global_seed(0) + >>> hl.reset_global_randomness() >>> hl.eval(hl.rand_cat([0, 1.7, 2])) 2 >>> hl.eval(hl.rand_cat([0, 1.7, 2])) - 1 + 2 Parameters ---------- @@ -2827,12 +2848,12 @@ def rand_dirichlet(a, seed=None) -> ArrayExpression: Examples -------- - >>> hl.set_global_seed(0) + >>> hl.reset_global_randomness() >>> hl.eval(hl.rand_dirichlet([1, 1, 1])) - [0.31600799564679466, 0.22921566396520351, 0.45477634038800185] + [0.6987619676833735, 0.287566556865261, 0.013671475451365567] >>> hl.eval(hl.rand_dirichlet([1, 1, 1])) - [0.28935842257116556, 0.40020478428981887, 0.31043679313901557] + [0.16299928555608242, 0.04393664153526524, 0.7930640729086523] Parameters ---------- @@ -6284,9 +6305,9 @@ def shuffle(a, seed: builtins.int = None) -> ArrayExpression: Example ------- - >>> hl.set_global_seed(0) + >>> hl.reset_global_randomness() >>> hl.eval(hl.shuffle(hl.range(5))) - [3, 2, 0, 4, 1] + [4, 0, 2, 1, 3] Parameters ---------- @@ -6299,7 +6320,7 @@ def shuffle(a, seed: builtins.int = None) -> ArrayExpression: ------- :class:`.ArrayExpression` """ - return sorted(a, key=lambda _: hl.rand_unif(0.0, 1.0, seed=seed)) + return sorted(a, key=lambda _: hl.rand_unif(0.0, 1.0)) @typecheck(path=builtins.str, point_or_interval=expr_any) diff --git a/hail/python/hail/ir/__init__.py b/hail/python/hail/ir/__init__.py index 0d0bb571292..45e1fd6c525 100644 --- a/hail/python/hail/ir/__init__.py +++ b/hail/python/hail/ir/__init__.py @@ -18,8 +18,8 @@ MakeStruct, SelectFields, InsertFields, GetField, MakeTuple, \ GetTupleElement, Die, ConsoleLog, Apply, ApplySeeded, RNGStateLiteral, RNGSplit,\ TableCount, TableGetGlobals, TableCollect, TableAggregate, MatrixCount,\ - MatrixAggregate, TableWrite, udf, subst, clear_session_functions, get_static_split_uid, \ - ReadPartition, PartitionNativeIntervalReader + MatrixAggregate, TableWrite, udf, subst, clear_session_functions, ReadPartition,\ + PartitionNativeIntervalReader from .register_functions import register_functions from .register_aggregators import register_aggregators from .table_ir import MatrixRowsTable, TableJoin, TableLeftJoinRightDistinct, \ @@ -216,7 +216,6 @@ 'udf', 'subst', 'clear_session_functions', - 'get_static_split_uid', 'MatrixWrite', 'MatrixMultiWrite', 'BlockMatrixWrite', diff --git a/hail/python/hail/ir/base_ir.py b/hail/python/hail/ir/base_ir.py index e3b6a637f5e..5a0294d47bc 100644 --- a/hail/python/hail/ir/base_ir.py +++ b/hail/python/hail/ir/base_ir.py @@ -1,6 +1,6 @@ import abc -from hail.expr.types import tstream +from hail.expr.types import tstream, tstruct from hail.utils.java import Env from .renderer import Renderer, PlainRenderer, Renderable @@ -244,6 +244,8 @@ def __init__(self, *children): self._free_vars = None self._free_agg_vars = None self._free_scan_vars = None + self.has_uids = False + self.needs_randomness_handling = False @property def aggregations(self): @@ -301,6 +303,7 @@ def renderable_new_block(self, i: int) -> bool: def compute_type(self, env, agg_env, deep_typecheck): if deep_typecheck or self._type is None: computed = self._compute_type(env, agg_env, deep_typecheck) + assert(computed is not None) if self._type is not None: assert self._type == computed self._type = computed @@ -339,10 +342,14 @@ def handle_randomness(self, create_uids): The uid may be an int64, or arbitrary tuple of int64s. The only requirement is that all stream elements contain distinct uid values. """ - assert(isinstance(self.typ, tstream)) - if not create_uids and not self.uses_randomness: + assert(self.is_stream) + if (create_uids == self.has_uids) and not self.needs_randomness_handling: return self - return self._handle_randomness(create_uids) + new = self._handle_randomness(create_uids) + assert(isinstance(self.typ.element_type, tstruct) == isinstance(new.typ.element_type, tstruct)) + new.has_uids = create_uids + new.needs_randomness_handling = False + return new @property def free_vars(self): diff --git a/hail/python/hail/ir/blockmatrix_ir.py b/hail/python/hail/ir/blockmatrix_ir.py index 5e7003e65dd..3897a88bffd 100644 --- a/hail/python/hail/ir/blockmatrix_ir.py +++ b/hail/python/hail/ir/blockmatrix_ir.py @@ -32,6 +32,7 @@ def _compute_type(self, deep_typecheck): class BlockMatrixMap(BlockMatrixIR): @typecheck_method(child=BlockMatrixIR, name=str, f=IR, needs_dense=bool) def __init__(self, child, name, f, needs_dense): + assert(not f.uses_randomness) super().__init__(child, f) self.child = child self.name = name @@ -64,6 +65,7 @@ def binds(self, i): class BlockMatrixMap2(BlockMatrixIR): @typecheck_method(left=BlockMatrixIR, right=BlockMatrixIR, left_name=str, right_name=str, f=IR, sparsity_strategy=str) def __init__(self, left, right, left_name, right_name, f, sparsity_strategy): + assert(not f.uses_randomness) super().__init__(left, right, f) self.left = left self.right = right @@ -299,6 +301,7 @@ def __repr__(self): class BlockMatrixSparsify(BlockMatrixIR): @typecheck_method(child=BlockMatrixIR, value=IR, sparsifier=BlockMatrixSparsifier) def __init__(self, child, value, sparsifier): + assert(not value.uses_randomness) super().__init__(value, child) self.child = child self.value = value @@ -345,6 +348,9 @@ class ValueToBlockMatrix(BlockMatrixIR): shape=sequenceof(int), block_size=int) def __init__(self, child, shape, block_size): + from .ir import Let, RNGStateLiteral + if child.uses_randomness: + child = Let('__rng_state', RNGStateLiteral(), child) super().__init__(child) self.child = child self.shape = shape @@ -372,25 +378,25 @@ def _compute_type(self, deep_typecheck): class BlockMatrixRandom(BlockMatrixIR): - @typecheck_method(seed=int, + @typecheck_method(static_rng_uid=int, gaussian=bool, shape=sequenceof(int), block_size=int) - def __init__(self, seed, gaussian, shape, block_size): + def __init__(self, static_rng_uid, gaussian, shape, block_size): super().__init__() - self.seed = seed + self.static_rng_uid = static_rng_uid self.gaussian = gaussian self.shape = shape self.block_size = block_size def head_str(self): - return '{} {} {} {}'.format(self.seed, + return '{} {} {} {}'.format(self.static_rng_uid, self.gaussian, _serialize_list(self.shape), self.block_size) def _eq(self, other): - return self.seed == other.seed and \ + return self.static_rng_uid == other.static_rng_uid and \ self.gaussian == other.gaussian and \ self.shape == other.shape and \ self.block_size == other.block_size diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index a3acc048e4b..bda5abe241d 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -11,14 +11,14 @@ from hail.ir.blockmatrix_writer import BlockMatrixWriter, BlockMatrixMultiWriter from hail.typecheck import typecheck, typecheck_method, sequenceof, numeric, \ sized_tupleof, nullable, tupleof, anytype, func_spec -from hail.utils.java import Env, HailUserError, warning +from hail.utils.java import Env, HailUserError from hail.utils.jsonx import dump_json from hail.utils.misc import escape_str, parsable_strings, escape_id from .base_ir import BaseIR, IR, TableIR, MatrixIR, BlockMatrixIR, _env_bind from .matrix_writer import MatrixWriter, MatrixNativeMultiWriter from .renderer import Renderer, Renderable, ParensRenderer from .table_writer import TableWriter -from .utils import default_row_uid, default_col_uid, rng_key, unpack_row_uid, unpack_col_uid +from .utils import default_row_uid, default_col_uid, unpack_row_uid, unpack_col_uid class I32(IR): @@ -188,12 +188,12 @@ def __init__(self, typ): self._typ = typ def _handle_randomness(self, create_uids): - if create_uids: - if isinstance(self.typ.element_type, tstruct): - new_elt_typ = self.typ.element_type._insert_field(uid_field_name, tint64) - else: - new_elt_typ = ttuple(tint64, self.typ.element_type) - return NA(tstream(new_elt_typ)) + assert create_uids + if isinstance(self.typ.element_type, tstruct): + new_elt_typ = self.typ.element_type._insert_field(uid_field_name, tint64) + else: + new_elt_typ = ttuple(tint64, self.typ.element_type) + return NA(tstream(new_elt_typ)) @property def typ(self): @@ -234,6 +234,7 @@ def __init__(self, cond, cnsq, altr): self.cond = cond self.cnsq = cnsq self.altr = altr + self.needs_randomness_handling = cnsq.needs_randomness_handling or altr.needs_randomness_handling def _handle_randomness(self, create_uids): return If(self.cond, @@ -283,6 +284,10 @@ def __init__(self, name, value, body): self.name = name self.value = value self.body = body + self.needs_randomness_handling = body.needs_randomness_handling + + def _handle_randomness(self, create_uids): + return Let(self.name, self.value, self.body.handle_randomness(create_uids)) @typecheck_method(value=IR, body=IR) def copy(self, value, body): @@ -370,21 +375,26 @@ def renderable_uses_scan_context(self, i: int) -> bool: class Ref(IR): - @typecheck_method(name=str, type=nullable(HailType)) - def __init__(self, name, type=None): + @typecheck_method(name=str, type=nullable(HailType), has_uids=bool) + def __init__(self, name, type=None, has_uids=False): super().__init__() self.name = name self._free_vars = {name} self._typ = type + self.has_uids = has_uids def _handle_randomness(self, create_uids): - if not create_uids: - return self - elt = Env.get_uid - uid = Env.get_uid - return StreamZip([self, StreamIota(I32(0), I32(1))], - [elt, uid], - pack_uid(Cast(Ref(uid, tint32), tint64), Ref(elt, tint32))) + assert create_uids != self.has_uids + if create_uids: + elt = Env.get_uid() + uid = Env.get_uid() + return StreamZip([self, StreamIota(I32(0), I32(1))], + [elt, uid], + pack_uid(Cast(Ref(uid, tint32), tint64), Ref(elt, self.typ.element_type)), + 'TakeMinLength') + else: + tuple, uid, elt = unpack_uid(self.typ) + return StreamMap(self, tuple, elt) def copy(self): return Ref(self.name, self._type) @@ -407,7 +417,7 @@ def _compute_type(self, env, agg_env, deep_typecheck): class TopLevelReference(Ref): @typecheck_method(name=str, type=nullable(HailType)) - def __init__(self, name, type=None): + def __init__(self, name, type): super().__init__(name, type) @property @@ -690,9 +700,8 @@ def __init__(self, start, step, requires_memory_management_per_element=False): self.requires_memory_management_per_element = requires_memory_management_per_element def _handle_randomness(self, create_uids): - if not create_uids: - return self - elt = Env.get_uid + assert create_uids + elt = Env.get_uid() return StreamMap(self, elt, MakeTuple([Cast(Ref(elt, tint32), tint64), Ref(elt, tint32)])) @typecheck_method(start=IR, step=IR) @@ -725,8 +734,7 @@ def __init__(self, start, stop, step, requires_memory_management_per_element=Fal self.save_error_info() def _handle_randomness(self, create_uids): - if not create_uids: - return self + assert create_uids elt = Env.get_uid() return StreamMap(self, elt, MakeTuple([Cast(Ref(elt, tint32), tint64), Ref(elt, tint32)])) @@ -750,9 +758,11 @@ def __init__(self, stream, group_size): super().__init__(stream, group_size) self.stream = stream self.group_size = group_size + self.needs_randomness_handling = stream.needs_randomness_handling def _handle_randomness(self, create_uids): assert(not create_uids) + assert(self.stream.needs_randomness_handling) self.stream.handle_randomness(False) @typecheck_method(stream=IR, group_size=IR) @@ -832,6 +842,7 @@ def _compute_type(self, env, agg_env, deep_typecheck): class NDArrayMap(IR): @typecheck_method(nd=IR, name=str, body=IR) def __init__(self, nd, name, body): + assert(not body.uses_randomness) super().__init__(nd, body) self.nd = nd self.name = name @@ -1152,6 +1163,7 @@ def is_effectful() -> bool: class ArraySort(IR): @typecheck_method(a=IR, l_name=str, r_name=str, compare=IR) def __init__(self, a, l_name, r_name, compare): + a = a.handle_randomness(False) super().__init__(a, compare) self.a = a self.l_name = l_name @@ -1228,8 +1240,7 @@ def toArray(s): class ToArray(IR): @typecheck_method(a=IR) def __init__(self, a): - if a.uses_randomness: - a = a.handle_randomness(False) + a = a.handle_randomness(False) super().__init__(a) self.a = a @@ -1273,8 +1284,7 @@ def __init__(self, a, requires_memory_management_per_element=False): self.requires_memory_management_per_element = requires_memory_management_per_element def _handle_randomness(self, create_uids): - if not create_uids: - return self + assert create_uids uid = Env.get_uid() elt = Env.get_uid() iota = StreamIota(I32(0), I32(1)) @@ -1329,18 +1339,9 @@ def _compute_type(self, env, agg_env, deep_typecheck): tarray(self.collection.typ.element_type.types[1])) -static_split_ctr = 0 uid_field_name = '__uid' -def get_static_split_uid(): - global static_split_ctr - result = static_split_ctr - assert(result <= 0xFFFF_FFFF_FFFF_FFFF) - static_split_ctr += 1 - return result - - def uid_size(type): if isinstance(type, ttuple): return len(type) @@ -1353,7 +1354,7 @@ def unify_uid_types(types, tag=False): size += 1 if size == 1: return tint64 - return ttuple(tint64 for _ in range(size)) + return ttuple(*(tint64 for _ in range(size))) def pad_uid(uid, type, tag=None): @@ -1368,10 +1369,9 @@ def pad_uid(uid, type, tag=None): else: fields = (GetTupleElement(uid, i) for i in range(size)) if tag is None: - tuple = MakeTuple([*(I64(0) for _ in range(padding)), *fields]) + return MakeTuple([*(I64(0) for _ in range(padding)), *fields]) else: - tuple = MakeTuple([I64(tag), *(I64(0) for _ in range(padding)), *fields]) - return If(IsNA(uid), NA(tuple.typ), tuple) + return MakeTuple([I64(tag), *(I64(0) for _ in range(padding)), *fields]) def concat_uids(uid1, uid2, handle_missing_left=False, handle_missing_right=False): @@ -1389,8 +1389,7 @@ def concat_uids(uid1, uid2, handle_missing_left=False, handle_missing_right=Fals fields2 = (GetTupleElement(uid2, i) for i in range(size2)) if handle_missing_right: fields2 = (Coalesce(field, I64(0)) for field in fields2) - tuple = MakeTuple([*fields1, *fields2]) - return If(Apply("lor", tbool, IsNA(uid1), IsNA(uid2)), NA(tuple.typ), tuple) + return MakeTuple([*fields1, *fields2]) def unpack_uid(stream_type): @@ -1408,7 +1407,7 @@ def unpack_uid(stream_type): def pack_uid(uid, elt): if isinstance(elt.typ, tstruct): - return InsertFields(elt, [uid_field_name, uid]) + return InsertFields(elt, [(uid_field_name, uid)], None) else: return MakeTuple([uid, elt]) @@ -1428,6 +1427,7 @@ def __init__(self, a, n): super().__init__(a, n) self.a = a self.n = n + self.needs_randomness_handling = a.needs_randomness_handling def _handle_randomness(self, create_uids): a = self.a.handle_randomness(create_uids) @@ -1450,6 +1450,7 @@ def __init__(self, a, name, body): self.a = a self.name = name self.body = body + self.needs_randomness_handling = a.needs_randomness_handling or body.uses_randomness def _handle_randomness(self, create_uids): if not self.body.uses_randomness and not create_uids: @@ -1517,6 +1518,7 @@ def __init__(self, streams, names, body, behavior, error_id=None, stack_trace=No self._stack_trace = stack_trace if error_id is None or stack_trace is None: self.save_error_info() + self.needs_randomness_handling = any(stream.needs_randomness_handling for stream in streams) or body.uses_randomness def _handle_randomness(self, create_uids): if not self.body.uses_randomness and not create_uids: @@ -1526,24 +1528,24 @@ def _handle_randomness(self, create_uids): if self.behavior == 'ExtendNA': new_streams = [stream.handle_randomness(True) for stream in self.streams] tuples, uids, elts = zip(*(unpack_uid(stream.typ) for stream in new_streams)) - uid_type = unify_uid_types(uid.typ for uid in uids) - uid = Coalesce(If(IsNA(uid), NA(uid_type), pad_uid(uid, uid_type, i)) for i, uid in enumerate(uids)) + uid_type = unify_uid_types((uid.typ for uid in uids), tag=True) + uid = Coalesce(*(If(IsNA(uid), NA(uid_type), pad_uid(uid, uid_type, i)) for i, uid in enumerate(uids))) new_body = self.body for elt, name in zip(elts, self.names): new_body = Let(name, elt, new_body) if self.body.uses_randomness: new_body = with_split_rng_state(new_body, uid) if create_uids: - pack_uid(uid, new_body) + new_body = pack_uid(uid, new_body) return StreamZip(new_streams, tuples, new_body, self.behavior, self._error_id, self._stack_trace) new_streams = [self.streams[0].handle_randomness(True), *(stream.handle_randomness(False) for stream in self.streams[1:])] - tuple, uid, elt = unpack_uid(self.streams[0].typ) + tuple, uid, elt = unpack_uid(new_streams[0].typ) new_body = Let(self.names[0], elt, self.body) if self.body.uses_randomness: new_body = with_split_rng_state(new_body, uid) if create_uids: - pack_uid(uid, new_body) + new_body = pack_uid(uid, new_body) return StreamZip(new_streams, [tuple, *self.names[1:]], new_body, self.behavior, self._error_id, self._stack_trace) @typecheck_method(children=IR) @@ -1580,6 +1582,7 @@ def __init__(self, a, name, body): self.a = a self.name = name self.body = body + self.needs_randomness_handling = a.needs_randomness_handling or body.uses_randomness def _handle_randomness(self, create_uids): if not self.body.uses_randomness and not create_uids: @@ -1633,6 +1636,7 @@ def __init__(self, a, name, body): self.a = a self.name = name self.body = body + self.needs_randomness_handling = a.needs_randomness_handling or body.uses_randomness def _handle_randomness(self, create_uids): if not self.body.uses_randomness and not create_uids: @@ -1683,8 +1687,7 @@ def renderable_bindings(self, i, default_value=None): class StreamFold(IR): @typecheck_method(a=IR, zero=IR, accum_name=str, value_name=str, body=IR) def __init__(self, a, zero, accum_name, value_name, body): - if a.uses_randomness or body.uses_randomness: - a = a.handle_randomness(create_uids=body.uses_randomness) + a = a.handle_randomness(create_uids=body.uses_randomness) if body.uses_randomness: tuple, uid, elt = unpack_uid(a.typ) body = Let(value_name, elt, body) @@ -1738,6 +1741,7 @@ def __init__(self, a, zero, accum_name, value_name, body): self.accum_name = accum_name self.value_name = value_name self.body = body + self.needs_randomness_handling = a.needs_randomness_handling or body.uses_randomness def _handle_randomness(self, create_uids): if not self.body.uses_randomness and not create_uids: @@ -1746,7 +1750,7 @@ def _handle_randomness(self, create_uids): a = self.a.handle_randomness(True) tuple, uid, elt = unpack_uid(a.typ) - new_body = Let(self.name, elt, self.body) + new_body = Let(self.value_name, elt, self.body) if self.body.uses_randomness: new_body = with_split_rng_state(new_body, uid) if create_uids: @@ -1796,26 +1800,23 @@ def __init__(self, left, right, l_key, r_key, l_name, r_name, join, join_type): self.r_name = r_name self.join = join self.join_type = join_type + self.needs_randomness_handling = left.needs_randomness_handling or right.needs_randomness_handling or join.uses_randomness def _handle_randomness(self, create_uids): if not self.join.uses_randomness and not create_uids: - if self.left.uses_randomness: - left = self.left.handle_randomness(False) - if self.right.uses_randomness: - right = self.right.handle_randomness(False) + left = self.left.handle_randomness(False) + right = self.right.handle_randomness(False) return StreamJoinRightDistinct(left, right, self.l_key, self.r_key, self.l_name, self.r_name, self.join, self.join_type) if self.join_type == 'left' or self.join_type == 'inner': left = self.left.handle_randomness(True) - if self.right.uses_randomness: - right = self.right.handle_randomness(False) + right = self.right.handle_randomness(False) r_name = self.r_name l_name, uid, l_elt = unpack_uid(left.typ) new_join = Let(self.l_name, l_elt, self.join) elif self.join_type == 'right': right = self.right.handle_randomness(True) - if self.left.uses_randomness: - left = self.left.handle_randomness(False) + left = self.left.handle_randomness(False) l_name = self.l_name r_name, uid, r_elt = unpack_uid(right.typ) new_join = Let(self.r_name, r_elt, self.join) @@ -1868,14 +1869,12 @@ def renderable_bindings(self, i, default_value=None): class StreamFor(IR): @typecheck_method(a=IR, value_name=str, body=IR) def __init__(self, a, value_name, body): + a = a.handle_randomness(body.uses_randomness) if body.uses_randomness: - a = a.handle_randomness(True) tuple, uid, elt = unpack_uid(a.typ) body = Let(value_name, elt, body) body = with_split_rng_state(body, uid) value_name = tuple - elif a.uses_randomness: - a = a.handle_randomness(False) super().__init__(a, body) self.a = a @@ -1955,14 +1954,13 @@ def uses_agg_capability(cls) -> bool: class AggExplode(IR): @typecheck_method(s=IR, name=str, agg_body=IR, is_scan=bool) def __init__(self, s, name, agg_body, is_scan): - if s.uses_randomness or agg_body.uses_agg_randomness(is_scan): - if agg_body.uses_agg_randomness(is_scan): - s = s.handle_randomness(True) - tuple, uid, elt = unpack_uid(s.typ) - agg_body = Let(self.name, elt, agg_body) - agg_body = with_split_rng_state(agg_body, uid, is_scan) - else: - s = s.handle_randomness(False) + s = s.handle_randomness(agg_body.uses_agg_randomness(is_scan)) + if agg_body.uses_agg_randomness(is_scan): + s = s.handle_randomness(True) + tuple, uid, elt = unpack_uid(s.typ) + agg_body = AggLet(name, elt, agg_body, is_scan) + agg_body = with_split_rng_state(agg_body, uid, is_scan) + name = tuple super().__init__(s, agg_body) self.name = name self.s = s @@ -2355,7 +2353,7 @@ def _compute_type(self, env, agg_env, deep_typecheck): class SelectedTopLevelReference(SelectFields): @typecheck_method(name=str, type=tstruct) def __init__(self, name, type=None): - ref = TopLevelReference(name) + ref = TopLevelReference(name, None) super().__init__(ref, type.fields) self.ref = ref self._typ = type @@ -2492,7 +2490,7 @@ def _compute_type(self, env, agg_env, deep_typecheck): class ProjectedTopLevelReference(GetField): @typecheck_method(name=str, field=str, type=HailType) def __init__(self, name, field, type=None): - ref = TopLevelReference(name) + ref = TopLevelReference(name, None) super().__init__(ref, field) self.ref = ref self.field = field @@ -2689,27 +2687,24 @@ def _compute_type(self, env, agg_env, deep_typecheck): class ApplySeeded(IR): - @typecheck_method(function=str, seed=int, rng_state=IR, return_type=hail_type, args=IR) - def __init__(self, function, seed, rng_state, return_type, *args): - if hail.current_backend().requires_lowering: - warning("Seeded randomness is currently unreliable on the service. " - "You may observe some unexpected behavior. Don't use for real work yet.") + @typecheck_method(function=str, static_rng_uid=int, rng_state=IR, return_type=hail_type, args=IR) + def __init__(self, function, static_rng_uid, rng_state, return_type, *args): super().__init__(rng_state, *args) self.function = function self.args = args self.rng_state = rng_state - self.seed = seed + self.static_rng_uid = static_rng_uid self.return_type = return_type def copy(self, rng_state, *args): - return ApplySeeded(self.function, self.seed, rng_state, self.return_type, *args) + return ApplySeeded(self.function, self.static_rng_uid, rng_state, self.return_type, *args) def head_str(self): - return f'{escape_id(self.function)} {self.seed} {self.return_type._parsable_string()}' + return f'{escape_id(self.function)} {self.static_rng_uid} {self.return_type._parsable_string()}' def _eq(self, other): return other.function == self.function and \ - other.seed == self.seed and \ + other.static_rng_uid == self.static_rng_uid and \ other.return_type == self.return_type def _compute_type(self, env, agg_env, deep_typecheck): @@ -2720,21 +2715,12 @@ def _compute_type(self, env, agg_env, deep_typecheck): class RNGStateLiteral(IR): - @typecheck_method(key=sized_tupleof(int, int, int, int)) - def __init__(self, key): - for k in key: - assert 0 <= k < 0xFFFFFFFF_FFFFFFFF + @typecheck_method() + def __init__(self): super().__init__() - self.key = key def copy(self): - return RNGStateLiteral(self.key) - - def head_str(self): - return f'({" ".join(map(str, self.key))})' - - def _eq(self, other): - return other.key == self.key + return RNGStateLiteral() def _compute_type(self, env, agg_env, deep_typecheck): return trngstate @@ -2809,12 +2795,12 @@ class TableAggregate(IR): @typecheck_method(child=TableIR, query=IR) def __init__(self, child, query): if query.uses_randomness: - child = child.handle_randomness(uid_field_name) - uid = GetField(Ref('row', child.typ.row_type), uid_field_name) + child = child.handle_randomness(default_row_uid) + uid = GetField(Ref('row', child.typ.row_type), default_row_uid) if query.uses_value_randomness: - query = Let('__rng_state', RNGStateLiteral(rng_key), query) + query = Let('__rng_state', RNGStateLiteral(), query) if query.uses_agg_randomness(is_scan=False): - query = AggLet('__rng_state', RNGSplit(RNGStateLiteral(rng_key), uid), query, is_scan=False) + query = AggLet('__rng_state', RNGSplit(RNGStateLiteral(), uid), query, is_scan=False) else: child = child.handle_randomness(None) super().__init__(child, query) @@ -2865,14 +2851,14 @@ class MatrixAggregate(IR): @typecheck_method(child=MatrixIR, query=IR) def __init__(self, child, query): if query.uses_value_randomness: - query = Let('__rng_state', RNGStateLiteral(rng_key), query) + query = Let('__rng_state', RNGStateLiteral(), query) if query.uses_agg_randomness(is_scan=False): child = child.handle_randomness(default_row_uid, default_col_uid) row_uid, old_row = unpack_row_uid(child.typ.row_type, default_row_uid) col_uid, old_col = unpack_col_uid(child.typ.col_type, default_col_uid) entry_uid = concat_uids(row_uid, col_uid) - query = AggLet('__rng_state', RNGSplit(RNGStateLiteral(rng_key), entry_uid), query, is_scan=False) + query = AggLet('__rng_state', RNGSplit(RNGStateLiteral(), entry_uid), query, is_scan=False) else: child = child.handle_randomness(None, None) diff --git a/hail/python/hail/ir/matrix_ir.py b/hail/python/hail/ir/matrix_ir.py index 84a968697e2..1c3672ff152 100644 --- a/hail/python/hail/ir/matrix_ir.py +++ b/hail/python/hail/ir/matrix_ir.py @@ -2,8 +2,9 @@ import hail as hl from hail.expr.types import HailType, tint64 from hail.ir.base_ir import BaseIR, MatrixIR -from hail.ir.utils import modify_deep_field, zip_with_index, zip_with_index_field, default_row_uid, default_col_uid, rng_key, unpack_row_uid, unpack_col_uid +from hail.ir.utils import modify_deep_field, zip_with_index, zip_with_index_field, default_row_uid, default_col_uid, unpack_row_uid, unpack_col_uid import hail.ir.ir as ir +from hail.utils import FatalError from hail.utils.misc import escape_str, parsable_strings, escape_id from hail.utils.jsonx import dump_json from hail.utils.java import Env @@ -19,11 +20,11 @@ def __init__(self, child, entry_expr, row_expr): def _handle_randomness(self, row_uid_field_name, col_uid_field_name): drop_row_uid = False drop_col_uid = False - if self.entry_expr.uses_randomness: + if self.entry_expr.uses_randomness or self.row_expr.uses_randomness: drop_row_uid = row_uid_field_name is None if row_uid_field_name is None: row_uid_field_name = default_row_uid - if self.entry_expr.uses_randomness or self.row_expr.uses_randomness: + if self.entry_expr.uses_randomness: drop_col_uid = col_uid_field_name is None if col_uid_field_name is None: col_uid_field_name = default_col_uid @@ -33,29 +34,30 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): row_expr = self.row_expr if row_uid_field_name is not None: row_uid, old_row = unpack_row_uid(child.typ.row_type, row_uid_field_name) - first_row_uid = ir.ApplyAggOp('Take', [1], [row_uid]) - entry_expr = ir.Let('va', old_row, entry_expr) + first_row_uid = ir.ArrayRef(ir.ApplyAggOp('Take', [ir.I32(1)], [row_uid]), ir.I32(0)) entry_expr = ir.AggLet('va', old_row, entry_expr, is_scan=False) + row_expr = ir.AggLet('va', old_row, row_expr, is_scan=False) + row_expr = ir.InsertFields(row_expr, [(row_uid_field_name, first_row_uid)], None) if col_uid_field_name is not None: col_uid, old_col = unpack_col_uid(child.typ.col_type, col_uid_field_name) entry_expr = ir.AggLet('sa', old_col, entry_expr, is_scan=False) - row_expr = ir.AggLet('sa', old_col, row_expr, is_scan=False) + entry_expr = ir.Let('sa', old_col, entry_expr) if self.entry_expr.uses_value_randomness: entry_expr = ir.Let('__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(rng_key), ir.concat_uids(first_row_uid, col_uid)), + ir.RNGSplit(ir.RNGStateLiteral(), ir.concat_uids(first_row_uid, col_uid)), entry_expr) if self.entry_expr.uses_agg_randomness(is_scan=False): entry_expr = ir.AggLet('__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(rng_key), ir.concat_uids(row_uid, col_uid)), + ir.RNGSplit(ir.RNGStateLiteral(), ir.concat_uids(row_uid, col_uid)), entry_expr, is_scan=False) if self.row_expr.uses_value_randomness: row_expr = ir.Let('__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(rng_key), first_row_uid), + ir.RNGSplit(ir.RNGStateLiteral(), first_row_uid), row_expr) if self.row_expr.uses_agg_randomness(is_scan=False): row_expr = ir.AggLet('__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(rng_key), row_uid), + ir.RNGSplit(ir.RNGStateLiteral(), row_uid), row_expr, is_scan=False) @@ -154,7 +156,8 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): col = ir.Ref('sa', self.typ.col_type) result = MatrixMapCols( result, - ir.InsertFields(col, [(col_uid_field_name, ir.GetField(row, default_col_uid))], None)) + ir.InsertFields(col, [(col_uid_field_name, ir.GetField(row, default_col_uid))], None), + None) if rename: result = MatrixRename(result, {}, col_map, row_map, {}) return result @@ -202,7 +205,7 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): uid, old_row = unpack_row_uid(child.typ.row_type, row_uid_field_name) pred = ir.Let('va', old_row, self.pred) if self.pred.uses_randomness: - pred = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(rng_key), uid), pred) + pred = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), uid), pred) result = MatrixFilterRows(child, pred) if drop_row_uid: result = MatrixMapRows(result, old_row) @@ -263,12 +266,12 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): if row_uid_field_name is not None: row_uid, old_row = unpack_row_uid(child.typ.row_type, row_uid_field_name) if self.new_col.uses_value_randomness: - new_col = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(rng_key), col_uid), new_col) + new_col = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), col_uid), new_col) if self.new_col.uses_agg_randomness(is_scan=True): - new_col = ir.AggLet('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(rng_key), col_uid), new_col, is_scan=True) + new_col = ir.AggLet('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), col_uid), new_col, is_scan=True) if self.new_col.uses_agg_randomness(is_scan=False): entry_uid = ir.concat_uids(row_uid, col_uid) - new_col = ir.AggLet('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(rng_key), entry_uid), new_col, is_scan=False) + new_col = ir.AggLet('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), entry_uid), new_col, is_scan=False) if keep_col_uid: new_col = ir.InsertFields(new_col, [(col_uid_field_name, col_uid)], None) result = MatrixMapCols(child, new_col, self.new_key) @@ -318,17 +321,17 @@ def __init__(self, left, right, join_type): self.join_type = join_type def _handle_randomness(self, row_uid_field_name, col_uid_field_name): - if self.join_type == 'outer': - # FIXME: Need to make MatrixUnionCols preserve row fields from the right - # to handle the outer join case - row_uid_field_name = None + if self.join_type == 'outer' and row_uid_field_name is not None: + right_row_uid_field_name = f'{row_uid_field_name}_right' + else: + right_row_uid_field_name = None left = self.left.handle_randomness(row_uid_field_name, col_uid_field_name) - right = self.right.handle_randomness(None, col_uid_field_name) + right = self.right.handle_randomness(right_row_uid_field_name, col_uid_field_name) if col_uid_field_name is not None: - left_uid = unpack_col_uid(left.typ.col_type, col_uid_field_name) - right_uid = unpack_col_uid(right.typ.col_type, col_uid_field_name) - uid_type = ir.unify_uid_types(left_uid.typ, right_uid.typ) + (left_uid, _) = unpack_col_uid(left.typ.col_type, col_uid_field_name) + (right_uid, _) = unpack_col_uid(right.typ.col_type, col_uid_field_name) + uid_type = ir.unify_uid_types((left_uid.typ, right_uid.typ), tag=True) left = MatrixMapCols(left, ir.InsertFields(ir.Ref('sa', left.typ.col_type), [(col_uid_field_name, ir.pad_uid(left_uid, uid_type, 0))], None), @@ -339,13 +342,16 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): new_key=None) result = MatrixUnionCols(left, right, self.join_type) - # FIXME: Need to make MatrixUnionCols preserve row fields from the right - # to handle the outer join case - if row_uid_field_name is not None and self.join_type == 'outer': - result = MatrixMapRows(result, - ir.InsertFields(ir.Ref('va', result.typ.row_type), - [(row_uid_field_name, ir.NA(tint64))], - None)) + + if row_uid_field_name is not None and right_row_uid_field_name is not None: + row = ir.Ref('row', result.typ.row_type) + old_joined_row = ir.SelectFields(row, [field for field in self.typ.row_type]) + left_uid = ir.GetField(row, row_uid_field_name) + right_uid = ir.GetField(row, right_row_uid_field_name) + uid = ir.concat_uids(left_uid, right_uid, True, True) + new_row = ir.InsertFields(old_joined_row, [(row_uid_field_name, uid)], None) + result = MatrixMapRows(result, new_row) + return result def head_str(self): @@ -394,7 +400,7 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): col_uid, old_col = unpack_col_uid(child.typ.col_type, col_uid_field_name) new_entry = ir.Let('sa', old_col, new_entry) if self.new_entry.uses_value_randomness: - new_entry = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(rng_key), ir.concat_uids(row_uid, col_uid)), new_entry) + new_entry = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), ir.concat_uids(row_uid, col_uid)), new_entry) result = MatrixMapEntries(child, new_entry) if drop_row_uid: _, old_row = unpack_row_uid(result.typ.row_type, row_uid_field_name) @@ -446,7 +452,7 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): col_uid, old_col = unpack_col_uid(child.typ.col_type, col_uid_field_name) pred = ir.Let('sa', old_col, pred) if self.pred.uses_value_randomness: - pred = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(rng_key), ir.concat_uids(row_uid, col_uid)), pred) + pred = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), ir.concat_uids(row_uid, col_uid)), pred) result = MatrixFilterEntries(child, pred) if drop_row_uid: _, old_row = unpack_row_uid(result.typ.row_type, row_uid_field_name) @@ -520,12 +526,12 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): if col_uid_field_name is not None: col_uid, old_col = unpack_col_uid(child.typ.col_type, col_uid_field_name) if self.new_row.uses_value_randomness: - new_row = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(rng_key), row_uid), new_row) + new_row = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), row_uid), new_row) if self.new_row.uses_agg_randomness(is_scan=True): - new_row = ir.AggLet('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(rng_key), row_uid), new_row, is_scan=True) + new_row = ir.AggLet('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), row_uid), new_row, is_scan=True) if self.new_row.uses_agg_randomness(is_scan=False): entry_uid = ir.concat_uids(row_uid, col_uid) - new_row = ir.AggLet('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(rng_key), entry_uid), new_row, is_scan=False) + new_row = ir.AggLet('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), entry_uid), new_row, is_scan=False) if keep_row_uid: new_row = ir.InsertFields(new_row, [(row_uid_field_name, row_uid)], None) result = MatrixMapRows(child, new_row) @@ -571,7 +577,7 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): child = self.child.handle_randomness(row_uid_field_name, col_uid_field_name) new_global = self.new_global if new_global.uses_randomness: - new_global = ir.Let('__rng_state', ir.RNGStateLiteral(rng_key), new_global) + new_global = ir.Let('__rng_state', ir.RNGStateLiteral(), new_global) return MatrixMapGlobals(child, new_global) def _compute_type(self, deep_typecheck): @@ -608,7 +614,7 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): col_uid, old_col = unpack_col_uid(child.typ.col_type, col_uid_field_name) pred = ir.Let('sa', old_col, self.pred) if self.pred.uses_randomness: - pred = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(rng_key), col_uid), pred) + pred = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), col_uid), pred) result = MatrixFilterCols(child, pred) if drop_col_uid: result = MatrixMapCols(result, old_col, new_key=None) @@ -680,27 +686,28 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): entry_expr = ir.AggLet('va', old_row, entry_expr, is_scan=False) if col_uid_field_name is not None: col_uid, old_col = unpack_col_uid(child.typ.col_type, col_uid_field_name) - first_col_uid = ir.ApplyAggOp('Take', [1], [col_uid]) + first_col_uid = ir.ArrayRef(ir.ApplyAggOp('Take', [ir.I32(1)], [col_uid]), ir.I32(0)) entry_expr = ir.AggLet('sa', old_col, entry_expr, is_scan=False) col_expr = ir.AggLet('sa', old_col, col_expr, is_scan=False) + col_expr = ir.InsertFields(col_expr, [(col_uid_field_name, first_col_uid)], None) if self.entry_expr.uses_value_randomness: entry_expr = ir.Let('__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(rng_key), + ir.RNGSplit(ir.RNGStateLiteral(), ir.concat_uids(row_uid, first_col_uid)), entry_expr) if self.entry_expr.uses_agg_randomness(is_scan=False): entry_expr = ir.AggLet('__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(rng_key), + ir.RNGSplit(ir.RNGStateLiteral(), ir.concat_uids(row_uid, col_uid)), entry_expr, is_scan=False) if self.col_expr.uses_value_randomness: col_expr = ir.Let('__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(rng_key), first_col_uid), + ir.RNGSplit(ir.RNGStateLiteral(), first_col_uid), col_expr) if self.col_expr.uses_agg_randomness(is_scan=False): col_expr = ir.AggLet('__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(rng_key), col_uid), + ir.RNGSplit(ir.RNGStateLiteral(), col_uid), col_expr, is_scan=False) @@ -710,7 +717,7 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): result = MatrixMapRows(result, old_row) if drop_col_uid: _, old_col = unpack_col_uid(result.typ.col_type, col_uid_field_name) - result = MatrixMapCols(result, old_col) + result = MatrixMapCols(result, old_col, None) return result def _compute_type(self, deep_typecheck): @@ -770,7 +777,7 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): ir.Ref('va', new_explode.typ.row_type), self.path, lambda tuple: ir.GetTupleElement(tuple, 0), - lambda row, tuple: ir.InsertFields(row, (row_uid_field_name, ir.concat_uids(ir.GetField(row, row_uid_field_name), ir.GetTupleElement(tuple, 1))), None)) + lambda row, tuple: ir.InsertFields(row, [(row_uid_field_name, ir.concat_uids(ir.GetField(row, row_uid_field_name), ir.Cast(ir.GetTupleElement(tuple, 1), tint64)))], None)) return MatrixMapRows(new_explode, new_row) def head_str(self): @@ -825,7 +832,7 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): if row_uid_field_name is not None: uids = [uid for uid, _ in (unpack_row_uid(child.typ.row_type, row_uid_field_name) for child in children)] - uid_type = ir.unify_uid_types(uid.typ for uid in uids) + uid_type = ir.unify_uid_types((uid.typ for uid in uids), tag=True) children = [MatrixMapRows(child, ir.InsertFields(ir.Ref('va', child.typ.row_type), [(row_uid_field_name, ir.pad_uid(uid, uid_type, i))], @@ -949,15 +956,15 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): return MatrixExplodeCols(child, self.path) new_col = modify_deep_field(ir.Ref('sa', child.typ.col_type), self.path, zip_with_index) - child = MatrixMapCols(child, new_col) + child = MatrixMapCols(child, new_col, None) new_explode = MatrixExplodeCols(child, self.path) new_col = modify_deep_field( ir.Ref('sa', new_explode.typ.col_type), self.path, lambda tuple: ir.GetTupleElement(tuple, 0), - lambda col, tuple: ir.InsertFields(col, (col_uid_field_name, ir.concat_uids(ir.GetField(col, col_uid_field_name), ir.GetTupleElement(tuple, 1))), None)) - return MatrixMapCols(new_explode, new_col) + lambda col, tuple: ir.InsertFields(col, [(col_uid_field_name, ir.concat_uids(ir.GetField(col, col_uid_field_name), ir.Cast(ir.GetTupleElement(tuple, 1), tint64)))], None)) + return MatrixMapCols(new_explode, new_col, None) def head_str(self): return f"({' '.join([escape_id(id) for id in self.path])})" @@ -997,7 +1004,7 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): lambda g: zip_with_index_field(g, col_uid_field_name)) child = TableMapGlobals(child, new_globals) - return CastTableToMatrix(self.child.handle_randomness(row_uid_field_name), + return CastTableToMatrix(child.handle_randomness(row_uid_field_name), self.entries_field_name, self.cols_field_name, self.col_key) @@ -1102,15 +1109,9 @@ def __init__(self, child, config): self.config = config def _handle_randomness(self, row_uid_field_name, col_uid_field_name): - child = self.child.handle_randomness(None, None) - result = MatrixToMatrixApply(child, self.config) - if row_uid_field_name is not None: - new_row = ir.InsertFields(ir.Ref('va', result.typ.row_type), [(row_uid_field_name, ir.NA(tint64))], None) - result = MatrixMapRows(result, new_row) - if col_uid_field_name is not None: - new_col = ir.InsertFields(ir.Ref('sa', result.typ.col_type), [(col_uid_field_name, ir.NA(tint64))], None) - result = MatrixMapCols(result, new_col, None) - return result + assert self.config['name'] == 'MatrixFilterPartitions' + child = self.child.handle_randomness(row_uid_field_name, col_uid_field_name) + return MatrixToMatrixApply(child, self.config) def head_str(self): return dump_json(self.config) @@ -1198,14 +1199,7 @@ def __init__(self, jir): self._jir = jir def _handle_randomness(self, row_uid_field_name, col_uid_field_name): - result = self - if row_uid_field_name is not None: - new_row = ir.InsertFields(ir.Ref('va', result.typ.row_type), [(row_uid_field_name, ir.NA(tint64))], None) - result = MatrixMapRows(result, new_row) - if col_uid_field_name is not None: - new_col = ir.InsertFields(ir.Ref('sa', result.typ.col_type), [(col_uid_field_name, ir.NA(tint64))], None) - result = MatrixMapCols(result, new_col, None) - return result + raise FatalError('JavaMatrix does not support randomness in consumers') def render_head(self, r): return f'(JavaMatrix {r.add_jir(self._jir)}' @@ -1224,14 +1218,7 @@ def __init__(self, vec_ref, idx): self.idx = idx def _handle_randomness(self, row_uid_field_name, col_uid_field_name): - result = self - if row_uid_field_name is not None: - new_row = ir.InsertFields(ir.Ref('va', result.typ.row_type), [(row_uid_field_name, ir.NA(tint64))], None) - result = MatrixMapRows(result, new_row) - if col_uid_field_name is not None: - new_col = ir.InsertFields(ir.Ref('sa', result.typ.col_type), [(col_uid_field_name, ir.NA(tint64))], None) - result = MatrixMapCols(result, new_col, None) - return result + raise FatalError('JavaMatrix does not support randomness in consumers') def head_str(self): return f'{self.vec_ref.jid} {self.idx}' diff --git a/hail/python/hail/ir/table_ir.py b/hail/python/hail/ir/table_ir.py index 5ac0753bde2..8d9ad247f84 100644 --- a/hail/python/hail/ir/table_ir.py +++ b/hail/python/hail/ir/table_ir.py @@ -1,10 +1,11 @@ from typing import Optional import hail as hl -from hail.expr.types import dtype, tint32, tint64, trngstate +from hail.expr.types import dtype, tint32, tint64 from hail.ir.base_ir import BaseIR, TableIR import hail.ir.ir as ir from hail.ir.utils import modify_deep_field, zip_with_index, default_row_uid, default_col_uid -from hail.ir.ir import rng_key, unify_uid_types, pad_uid, concat_uids +from hail.ir.ir import unify_uid_types, pad_uid, concat_uids +from hail.utils import FatalError from hail.utils.java import Env from hail.utils.misc import escape_str, parsable_strings, escape_id from hail.utils.jsonx import dump_json @@ -54,17 +55,14 @@ def _handle_randomness(self, uid_field_name): joined = TableJoin(left, right, self.join_type, self.join_key) row = ir.Ref('row', joined.typ.row_type) - if '__left_uid' in joined.row_typ and '__right_uid' in joined.row_typ: - old_joined_row = ir.SelectFields(row, [field for field in self.typ.row_type]) - left_uid = ir.GetField(row, '__left_uid') - right_uid = ir.GetField(row, '__right_uid') - handle_missing_left = self.join_type == 'right' or self.join_type == 'outer' - handle_missing_right = self.join_type == 'left' or self.join_type == 'outer' - uid = concat_uids(left_uid, right_uid, handle_missing_left, handle_missing_right) - else: - old_joined_row = ir.SelectFields(row, [field for field in self.typ.row_type]) - uid = ir.NA(tint64) - TableMapRows(joined, ir.InsertFields(old_joined_row, [(uid_field_name, uid)]), None) + old_joined_row = ir.SelectFields(row, [field for field in self.typ.row_type]) + left_uid = ir.GetField(row, '__left_uid') + right_uid = ir.GetField(row, '__right_uid') + handle_missing_left = self.join_type == 'right' or self.join_type == 'outer' + handle_missing_right = self.join_type == 'left' or self.join_type == 'outer' + uid = concat_uids(left_uid, right_uid, handle_missing_left, handle_missing_right) + + return TableMapRows(joined, ir.InsertFields(old_joined_row, [(uid_field_name, uid)], None)) def head_str(self): return f'{escape_id(self.join_type)} {self.join_key}' @@ -158,12 +156,12 @@ def _handle_randomness(self, uid_field_name): new_children = [child.handle_randomness(uid_field_name) for child in self.children] - if uid_field_name is None or not all(uid_field_name in child.typ.row_type for child in new_children): + if not all(uid_field_name in child.typ.row_type for child in new_children): new_children = [child.handle_randomness(None) for child in self.children] return TableUnion(new_children) uids = [uid for uid, _ in (unpack_uid(child.typ.row_type, uid_field_name) for child in new_children)] - uid_type = unify_uid_types(uid.typ for uid in uids) + uid_type = unify_uid_types((uid.typ for uid in uids), tag=True) new_children = [TableMapRows(child, ir.InsertFields(ir.Ref('row', child.typ.row_type), [(uid_field_name, pad_uid(uid, uid_type, i))], None)) @@ -208,7 +206,7 @@ def __init__(self, child, new_globals): def _handle_randomness(self, uid_field_name): new_globals = self.new_globals if new_globals.uses_randomness: - new_globals = ir.Let('__rng_state', ir.RNGStateLiteral(rng_key), new_globals) + new_globals = ir.Let('__rng_state', ir.RNGStateLiteral(), new_globals) return TableMapGlobals(self.child.handle_randomness(uid_field_name), new_globals) @@ -235,9 +233,6 @@ def _handle_randomness(self, uid_field_name): child = self.child.handle_randomness(uid_field_name) - if uid_field_name not in child.typ.row_type.fields: - return TableExplode(child, self.path) - new_row = modify_deep_field(ir.Ref('row', child.typ.row_type), self.path, zip_with_index) child = TableMapRows(child, new_row) @@ -246,7 +241,7 @@ def _handle_randomness(self, uid_field_name): ir.Ref('row', new_explode.typ.row_type), self.path, lambda tuple: ir.GetTupleElement(tuple, 0), - lambda row, tuple: ir.InsertFields(row, (uid_field_name, concat_uids(ir.GetField(row, uid_field_name), ir.GetTupleElement(tuple, 1))), None)) + lambda row, tuple: ir.InsertFields(row, [(uid_field_name, concat_uids(ir.GetField(row, uid_field_name), ir.Cast(ir.GetTupleElement(tuple, 1), tint64)))], None)) return TableMapRows(new_explode, new_row) def head_str(self): @@ -297,13 +292,13 @@ def _handle_randomness(self, uid_field_name): child = self.child.handle_randomness(None) return TableMapRows(child, self.new_row) - child = self.child.handle_randomness(ir.uid_field_name) - uid, old_row = unpack_uid(child.typ.row_type, ir.uid_field_name) + child = self.child.handle_randomness(default_row_uid) + uid, old_row = unpack_uid(child.typ.row_type, default_row_uid) new_row = ir.Let('row', old_row, self.new_row) if self.new_row.uses_value_randomness: - new_row = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(rng_key), uid), new_row) + new_row = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), uid), new_row) if self.new_row.uses_agg_randomness(is_scan=True): - new_row = ir.AggLet('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(rng_key), uid), new_row, is_scan=True) + new_row = ir.AggLet('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), uid), new_row, is_scan=True) if uid_field_name is not None: new_row = ir.InsertFields(new_row, [(uid_field_name, uid)], None) return TableMapRows(child, new_row) @@ -338,17 +333,9 @@ def __init__(self, child, global_name, partition_stream_name, body): self.partition_stream_name = partition_stream_name def _handle_randomness(self, uid_field_name): - # FIXME: This is tricky. Might need to disallow randomness in body - # outside of the partition stream. Or just document that randomness - # outside of the partition stream uses the same state in every partition - # (might still be more efficient then computing once and broadcasting) - child = self.child - body = self.body - if self.child.uses_randomness: - child = child.handle_randomness(None) - if body.uses_randomness: - body = ir.Let('__rng_state', ir.NA(trngstate), body) - return TableMapPartitions(child, self.global_name, self.partition_stream_name, body) + if uid_field_name is not None: + raise FatalError('TableMapPartitions does not support randomness, in its body or in consumers') + return TableMapPartitions(self.child.handle_randomness(None), self.global_name, self.partition_stream_name, self.body) def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck) @@ -435,17 +422,18 @@ def __init__(self, child): self.child = child def _handle_randomness(self, uid_field_name): - from hail.ir.matrix_ir import MatrixMapEntries - # FIXME: Finish once done with MatrixIR + from hail.ir.matrix_ir import MatrixMapEntries, MatrixMapRows if uid_field_name is None: return MatrixEntriesTable(self.child.handle_randomness(None, None)) - child = self.child.handle_randomness(default_row_uid, default_col_uid) + temp_row_uid = Env.get_uid(default_row_uid) + child = self.child.handle_randomness(temp_row_uid, default_col_uid) entry = ir.Ref('g', child.typ.entry_type) - row_uid = ir.GetField(ir.Ref('va', child.typ.row_type), default_row_uid) - col_uid = ir.GetField(ir.Ref('sa', child.typ.row_type), default_col_uid) - child = MatrixMapEntries(child, ir.InsertFields(entry, [(uid_field_name, ir.concat_uids(row_uid, col_uid))], None)) - return MatrixEntriesTable(child) + row_uid = ir.GetField(ir.Ref('va', child.typ.row_type), temp_row_uid) + col_uid = ir.GetField(ir.Ref('sa', child.typ.col_type), default_col_uid) + child = MatrixMapEntries(child, ir.InsertFields(entry, [('__entry_uid', ir.concat_uids(row_uid, col_uid))], None)) + child = MatrixMapRows(child, ir.SelectFields(ir.Ref('va', child.typ.row_type), [field for field in child.typ.row_type if field != temp_row_uid])) + return TableRename(MatrixEntriesTable(child), {'__entry_uid': default_row_uid}, {}) def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck) @@ -475,7 +463,7 @@ def _handle_randomness(self, uid_field_name): uid, old_row = unpack_uid(child.typ.row_type, uid_field_name) pred = ir.Let('row', old_row, self.pred) if self.pred.uses_randomness: - pred = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(rng_key), uid), pred) + pred = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), uid), pred) result = TableFilter(child, pred) if drop_uid: result = TableMapRows(result, old_row) @@ -509,24 +497,24 @@ def _handle_randomness(self, uid_field_name): expr = self.expr if expr.uses_randomness or uid_field_name is not None: - first_uid = Env.get_uid() + first_uid = ir.Ref(Env.get_uid(), uid.typ) if expr.uses_randomness: expr = ir.Let( '__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(rng_key), ir.Ref(first_uid, uid.typ)), + ir.RNGSplit(ir.RNGStateLiteral(), first_uid), expr) if expr.uses_agg_randomness(is_scan=False): expr = ir.AggLet('__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(rng_key), uid), + ir.RNGSplit(ir.RNGStateLiteral(), uid), expr, is_scan=False) if uid_field_name is not None: - expr = ir.InsertFields(expr, [uid_field_name, uid], None) - expr = ir.Let(first_uid, ir.ApplyAggOp('Take', [1], [uid]), expr) + expr = ir.InsertFields(expr, [(uid_field_name, first_uid)], None) + expr = ir.Let(first_uid.name, ir.ArrayRef(ir.ApplyAggOp('Take', [ir.I32(1)], [uid]), ir.I32(0)), expr) new_key = self.new_key if new_key.uses_randomness: expr = ir.Let( '__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(rng_key), uid), + ir.RNGSplit(ir.RNGStateLiteral(), uid), new_key) return TableKeyByAndAggregate(child, expr, new_key, self.n_partitions, self.buffer_size) @@ -573,21 +561,21 @@ def _handle_randomness(self, uid_field_name): uid, old_row = unpack_uid(child.typ.row_type, ir.uid_field_name) expr = ir.AggLet('va', old_row, self.expr, is_scan=False) - first_uid = Env.get_uid() + first_uid = ir.Ref(Env.get_uid(), uid.typ) if expr.uses_value_randomness: expr = ir.Let( '__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(rng_key), ir.Ref(first_uid, uid.typ)), + ir.RNGSplit(ir.RNGStateLiteral(), first_uid), expr) if expr.uses_agg_randomness(is_scan=False): expr = ir.AggLet( '__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(rng_key), uid), + ir.RNGSplit(ir.RNGStateLiteral(), uid), expr, is_scan=False) if uid_field_name is not None: - expr = ir.InsertFields(expr, [uid_field_name, first_uid], None) - expr = ir.Let(first_uid, ir.ApplyAggOp('Take', [1], [uid]), expr) + expr = ir.InsertFields(expr, [(uid_field_name, first_uid)], None) + expr = ir.Let(first_uid.name, ir.ArrayRef(ir.ApplyAggOp('Take', [ir.I32(1)], [uid]), ir.I32(0)), expr) return TableAggregateByKey(child, expr) def _compute_type(self, deep_typecheck): @@ -636,25 +624,26 @@ def _handle_randomness(self, uid_field_name): if rows_and_global.uses_randomness: rows_and_global = ir.Let( '__rng_state', - ir.RNGStateLiteral(rng_key), + ir.RNGStateLiteral(), rows_and_global) if uid_field_name is not None: - rows_and_global_ref = Env.get_uid() + rows_and_global_ref = ir.Ref(Env.get_uid(), rows_and_global.typ) row = Env.get_uid() uid = Env.get_uid() iota = ir.StreamIota(ir.I32(0), ir.I32(1)) - new_rows = ir.StreamZip( - [ir.GetField(ir.Ref(rows_and_global_ref, rows_and_global.typ), 'rows'), iota], + rows = ir.ToStream(ir.GetField(rows_and_global_ref, 'rows')) + new_rows = ir.ToArray(ir.StreamZip( + [rows, iota], [row, uid], ir.InsertFields( - ir.Ref(row, rows_and_global.typ.element_type), + ir.Ref(row, rows.typ.element_type), [(uid_field_name, ir.Cast(ir.Ref(uid, tint32), tint64))], None), - 'TakeMinLength') + 'TakeMinLength')) rows_and_global = \ - ir.Let(rows_and_global_ref, rows_and_global, + ir.Let(rows_and_global_ref.name, rows_and_global, ir.InsertFields( - ir.Ref(rows_and_global_ref, rows_and_global.typ), + rows_and_global_ref, [('rows', new_rows)], None)) return TableParallelize(rows_and_global, self.n_partitions) @@ -836,7 +825,7 @@ def _handle_randomness(self, uid_field_name): new_children = [child.handle_randomness(uid_field_name) for child in self.children] uids = [uid for uid, _ in (unpack_uid(child.typ.row_type, uid_field_name) for child in new_children)] - uid_type = unify_uid_types(uid.typ for uid in uids) + uid_type = unify_uid_types((uid.typ for uid in uids), tag=True) new_children = [ TableMapRows( child, @@ -845,13 +834,13 @@ def _handle_randomness(self, uid_field_name): pad_uid(uid, uid_type, i))], None)) for i, (child, uid) in enumerate(zip(new_children, uids))] joined = TableMultiWayZipJoin(new_children, self.data_name, self.global_name) - accum = Env.get_uid() + accum = ir.Ref(Env.get_uid(), uid_type) elt = Env.get_uid() row = ir.Ref('row', joined.typ.row_type) data = ir.GetField(row, self.data_name) uid = ir.StreamFold( - ir.ToStream(data), ir.NA(uid_type), accum, elt, - ir.If(ir.IsNA(ir.Ref(accum, uid_type)), + ir.toStream(data), ir.NA(uid_type), accum.name, elt, + ir.If(ir.IsNA(accum), ir.GetField( ir.Ref(elt, data.typ.element_type), uid_field_name), @@ -903,12 +892,10 @@ def __init__(self, child, config): self.config = config def _handle_randomness(self, uid_field_name): - child = self.child.handle_randomness(None) - result = TableToTableApply(child, self.config) if uid_field_name is not None: - new_row = ir.InsertFields(ir.Ref('row', result.typ.row_type), [(uid_field_name, ir.NA(tint64))], None) - result = TableMapRows(result, new_row) - return result + raise FatalError('TableToTableApply does not support randomness in consumers') + child = self.child.handle_randomness(None) + return TableToTableApply(child, self.config) def head_str(self): return dump_json(self.config) @@ -950,12 +937,10 @@ def __init__(self, child, config): self.config = config def _handle_randomness(self, uid_field_name): - child = self.child.handle_randomness(None, None) - result = MatrixToTableApply(child, self.config) if uid_field_name is not None: - new_row = ir.InsertFields(ir.Ref('row', result.typ.row_type), [(uid_field_name, ir.NA(tint64))], None) - result = TableMapRows(result, new_row) - return result + raise FatalError('TableToTableApply does not support randomness in consumers') + child = self.child.handle_randomness(None, None) + return MatrixToTableApply(child, self.config) def head_str(self): return dump_json(self.config) @@ -1047,11 +1032,7 @@ def __init__(self, bm, aux, config): self.config = config def _handle_randomness(self, uid_field_name): - result = self - if uid_field_name is not None: - new_row = ir.InsertFields(ir.Ref('row', result.typ.row_type), [(uid_field_name, ir.NA(tint64))], None) - result = TableMapRows(result, new_row) - return result + raise FatalError('BlockMatrixToTableApply does not support randomness in consumers') def head_str(self): return dump_json(self.config) @@ -1082,7 +1063,9 @@ def __init__(self, child): def _handle_randomness(self, uid_field_name): result = self if uid_field_name is not None: - new_row = ir.InsertFields(ir.Ref('row', result.typ.row_type), [(uid_field_name, ir.NA(tint64))], None) + row = ir.Ref('row', result.typ.row_type) + new_row = ir.InsertFields(row, [(uid_field_name, + ir.MakeTuple([ir.GetField(row, 'i'), ir.GetField(row, 'j')]))], None) result = TableMapRows(result, new_row) return result @@ -1097,11 +1080,7 @@ def __init__(self, jir): self._jir = jir def _handle_randomness(self, uid_field_name): - result = self - if uid_field_name is not None: - new_row = ir.InsertFields(ir.Ref('row', result.typ.row_type), [(uid_field_name, ir.NA(tint64))], None) - result = TableMapRows(result, new_row) - return result + raise FatalError('JavaTable does not support randomness in consumers') def render_head(self, r): return f'(JavaTable {r.add_jir(self._jir)}' diff --git a/hail/python/hail/ir/utils.py b/hail/python/hail/ir/utils.py index aa43ac4d13a..bbbf8b8aa36 100644 --- a/hail/python/hail/ir/utils.py +++ b/hail/python/hail/ir/utils.py @@ -4,10 +4,10 @@ from hail.expr.types import tint32, tint64 -def finalize_randomness(x, key=(0, 0, 0, 0)): +def finalize_randomness(x): import hail.ir.ir as ir if isinstance(x, ir.IR): - x = ir.Let('__rng_state', ir.RNGStateLiteral(key), x) + x = ir.Let('__rng_state', ir.RNGStateLiteral(), x) elif isinstance(x, ir.TableIR): x = x.handle_randomness(None) elif isinstance(x, ir.MatrixIR): @@ -17,7 +17,6 @@ def finalize_randomness(x, key=(0, 0, 0, 0)): default_row_uid = '__row_uid' default_col_uid = '__col_uid' -rng_key = (0, 1, 2, 3) def unpack_row_uid(new_row_type, uid_field_name): @@ -46,39 +45,40 @@ def modify_deep_field(struct, path, new_deep_field, new_struct=None): import hail.ir.ir as ir refs = [struct] for i in range(len(path)): - refs[i + 1] = ir.Ref(Env.gen_uid(), refs[i].typ[path[i]]) + refs.append(ir.Ref(Env.get_uid(), refs[i].typ[path[i]])) acc = new_deep_field(refs[-1]) - for parent_struct, field_name in reversed(zip(refs[:-1], path)): + for parent_struct, field_name in zip(refs[-2::-1], path[::-1]): acc = ir.InsertFields(parent_struct, [(field_name, acc)], None) - acc = new_struct(acc, refs[-1]) - for struct_ref, field_ref, field_name in reversed(zip(refs[:-1], refs[1:], path)): - acc = ir.Let(field_ref.name, ir.GetField(struct_ref, field_name)) + if new_struct is not None: + acc = new_struct(acc, refs[-1]) + for struct_ref, field_ref, field_name in zip(refs[-2::-1], refs[:0:-1], path[::-1]): + acc = ir.Let(field_ref.name, ir.GetField(struct_ref, field_name), acc) return acc def zip_with_index(array): import hail.ir.ir as ir - elt = Env.gen_uid() - inner_row_uid = Env.gen_uid() + elt = Env.get_uid() + inner_row_uid = Env.get_uid() iota = ir.StreamIota(ir.I32(0), ir.I32(1)) - return ir.StreamZip( - [ir.ToStream(array), iota], + return ir.toArray(ir.StreamZip( + [ir.toStream(array), iota], [elt, inner_row_uid], - ir.MakeTuple(ir.Ref(elt, array.typ.element_type), ir.Ref(inner_row_uid, tint32)), - 'TakeMinLength') + ir.MakeTuple((ir.Ref(elt, array.typ.element_type), ir.Ref(inner_row_uid, tint32))), + 'TakeMinLength')) def zip_with_index_field(array, idx_field_name): import hail.ir.ir as ir - elt = Env.gen_uid() - inner_row_uid = Env.gen_uid() + elt = Env.get_uid() + inner_row_uid = Env.get_uid() iota = ir.StreamIota(ir.I32(0), ir.I32(1)) - return ir.StreamZip( - [ir.ToStream(array), iota], + return ir.toArray(ir.StreamZip( + [ir.toStream(array), iota], [elt, inner_row_uid], - ir.InsertFields(ir.Ref(elt, array.typ.element_type), [(idx_field_name, ir.Ref(inner_row_uid, tint32))], None), - 'TakeMinLength') + ir.InsertFields(ir.Ref(elt, array.typ.element_type), [(idx_field_name, ir.Cast(ir.Ref(inner_row_uid, tint32), tint64))], None), + 'TakeMinLength')) def impute_type_of_partition_interval_array( diff --git a/hail/python/hail/linalg/blockmatrix.py b/hail/python/hail/linalg/blockmatrix.py index 4df6eef2ce0..6d6293eaa61 100644 --- a/hail/python/hail/linalg/blockmatrix.py +++ b/hail/python/hail/linalg/blockmatrix.py @@ -458,9 +458,9 @@ def random(cls, n_rows, n_cols, block_size=None, seed=None, gaussian=True) -> 'B if not block_size: block_size = BlockMatrix.default_block_size() - seed = seed if seed is not None else Env.next_seed() + static_rng_uid = seed if seed is not None else Env.next_static_rng_uid() - rand = BlockMatrixRandom(seed, gaussian, [n_rows, n_cols], block_size) + rand = BlockMatrixRandom(static_rng_uid, gaussian, [n_rows, n_cols], block_size) return BlockMatrix(rand) @classmethod diff --git a/hail/python/hail/matrixtable.py b/hail/python/hail/matrixtable.py index 3cc2125bfa0..895bbd6c0bb 100644 --- a/hail/python/hail/matrixtable.py +++ b/hail/python/hail/matrixtable.py @@ -1991,7 +1991,7 @@ def aggregate_rows(self, expr, _localize=True) -> Any: >>> dataset.aggregate_rows(hl.struct(n_high_quality=hl.agg.count_where(dataset.qual > 40), ... mean_qual=hl.agg.mean(dataset.qual))) - Struct(n_high_quality=13, mean_qual=544323.8915384616) + Struct(n_high_quality=9, mean_qual=140054.73333333334) Notes ----- @@ -2041,7 +2041,7 @@ def aggregate_cols(self, expr, _localize=True) -> Any: >>> dataset.aggregate_cols( ... hl.struct(fraction_female=hl.agg.fraction(dataset.pheno.is_female), ... case_ratio=hl.agg.count_where(dataset.is_case) / hl.agg.count())) - Struct(fraction_female=0.48, case_ratio=1.0) + Struct(fraction_female=0.44, case_ratio=1.0) Notes ----- @@ -2091,7 +2091,7 @@ def aggregate_entries(self, expr, _localize=True): >>> dataset.aggregate_entries(hl.struct(global_gq_mean=hl.agg.mean(dataset.GQ), ... call_rate=hl.agg.fraction(hl.is_defined(dataset.GT)))) - Struct(global_gq_mean=64.01841473178543, call_rate=0.9607692307692308) + Struct(global_gq_mean=69.60514541387025, call_rate=0.9933333333333333) Notes ----- diff --git a/hail/python/hail/methods/impex.py b/hail/python/hail/methods/impex.py index 325e0065fed..1157a0a3308 100644 --- a/hail/python/hail/methods/impex.py +++ b/hail/python/hail/methods/impex.py @@ -1625,17 +1625,17 @@ def import_table(paths, if should_remove_line_expr is not None: ht = ht.filter(should_remove_line_expr, keep=False) - if len(paths) <= 1: - # With zero or one files and no filters, the first row, if it exists must be in the first - # partition, so we take this one-pass fast-path. - first_row_ht = ht._filter_partitions([0]).head(1) - else: - first_row_ht = ht.head(1) + try: + if len(paths) <= 1: + # With zero or one files and no filters, the first row, if it exists must be in the first + # partition, so we take this one-pass fast-path. + first_row_ht = ht._filter_partitions([0]).head(1) + else: + first_row_ht = ht.head(1) - if find_replace is not None: - ht = ht.annotate(text=ht['text'].replace(*find_replace)) + if find_replace is not None: + ht = ht.annotate(text=ht['text'].replace(*find_replace)) - try: first_rows = first_row_ht.annotate( header=first_row_ht.text._split_line( delimiter, missing=hl.empty_array(hl.tstr), quote=quote, regex=len(delimiter) > 1) diff --git a/hail/python/hail/methods/pca.py b/hail/python/hail/methods/pca.py index 05b90d22691..e58029649d4 100644 --- a/hail/python/hail/methods/pca.py +++ b/hail/python/hail/methods/pca.py @@ -229,7 +229,6 @@ def __init__(self, block_table, block_expr, source_table, col_key): def _make_tsm(entry_expr, block_size): mt = matrix_table_source('_make_tsm/entry_expr', entry_expr) A, ht = mt_to_table_of_ndarray(entry_expr, block_size, return_checkpointed_table_also=True) - A = A.persist() return TallSkinnyMatrix(A, A.ndarray, ht, list(mt.col_key)) @@ -259,7 +258,6 @@ def _make_tsm_from_call(call_expr, block_size, mean_center=False, hwe_normalize= mt = mt.select_entries(__x=mt.__gt) A, ht = mt_to_table_of_ndarray(mt.__x, block_size, return_checkpointed_table_also=True) - A = A.persist() return TallSkinnyMatrix(A, A.ndarray, ht, list(mt.col_key)) @@ -366,7 +364,7 @@ def _reduced_svd(A: TallSkinnyMatrix, k=10, compute_U=False, iterations=2, itera n = A.ncols # Generate random matrix G - G = hl.nd.zeros((n, L)).map(lambda n: hl.rand_norm(0, 1)) + G = hl.rand_norm(0, 1, size=(n, L)) G = hl.nd.qr(G)[0]._persist() fact = _krylov_factorization(A, G, q, compute_U) @@ -392,7 +390,7 @@ def _spectral_moments(A, num_moments, p=None, moment_samples=500, block_size=128 # TODO: When moment_samples > n, we should just do a TSQR on A, and compute # the spectrum of R. assert moment_samples < n, '_spectral_moments: moment_samples must be smaller than num cols of A' - G = hl.nd.zeros((n, moment_samples)).map(lambda n: hl.if_else(hl.rand_bool(0.5), -1, 1)) + G = hl.rand_unif(-1, 1, size=(n, moment_samples)).map(lambda x: hl.sign(x)) Q1, R1 = hl.nd.qr(G)._persist() fact = _krylov_factorization(A, Q1, p, compute_U=False) moments_and_stdevs = hl.eval(fact.spectral_moments(num_moments, R1)) @@ -423,7 +421,7 @@ def _pca_and_moments(A, k=10, num_moments=5, compute_loadings=False, q_iteration n = A.ncols # Generate random matrix G - G = hl.nd.zeros((n, L)).map(lambda n: hl.rand_norm(0, 1)) + G = hl.rand_norm(0, 1, size=(n, L)) G = hl.nd.qr(G)[0]._persist() fact = _krylov_factorization(A, G, q, compute_loadings) @@ -433,7 +431,7 @@ def _pca_and_moments(A, k=10, num_moments=5, compute_loadings=False, q_iteration p = min(num_moments // 2, 10) # Generate random matrix G2 for moment estimation - G2 = hl.nd.zeros((n, moment_samples)).map(lambda n: hl.if_else(hl.rand_bool(0.5), -1, 1)) + G2 = hl.rand_unif(-1, 1, size=(n, moment_samples)).map(lambda x: hl.sign(x)) # Project out components in subspace fact.V, which we can compute exactly G2 = G2 - fact.V @ (fact.V.T @ G2) Q1, R1 = hl.nd.qr(G2)._persist() diff --git a/hail/python/hail/methods/relatedness/identity_by_descent.py b/hail/python/hail/methods/relatedness/identity_by_descent.py index 32256d85e7d..58e2987be4c 100644 --- a/hail/python/hail/methods/relatedness/identity_by_descent.py +++ b/hail/python/hail/methods/relatedness/identity_by_descent.py @@ -101,4 +101,4 @@ def identity_by_descent(dataset, maf=None, bounded=True, min=None, max=None) -> 'bounded': bounded, 'min': min, 'max': max, - })) + })).persist() diff --git a/hail/python/hail/methods/relatedness/pc_relate.py b/hail/python/hail/methods/relatedness/pc_relate.py index 47cabf44360..37be5b04f6e 100644 --- a/hail/python/hail/methods/relatedness/pc_relate.py +++ b/hail/python/hail/methods/relatedness/pc_relate.py @@ -346,7 +346,7 @@ def pc_relate(call_expr: CallExpression, 'maf': min_individual_maf, 'blockSize': block_size, 'minKinship': min_kinship, - 'statistics': {'kin': 0, 'kin2': 1, 'kin20': 2, 'all': 3}[statistics]})) + 'statistics': {'kin': 0, 'kin2': 1, 'kin20': 2, 'all': 3}[statistics]})).persist() if statistics == 'kin': ht = ht.drop('ibd0', 'ibd1', 'ibd2') @@ -359,7 +359,7 @@ def pc_relate(call_expr: CallExpression, ht = ht.filter(ht.i == ht.j, keep=False) col_keys = hl.literal(mt.select_cols().key_cols_by().cols().collect(), dtype=tarray(mt.col_key.dtype)) - return ht.key_by(i=col_keys[ht.i], j=col_keys[ht.j]) + return ht.key_by(i=col_keys[ht.i], j=col_keys[ht.j]).persist() def _bad_mu(mu: Float64Expression, maf: float) -> BooleanExpression: diff --git a/hail/python/hail/methods/statgen.py b/hail/python/hail/methods/statgen.py index 0ebf5fe5b17..b0a49023d2e 100644 --- a/hail/python/hail/methods/statgen.py +++ b/hail/python/hail/methods/statgen.py @@ -596,6 +596,9 @@ def process_partition(part): res = res.select_globals() + temp_file_name = hl.utils.new_temp_file("_linear_regression_rows_nd", "result") + res = res.checkpoint(temp_file_name) + return res @@ -1724,7 +1727,7 @@ def skat(key_expr, 'logistic_tolerance': tolerance } - return Table(ir.MatrixToTableApply(mt._mir, config)) + return Table(ir.MatrixToTableApply(mt._mir, config)).persist() @typecheck(p_value=expr_numeric, @@ -2610,7 +2613,7 @@ def balding_nichols_model(n_populations: int, Generate a matrix table of genotypes with 1000 variants and 100 samples across 3 populations: - >>> hl.set_global_seed(1) + >>> hl.reset_global_randomness() >>> bn_ds = hl.balding_nichols_model(3, 100, 1000) >>> bn_ds.show(n_rows=5, n_cols=5) +---------------+------------+------+------+------+------+------+ @@ -2618,18 +2621,18 @@ def balding_nichols_model(n_populations: int, +---------------+------------+------+------+------+------+------+ | locus | array | call | call | call | call | call | +---------------+------------+------+------+------+------+------+ - | 1:1 | ["A","C"] | 1/1 | 0/1 | 0/1 | 1/1 | 0/1 | - | 1:2 | ["A","C"] | 1/1 | 0/0 | 0/1 | 0/0 | 0/1 | - | 1:3 | ["A","C"] | 0/1 | 1/1 | 1/1 | 0/1 | 0/1 | - | 1:4 | ["A","C"] | 0/1 | 1/1 | 1/1 | 1/1 | 0/1 | - | 1:5 | ["A","C"] | 0/1 | 0/0 | 0/1 | 0/0 | 0/1 | + | 1:1 | ["A","C"] | 0/0 | 1/1 | 1/1 | 0/1 | 0/0 | + | 1:2 | ["A","C"] | 0/0 | 1/1 | 0/1 | 0/1 | 1/1 | + | 1:3 | ["A","C"] | 0/0 | 0/1 | 0/0 | 0/0 | 0/0 | + | 1:4 | ["A","C"] | 0/1 | 0/1 | 0/1 | 1/1 | 1/1 | + | 1:5 | ["A","C"] | 0/1 | 0/1 | 0/1 | 0/0 | 0/1 | +---------------+------------+------+------+------+------+------+ showing top 5 rows showing the first 5 of 100 columns Generate a dataset as above but with phased genotypes: - >>> hl.set_global_seed(1) + >>> hl.reset_global_randomness() >>> bn_ds = hl.balding_nichols_model(3, 100, 1000, phased=True) >>> bn_ds.show(n_rows=5, n_cols=5) +---------------+------------+------+------+------+------+------+ @@ -2637,11 +2640,11 @@ def balding_nichols_model(n_populations: int, +---------------+------------+------+------+------+------+------+ | locus | array | call | call | call | call | call | +---------------+------------+------+------+------+------+------+ - | 1:1 | ["A","C"] | 0|1 | 0|0 | 0|1 | 0|0 | 0|0 | - | 1:2 | ["A","C"] | 0|1 | 1|1 | 0|0 | 1|1 | 0|0 | - | 1:3 | ["A","C"] | 1|1 | 0|0 | 0|0 | 0|0 | 1|0 | - | 1:4 | ["A","C"] | 1|1 | 0|0 | 0|1 | 0|0 | 1|1 | - | 1:5 | ["A","C"] | 1|0 | 1|0 | 0|0 | 1|0 | 0|1 | + | 1:1 | ["A","C"] | 0|0 | 0|0 | 0|1 | 0|1 | 1|0 | + | 1:2 | ["A","C"] | 1|1 | 0|1 | 0|0 | 0|0 | 0|1 | + | 1:3 | ["A","C"] | 0|0 | 0|0 | 1|0 | 1|0 | 0|0 | + | 1:4 | ["A","C"] | 1|1 | 1|1 | 1|0 | 0|1 | 0|1 | + | 1:5 | ["A","C"] | 1|1 | 0|1 | 0|1 | 1|0 | 1|1 | +---------------+------------+------+------+------+------+------+ showing top 5 rows showing the first 5 of 100 columns @@ -2652,7 +2655,7 @@ def balding_nichols_model(n_populations: int, frequencies drawn from a truncated beta distribution with ``a = 0.01`` and ``b = 0.05`` over the interval ``[0.05, 1]``, and random seed 1: - >>> hl.set_global_seed(1) + >>> hl.reset_global_randomness() >>> bn_ds = hl.balding_nichols_model(4, 40, 150, 3, ... pop_dist=[0.1, 0.2, 0.3, 0.4], ... fst=[.02, .06, .04, .12], @@ -3258,7 +3261,7 @@ def _local_ld_prune(mt, call_field, r2=0.2, bp_window_size=1000000, memory_per_c 'r2Threshold': float(r2), 'windowSize': bp_window_size, 'maxQueueSize': max_queue_size - })) + })).persist() @typecheck(call_expr=expr_call, diff --git a/hail/python/hail/table.py b/hail/python/hail/table.py index 1f0a97b7f50..2ba03b88c6b 100644 --- a/hail/python/hail/table.py +++ b/hail/python/hail/table.py @@ -3722,7 +3722,7 @@ def summarize(self, handler=None): @typecheck_method(parts=sequenceof(int), keep=bool) def _filter_partitions(self, parts, keep=True) -> 'Table': - return Table(ir.TableToTableApply(self._tir, {'name': 'TableFilterPartitions', 'parts': parts, 'keep': keep})) + return Table(ir.TableToTableApply(self._tir, {'name': 'TableFilterPartitions', 'parts': parts, 'keep': keep})).persist() @typecheck_method(entries_field_name=str, cols_field_name=str, diff --git a/hail/python/hail/utils/__init__.py b/hail/python/hail/utils/__init__.py index db24c963fc7..67ae376acbc 100644 --- a/hail/python/hail/utils/__init__.py +++ b/hail/python/hail/utils/__init__.py @@ -1,6 +1,6 @@ from .misc import (wrap_to_list, get_env_or_default, uri_path, local_path_uri, new_temp_file, new_local_temp_dir, new_local_temp_file, with_local_temp_file, storage_level, - range_matrix_table, range_table, run_command, HailSeedGenerator, timestamp_path, + range_matrix_table, range_table, run_command, timestamp_path, _dumps_partitions, default_handler, guess_cloud_spark_provider, no_service_backend) from .hadoop_utils import (hadoop_copy, hadoop_open, hadoop_exists, hadoop_is_dir, hadoop_is_file, hadoop_ls, hadoop_scheme_supported, hadoop_stat, copy_log) @@ -41,7 +41,6 @@ 'HailUserError', 'range_table', 'range_matrix_table', - 'HailSeedGenerator', 'LinkedList', 'get_1kg', 'get_hgdp', diff --git a/hail/python/hail/utils/java.py b/hail/python/hail/utils/java.py index 8e01d04cc48..659034dc10c 100644 --- a/hail/python/hail/utils/java.py +++ b/hail/python/hail/utils/java.py @@ -2,7 +2,6 @@ import sys import re -import hail from hailtop.config import configuration_of @@ -39,7 +38,7 @@ class Env: _jutils = None _hc = None _counter = 0 - _seed_generator = None + _static_rng_uid = 0 @staticmethod def get_uid(base=None): @@ -126,14 +125,15 @@ def dummy_table(): return Env._dummy_table @staticmethod - def set_seed(seed): - Env._seed_generator = hail.utils.HailSeedGenerator(seed) + def next_static_rng_uid(): + result = Env._static_rng_uid + assert(result <= 0xFFFF_FFFF_FFFF_FFFF) + Env._static_rng_uid += 1 + return result @staticmethod - def next_seed(): - if Env._seed_generator is None: - Env.set_seed(None) - return Env._seed_generator.next_seed() + def reset_global_randomness(): + Env._static_rng_uid = 0 def scala_object(jpackage, name): diff --git a/hail/python/hail/utils/misc.py b/hail/python/hail/utils/misc.py index 5c8c5353a9b..6314944b9ba 100644 --- a/hail/python/hail/utils/misc.py +++ b/hail/python/hail/utils/misc.py @@ -11,7 +11,6 @@ from collections import defaultdict, Counter from contextlib import contextmanager from io import StringIO -from random import Random from typing import Optional from urllib.parse import urlparse @@ -520,18 +519,6 @@ def lookup_bit(byte, which_bit): return (byte >> which_bit) & 1 -class HailSeedGenerator(object): - def __init__(self, seed): - self.seed = seed - self.generator = Random(seed) - - def set_seed(self, seed): - self.__init__(seed) - - def next_seed(self): - return self.generator.randint(0, (1 << 63) - 1) - - def timestamp_path(base, suffix=''): return ''.join([base, '-', diff --git a/hail/python/test/hail/conftest.py b/hail/python/test/hail/conftest.py index acd03b2cc6b..ede4f35100c 100644 --- a/hail/python/test/hail/conftest.py +++ b/hail/python/test/hail/conftest.py @@ -4,8 +4,9 @@ import pytest -from hail import current_backend +from hail import current_backend, init, reset_global_randomness from hail.backend.service_backend import ServiceBackend +from .helpers import startTestHailContext, stopTestHailContext def pytest_collection_modifyitems(config, items): @@ -33,8 +34,20 @@ def ensure_event_loop_is_initialized_in_test_thread(): asyncio.set_event_loop(asyncio.new_event_loop()) +@pytest.fixture(scope="session", autouse=True) +def init_hail(): + startTestHailContext() + yield + stopTestHailContext() + + +@pytest.fixture(autouse=True) +def reset_randomness(init_hail): + reset_global_randomness() + + @pytest.fixture(autouse=True) -def set_query_name(request): +def set_query_name(init_hail, request): backend = current_backend() if isinstance(backend, ServiceBackend): backend.batch_attributes = dict(name=request.node.name) diff --git a/hail/python/test/hail/experimental/test_annotation_db.py b/hail/python/test/hail/experimental/test_annotation_db.py index 2fd40e3929b..376ecb31fad 100644 --- a/hail/python/test/hail/experimental/test_annotation_db.py +++ b/hail/python/test/hail/experimental/test_annotation_db.py @@ -1,27 +1,25 @@ -import unittest +import pytest import hail as hl from hail.backend.service_backend import ServiceBackend -from ..helpers import startTestHailContext, stopTestHailContext -class AnnotationDBTests(unittest.TestCase): - @classmethod - def setupAnnotationDBTests(cls): - startTestHailContext() +class TestAnnotationDB: + @pytest.fixture(scope="class") + def db_json(init_hail): backend = hl.current_backend() if isinstance(backend, ServiceBackend): backend.batch_attributes = dict(name='setupAnnotationDBTests') t = hl.utils.range_table(10) t = t.key_by(locus=hl.locus('1', t.idx + 1)) t = t.annotate(annotation=hl.str(t.idx)) - cls.tempdir_manager = hl.TemporaryDirectory() - d = cls.tempdir_manager.__enter__() + tempdir_manager = hl.TemporaryDirectory() + d = tempdir_manager.__enter__() fname = d + '/f.mt' t.write(fname) if isinstance(backend, ServiceBackend): backend.batch_attributes = dict() - cls.db_json = { + db_json = { 'unique_dataset': { 'description': 'now with unique rows!', 'url': 'https://example.com', @@ -46,16 +44,12 @@ def setupAnnotationDBTests(cls): } } - @classmethod - def tearDownAnnotationDBTests(cls): - stopTestHailContext() - cls.tempdir_manager.__exit__(None, None, None) + yield db_json - setUpClass = setupAnnotationDBTests - tearDownClass = tearDownAnnotationDBTests + tempdir_manager.__exit__(None, None, None) - def test_uniqueness(self): - db = hl.experimental.DB(region='us', cloud='gcp', config=AnnotationDBTests.db_json) + def test_uniqueness(self, db_json): + db = hl.experimental.DB(region='us', cloud='gcp', config=db_json) t = hl.utils.range_table(10) t = t.key_by(locus=hl.locus('1', t.idx + 1)) t = db.annotate_rows_db(t, 'unique_dataset', 'nonunique_dataset') diff --git a/hail/python/test/hail/experimental/test_codec.py b/hail/python/test/hail/experimental/test_codec.py index 6a6405c1de1..8c7dfe44f71 100644 --- a/hail/python/test/hail/experimental/test_codec.py +++ b/hail/python/test/hail/experimental/test_codec.py @@ -1,9 +1,6 @@ import hail as hl from test.hail.helpers import * -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - UNBLOCKED_UNBUFFERED_SPEC = '{"name":"StreamBufferSpec"}' diff --git a/hail/python/test/hail/experimental/test_experimental.py b/hail/python/test/hail/experimental/test_experimental.py index 8be69c457d3..2a5aed448ad 100644 --- a/hail/python/test/hail/experimental/test_experimental.py +++ b/hail/python/test/hail/experimental/test_experimental.py @@ -5,9 +5,6 @@ from ..helpers import * from hail.utils import new_temp_file -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - class Tests(unittest.TestCase): @fails_service_backend() diff --git a/hail/python/test/hail/experimental/test_vcf_combiner.py b/hail/python/test/hail/experimental/test_vcf_combiner.py index d04650edab1..6cd293075b1 100644 --- a/hail/python/test/hail/experimental/test_vcf_combiner.py +++ b/hail/python/test/hail/experimental/test_vcf_combiner.py @@ -4,10 +4,7 @@ from hail.experimental.vcf_combiner import vcf_combiner as vc from hail.utils.java import Env from hail.utils.misc import new_temp_file -from ..helpers import resource, startTestHailContext, stopTestHailContext, fails_local_backend, fails_service_backend - -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext +from ..helpers import resource, fails_local_backend, fails_service_backend all_samples = ['HG00308', 'HG00592', 'HG02230', 'NA18534', 'NA20760', diff --git a/hail/python/test/hail/expr/test_expr.py b/hail/python/test/hail/expr/test_expr.py index f48edcbdd34..33c6a3a3417 100644 --- a/hail/python/test/hail/expr/test_expr.py +++ b/hail/python/test/hail/expr/test_expr.py @@ -8,11 +8,9 @@ import hail.expr.aggregators as agg from hail.expr.types import * from hail.expr.functions import _error_from_cdf +import hail.ir as ir from ..helpers import * -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - def _test_many_equal(test_cases): expressions = [t[0] for t in test_cases] @@ -90,8 +88,6 @@ def test_zeros(self): evaled = hl.eval(hl.zeros(size)) assert evaled == [0 for i in range(size)] - @fails_service_backend() - @fails_local_backend() def test_seeded_sampling(self): sampled1 = hl.utils.range_table(50, 6).filter(hl.rand_bool(0.5)) sampled2 = hl.utils.range_table(50, 5).filter(hl.rand_bool(0.5)) @@ -2808,8 +2804,6 @@ def test_show_expression(self): +---------+ ''' - @fails_service_backend() - @fails_local_backend() def test_export_genetic_data(self): mt = hl.balding_nichols_model(1, 3, 3) mt = mt.key_cols_by(s = 's' + hl.str(mt.sample_idx)) @@ -3896,3 +3890,116 @@ def test_export_entry(delimiter, missing, header): 1, 2, 3, 2, None, 6] assert expected_collect == actual.x.collect() + + +def test_stream_randomness(): + def assert_contains_node(expr, node): + assert(expr._ir.base_search(lambda x: isinstance(x, node))) + + def assert_unique_uids(a): + n1 = hl.eval(a.to_array().length()) + n2 = len(hl.eval(hl.set(a.map(lambda x: hl.rand_int64()).to_array()))) + assert(n1 == n2) + + # test NA + a = hl.missing('array') + a = a.map(lambda x: x + hl.rand_int32(10)) + assert_contains_node(a, ir.NA) + assert(hl.eval(a) == None) + + # test If + a1 = hl._stream_range(0, 5) + a2 = hl._stream_range(2, 20) + a = hl.if_else(False, a1, a2) + assert_contains_node(a, ir.If) + assert_unique_uids(a) + + # test StreamIota + s = hl._stream_range(10).zip_with_index(0) + assert_contains_node(s, ir.StreamIota) + assert_unique_uids(s) + + # test ToArray + a = hl._stream_range(10) + a = a.map(lambda x: hl.rand_int64()).to_array() + assert_contains_node(a, ir.ToArray) + assert(len(set(hl.eval(a))) == 10) + + # test ToStream + t = hl.rbind(hl.range(10), + lambda a: (a, a.map(lambda x: hl.rand_int64()))) + assert_contains_node(t, ir.ToStream) + (a, r) = hl.eval(t) + assert(len(set(r)) == len(a)) + + # test StreamZip + a1 = hl._stream_range(10) + a2 = hl._stream_range(15) + a = hl._zip_streams(a1, a2, fill_missing=True) + assert_contains_node(a, ir.StreamZip) + assert_unique_uids(a) + a = hl._zip_streams(a1, a2, fill_missing=False) + assert_contains_node(a, ir.StreamZip) + assert_unique_uids(a) + + # test StreamFilter + a = hl._stream_range(15).filter(lambda x: x % 3 != 0) + assert_contains_node(a, ir.StreamFilter) + assert_unique_uids(a) + + # test StreamFilter + a = hl._stream_range(5).flatmap(lambda x: hl._stream_range(x)) + assert_contains_node(a, ir.StreamFlatMap) + assert_unique_uids(a) + + # test StreamFold + a = hl._stream_range(10) + a = a.fold(lambda acc, x: acc.append(hl.rand_int64()), hl.empty_array(hl.tint64)) + assert_contains_node(a, ir.StreamFold) + assert(len(set(hl.eval(a))) == 10) + + # test StreamScan + a = hl._stream_range(5) + a = a.scan(lambda acc, x: acc.append(hl.rand_int64()), hl.empty_array(hl.tint64)) + assert_contains_node(a, ir.StreamScan) + assert(len(set(hl.eval(a.to_array())[-1])) == 5) + + # test AggExplode + t = hl.utils.range_table(5) + t = t.annotate(a = hl.range(t.idx)) + a = hl.agg.explode(lambda x: hl.agg.collect_as_set(hl.rand_int64()), t.a) + assert_contains_node(a, ir.AggExplode) + assert(len(t.aggregate(a)) == 10) + + # test TableCount + t = hl.utils.range_table(10) + t = t.annotate(x = hl.rand_int64()) + assert(t.count() == 10) + + # test TableGetGlobals + t = hl.utils.range_table(10) + t = t.annotate(x = hl.rand_int64()) + g = t.index_globals() + assert_contains_node(g, ir.TableGetGlobals) + assert(len(hl.eval(g)) == 0) + + # test TableCollect + t = hl.utils.range_table(10) + t = t.annotate(x = hl.rand_int64()) + a = t.collect() + assert(len(set(a)) == 10) + + # test TableAggregate + t = hl.utils.range_table(10) + a = t.aggregate(hl.agg.collect(hl.rand_int64()).map(lambda x: x + hl.rand_int64())) + assert(len(set(a)) == 10) + + # test MatrixCount + mt = hl.utils.range_matrix_table(10, 10) + mt = mt.annotate_entries(x = hl.rand_int64()) + assert(mt.count() == (10, 10)) + + # test MatrixAggregate + mt = hl.utils.range_matrix_table(5, 5) + a = mt.aggregate_entries(hl.agg.collect(hl.rand_int64()).map(lambda x: x + hl.rand_int64())) + assert(len(set(a)) == 25) diff --git a/hail/python/test/hail/expr/test_ndarrays.py b/hail/python/test/hail/expr/test_ndarrays.py index 1974be5f88c..24f9a6ed6e0 100644 --- a/hail/python/test/hail/expr/test_ndarrays.py +++ b/hail/python/test/hail/expr/test_ndarrays.py @@ -4,9 +4,6 @@ from hail.utils.java import FatalError, HailUserError -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - def assert_ndarrays(asserter, exprs_and_expecteds): exprs, expecteds = zip(*exprs_and_expecteds) diff --git a/hail/python/test/hail/expr/test_show.py b/hail/python/test/hail/expr/test_show.py index dfb851bce02..d1acfc3bacd 100644 --- a/hail/python/test/hail/expr/test_show.py +++ b/hail/python/test/hail/expr/test_show.py @@ -1,11 +1,7 @@ -from ..helpers import startTestHailContext, stopTestHailContext import unittest import hail as hl -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - class Tests(unittest.TestCase): def test(self): diff --git a/hail/python/test/hail/expr/test_types.py b/hail/python/test/hail/expr/test_types.py index 7fc96e2ff02..9df0027b53c 100644 --- a/hail/python/test/hail/expr/test_types.py +++ b/hail/python/test/hail/expr/test_types.py @@ -5,9 +5,6 @@ from ..helpers import * from hail.utils.java import Env -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - class Tests(unittest.TestCase): def types_to_test(self): diff --git a/hail/python/test/hail/fs/test_worker_driver_fs.py b/hail/python/test/hail/fs/test_worker_driver_fs.py index eda8d26e162..76e8d312ad8 100644 --- a/hail/python/test/hail/fs/test_worker_driver_fs.py +++ b/hail/python/test/hail/fs/test_worker_driver_fs.py @@ -1,11 +1,7 @@ import hail as hl -from ..helpers import startTestHailContext, stopTestHailContext from hailtop.utils import secret_alnum_string from hailtop.test_utils import skip_in_azure -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - @skip_in_azure def test_requester_pays_no_settings(): diff --git a/hail/python/test/hail/genetics/test_call.py b/hail/python/test/hail/genetics/test_call.py index 25d8f8ef2c1..63a0712a545 100644 --- a/hail/python/test/hail/genetics/test_call.py +++ b/hail/python/test/hail/genetics/test_call.py @@ -3,9 +3,6 @@ from hail.genetics import * from ..helpers import * -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - class Tests(unittest.TestCase): def test_hom_ref(self): diff --git a/hail/python/test/hail/genetics/test_locus.py b/hail/python/test/hail/genetics/test_locus.py index b85ca96e4f8..ebf0491aeb3 100644 --- a/hail/python/test/hail/genetics/test_locus.py +++ b/hail/python/test/hail/genetics/test_locus.py @@ -4,9 +4,6 @@ from hail.genetics import * from ..helpers import * -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - class Tests(unittest.TestCase): def test_constructor(self): diff --git a/hail/python/test/hail/genetics/test_pedigree.py b/hail/python/test/hail/genetics/test_pedigree.py index 53a194fdead..e42084b0ace 100644 --- a/hail/python/test/hail/genetics/test_pedigree.py +++ b/hail/python/test/hail/genetics/test_pedigree.py @@ -4,9 +4,6 @@ from ..helpers import * from hail.utils.java import FatalError -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - class Tests(unittest.TestCase): diff --git a/hail/python/test/hail/genetics/test_reference_genome.py b/hail/python/test/hail/genetics/test_reference_genome.py index c79438a6f15..1f3316ceea2 100644 --- a/hail/python/test/hail/genetics/test_reference_genome.py +++ b/hail/python/test/hail/genetics/test_reference_genome.py @@ -5,9 +5,6 @@ from ..helpers import * from hail.utils import FatalError -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - class Tests(unittest.TestCase): diff --git a/hail/python/test/hail/helpers.py b/hail/python/test/hail/helpers.py index c38844cce50..a0f448c3a02 100644 --- a/hail/python/test/hail/helpers.py +++ b/hail/python/test/hail/helpers.py @@ -7,22 +7,17 @@ from hail.utils.java import Env, choose_backend import hail as hl -_initialized = False - def startTestHailContext(): - global _initialized - if not _initialized: - backend_name = choose_backend() - if backend_name == 'spark': - hl.init(master='local[2]', min_block_size=0, quiet=True) - else: - Env.hc() # force initialization - _initialized = True + backend_name = choose_backend() + if backend_name == 'spark': + hl.init(master='local[2]', min_block_size=0, quiet=True, global_seed=0) + else: + hl.init(global_seed=0) def stopTestHailContext(): - pass + hl.stop() _test_dir = os.environ.get('HAIL_TEST_RESOURCES_DIR', '../src/test/resources') _doctest_dir = os.environ.get('HAIL_DOCTEST_DATA_DIR', 'hail/docs/data') diff --git a/hail/python/test/hail/linalg/test_linalg.py b/hail/python/test/hail/linalg/test_linalg.py index 3258c7e172e..46b0645ebb3 100644 --- a/hail/python/test/hail/linalg/test_linalg.py +++ b/hail/python/test/hail/linalg/test_linalg.py @@ -10,9 +10,6 @@ import math from hail.expr.expressions import ExpressionException -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - def sparsify_numpy(np_mat, block_size, blocks_to_sparsify): n_rows, n_cols = np_mat.shape diff --git a/hail/python/test/hail/matrixtable/test_file_formats.py b/hail/python/test/hail/matrixtable/test_file_formats.py index 6488c3728e0..af3a0426d36 100644 --- a/hail/python/test/hail/matrixtable/test_file_formats.py +++ b/hail/python/test/hail/matrixtable/test_file_formats.py @@ -5,9 +5,6 @@ from hail.utils.java import Env, scala_object from ..helpers import * -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - def create_backward_compatibility_files(): import os @@ -39,12 +36,12 @@ def test_write(): @pytest.fixture(scope="module") -def all_values_matrix_table_fixture(): +def all_values_matrix_table_fixture(init_hail): return create_all_values_matrix_table() @pytest.fixture(scope="module") -def all_values_table_fixture(): +def all_values_table_fixture(init_hail): return create_all_values_table() diff --git a/hail/python/test/hail/matrixtable/test_grouped_matrix_table.py b/hail/python/test/hail/matrixtable/test_grouped_matrix_table.py index 9c1a3d9e8b0..601f32773a2 100644 --- a/hail/python/test/hail/matrixtable/test_grouped_matrix_table.py +++ b/hail/python/test/hail/matrixtable/test_grouped_matrix_table.py @@ -3,9 +3,6 @@ import hail as hl from ..helpers import * -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - class Tests(unittest.TestCase): diff --git a/hail/python/test/hail/matrixtable/test_matrix_table.py b/hail/python/test/hail/matrixtable/test_matrix_table.py index 05bd6ef816f..6da17ff93f1 100644 --- a/hail/python/test/hail/matrixtable/test_matrix_table.py +++ b/hail/python/test/hail/matrixtable/test_matrix_table.py @@ -4,14 +4,12 @@ import pytest import hail as hl +import hail.ir as ir import hail.expr.aggregators as agg from hail.utils.java import Env from hail.utils.misc import new_temp_file from ..helpers import * -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - class Tests(unittest.TestCase): def get_mt(self, min_partitions=None) -> hl.MatrixTable: @@ -1783,3 +1781,331 @@ def test_filter_against_invalid_contig(): mt = hl.balding_nichols_model(3, 5, 20) fmt = mt.filter_rows(mt.locus.contig == "chr1") assert fmt.rows()._force_count() == 0 + + +def test_matrix_randomness(): + def assert_unique_uids(mt): + x = mt.aggregate_rows(hl.struct(r=hl.agg.collect_as_set(hl.rand_int64()), n=hl.agg.count())) + assert(len(x.r) == x.n) + x = mt.aggregate_cols(hl.struct(r=hl.agg.collect_as_set(hl.rand_int64()), n=hl.agg.count())) + assert(len(x.r) == x.n) + x = mt.aggregate_entries(hl.struct(r=hl.agg.collect_as_set(hl.rand_int64()), n=hl.agg.count())) + assert(len(x.r) == x.n) + + def assert_contains_node(t, node): + assert(t._mir.base_search(lambda x: isinstance(x, node))) + + # test MatrixRead + mt = hl.utils.range_matrix_table(10, 10, 3) + assert_contains_node(mt, ir.MatrixRead) + assert_unique_uids(mt) + + # test MatrixAggregateRowsByKey + rmt = hl.utils.range_matrix_table(20, 10, 3) + # with body randomness + mt = (rmt.group_rows_by(k=rmt.row_idx % 5) + .aggregate_rows(r=hl.rand_int64()) + .aggregate_entries(e=hl.rand_int64()) + .result()) + assert_contains_node(mt, ir.MatrixAggregateRowsByKey) + x = mt.aggregate_rows(hl.struct(r=hl.agg.collect_as_set(mt.r), n=hl.agg.count())) + assert(len(x.r) == x.n) + x = mt.aggregate_entries(hl.struct(r=hl.agg.collect_as_set(mt.e), n=hl.agg.count())) + assert(len(x.r) == x.n) + assert_unique_uids(mt) + # with agg randomness + mt = (rmt.group_rows_by(k=rmt.row_idx % 5) + .aggregate_rows(r=hl.agg.collect(hl.rand_int64())) + .aggregate_entries(e=hl.agg.collect(hl.rand_int64())) + .result()) + assert_contains_node(mt, ir.MatrixAggregateRowsByKey) + x = mt.aggregate_rows(hl.agg.explode(lambda r: hl.struct(r=hl.agg.collect_as_set(r), n=hl.agg.count()), mt.r)) + assert(len(x.r) == x.n) + x = mt.aggregate_entries(hl.agg.explode(lambda r: hl.struct(r=hl.agg.collect_as_set(r), n=hl.agg.count()), mt.e)) + assert(len(x.r) == x.n) + assert_unique_uids(mt) + # w/o body randomness + mt = (rmt.group_rows_by(k=rmt.row_idx % 5) + .aggregate_rows(row_agg=hl.agg.sum(rmt.row_idx)) + .aggregate_entries(entry_agg=hl.agg.sum(rmt.row_idx + rmt.col_idx)) + .result()) + assert_contains_node(mt, ir.MatrixAggregateRowsByKey) + assert_unique_uids(mt) + + # test MatrixFilterRows + rmt = hl.utils.range_matrix_table(10, 10, 3) + # with cond randomness + mt = rmt.filter_rows(hl.rand_int64() % 2 == 0) + assert_contains_node(mt, ir.MatrixFilterRows) + mt.entries()._force_count() # test with no consumer randomness + assert_unique_uids(mt) + # w/o cond randomness + mt = rmt.filter_rows(rmt.row_idx < 5) + assert_contains_node(mt, ir.MatrixFilterRows) + assert_unique_uids(mt) + + # test MatrixChooseCols + rmt = hl.utils.range_matrix_table(10, 10, 3) + mt = rmt.choose_cols([2, 3, 7]) + assert_contains_node(mt, ir.MatrixChooseCols) + assert_unique_uids(mt) + + # test MatrixMapCols + rmt = hl.utils.range_matrix_table(10, 10, 3) + # with body randomness + mt = rmt.annotate_cols(r=hl.rand_int64()) + assert_contains_node(mt, ir.MatrixMapCols) + x = mt.aggregate_cols(hl.struct(r=hl.agg.collect_as_set(mt.r), n=hl.agg.count())) + assert(len(x.r) == x.n) + assert_unique_uids(mt) + # with agg randomness + mt = rmt.annotate_cols(r=hl.agg.collect(hl.rand_int64())) + assert_contains_node(mt, ir.MatrixMapCols) + x = mt.aggregate_cols(hl.agg.explode(lambda r: hl.struct(r=hl.agg.collect_as_set(r), n=hl.agg.count()), mt.r)) + assert(len(x.r) == x.n) + assert_unique_uids(mt) + # with scan randomness + mt = rmt.annotate_cols(r=hl.scan.collect(hl.rand_int64())) + assert_contains_node(mt, ir.MatrixMapCols) + x = mt.aggregate_cols(hl.struct(r=hl.agg.explode(lambda r: hl.agg.collect_as_set(r), mt.r), n=hl.agg.count())) + assert(len(x.r) == x.n - 1) + assert_unique_uids(mt) + # w/o body randomness + mt = rmt.annotate_cols(x=2*rmt.col_idx) + assert_contains_node(mt, ir.MatrixMapCols) + assert_unique_uids(mt) + + # test MatrixUnionCols + r, c = 5, 5 + mt = hl.utils.range_matrix_table(2*r, c) + mt2 = hl.utils.range_matrix_table(2*r, c) + mt2 = mt2.key_rows_by(row_idx=mt2.row_idx + r) + mt2 = mt2.key_cols_by(col_idx=mt2.col_idx + c) + mt = mt.union_cols(mt2) + assert_contains_node(mt, ir.MatrixUnionCols) + assert_unique_uids(mt) + + # test MatrixMapEntries + rmt = hl.utils.range_matrix_table(10, 10, 3) + # with body randomness + mt = rmt.annotate_entries(r=hl.rand_int64()) + assert_contains_node(mt, ir.MatrixMapEntries) + x = mt.aggregate_entries(hl.struct(r=hl.agg.collect_as_set(mt.r), n=hl.agg.count())) + assert(len(x.r) == x.n) + assert_unique_uids(mt) + # w/o body randomness + mt = rmt.annotate_entries(x=rmt.row_idx + rmt.col_idx) + assert_contains_node(mt, ir.MatrixMapEntries) + assert_unique_uids(mt) + + # test MatrixFilterEntries + rmt = hl.utils.range_matrix_table(10, 10, 3) + # with cond randomness + mt = rmt.filter_entries(hl.rand_int64() % 2 == 0) + assert_contains_node(mt, ir.MatrixFilterEntries) + mt.entries()._force_count() # test with no consumer randomness + assert_unique_uids(mt) + # w/o cond randomness + mt = rmt.filter_entries(rmt.row_idx + rmt.col_idx < 10) + assert_contains_node(mt, ir.MatrixFilterEntries) + assert_unique_uids(mt) + + # test MatrixKeyRowsBy + rmt = hl.utils.range_matrix_table(10, 10, 3) + mt = rmt.key_rows_by(k=rmt.row_idx // 4) + assert_contains_node(mt, ir.MatrixKeyRowsBy) + assert_unique_uids(mt) + + # test MatrixMapRows + rmt = hl.utils.range_matrix_table(10, 10, 3) + # with body randomness + mt = rmt.annotate_rows(r=hl.rand_int64()) + assert_contains_node(mt, ir.MatrixMapRows) + x = mt.aggregate_rows(hl.struct(r=hl.agg.collect_as_set(mt.r), n=hl.agg.count())) + assert(len(x.r) == x.n) + assert_unique_uids(mt) + # with agg randomness + mt = rmt.annotate_rows(r=hl.agg.collect(hl.rand_int64())) + assert_contains_node(mt, ir.MatrixMapRows) + x = mt.aggregate_rows(hl.agg.explode(lambda r: hl.struct(r=hl.agg.collect_as_set(r), n=hl.agg.count()), mt.r)) + assert(len(x.r) == x.n) + assert_unique_uids(mt) + # with scan randomness + mt = rmt.annotate_rows(r=hl.scan.collect(hl.rand_int64())) + assert_contains_node(mt, ir.MatrixMapRows) + x = mt.aggregate_rows(hl.struct(r=hl.agg.explode(lambda r: hl.agg.collect_as_set(r), mt.r), n=hl.agg.count())) + assert(len(x.r) == x.n - 1) + assert_unique_uids(mt) + # w/o body randomness + mt = rmt.annotate_rows(x=2*rmt.row_idx) + assert_contains_node(mt, ir.MatrixMapRows) + assert_unique_uids(mt) + + # test MatrixMapGlobals + rmt = hl.utils.range_matrix_table(10, 10, 3) + # with body randomness + mt = rmt.annotate_globals(x=hl.rand_int64()) + assert_contains_node(mt, ir.MatrixMapGlobals) + mt.entries()._force_count() # test with no consumer randomness + assert_unique_uids(mt) + # w/o body randomness + mt = rmt.annotate_globals(x=1) + assert_contains_node(mt, ir.MatrixMapGlobals) + assert_unique_uids(mt) + + # test MatrixFilterCols + rmt = hl.utils.range_matrix_table(10, 10, 3) + # with cond randomness + mt = rmt.filter_cols(hl.rand_int64() % 2 == 0) + assert_contains_node(mt, ir.MatrixFilterCols) + mt.entries()._force_count() # test with no consumer randomness + assert_unique_uids(mt) + # w/o cond randomness + mt = rmt.filter_cols(rmt.col_idx < 5) + assert_contains_node(mt, ir.MatrixFilterCols) + assert_unique_uids(mt) + + # test MatrixCollectColsByKey + rmt = hl.utils.range_matrix_table(10, 10, 3) + mt = rmt.key_cols_by(k=rmt.col_idx % 5) + mt = mt.collect_cols_by_key() + assert_contains_node(mt, ir.MatrixCollectColsByKey) + assert_unique_uids(mt) + + # test MatrixAggregateColsByKey + rmt = hl.utils.range_matrix_table(20, 10, 3) + # with body randomness + mt = (rmt.group_cols_by(k=rmt.col_idx % 5) + .aggregate_cols(r=hl.rand_int64()) + .aggregate_entries(e=hl.rand_int64()) + .result()) + assert_contains_node(mt, ir.MatrixAggregateColsByKey) + x = mt.aggregate_cols(hl.struct(r=hl.agg.collect_as_set(mt.r), n=hl.agg.count())) + assert(len(x.r) == x.n) + x = mt.aggregate_entries(hl.struct(r=hl.agg.collect_as_set(mt.e), n=hl.agg.count())) + assert(len(x.r) == x.n) + assert_unique_uids(mt) + # with agg randomness + mt = (rmt.group_cols_by(k=rmt.col_idx % 5) + .aggregate_cols(r=hl.agg.collect(hl.rand_int64())) + .aggregate_entries(e=hl.agg.collect(hl.rand_int64())) + .result()) + assert_contains_node(mt, ir.MatrixAggregateColsByKey) + x = mt.aggregate_cols(hl.agg.explode(lambda r: hl.struct(r=hl.agg.collect_as_set(r), n=hl.agg.count()), mt.r)) + assert(len(x.r) == x.n) + x = mt.aggregate_entries(hl.agg.explode(lambda r: hl.struct(r=hl.agg.collect_as_set(r), n=hl.agg.count()), mt.e)) + assert(len(x.r) == x.n) + assert_unique_uids(mt) + # w/o body randomness + mt = (rmt.group_cols_by(k=rmt.col_idx % 5) + .aggregate_cols(row_agg=hl.agg.sum(rmt.col_idx)) + .aggregate_entries(entry_agg=hl.agg.sum(rmt.row_idx + rmt.col_idx)) + .result()) + assert_contains_node(mt, ir.MatrixAggregateColsByKey) + assert_unique_uids(mt) + + # test MatrixExplodeRows + rmt = hl.utils.range_matrix_table(20, 10, 3) + mt = rmt.annotate_rows(s=hl.struct(a=hl.range(rmt.row_idx))) + mt = mt.explode_rows(mt.s.a) + assert_contains_node(mt, ir.MatrixExplodeRows) + assert_unique_uids(mt) + + # test MatrixRepartition + if not hl.current_backend().requires_lowering: + rmt = hl.utils.range_matrix_table(20, 10, 3) + mt = rmt.repartition(5) + assert_contains_node(mt, ir.MatrixRepartition) + assert_unique_uids(mt) + + # test MatrixUnionRows + r, c = 5, 5 + mt = hl.utils.range_matrix_table(2*r, c) + mt2 = hl.utils.range_matrix_table(2*r, c) + mt2 = mt2.key_rows_by(row_idx=mt2.row_idx + r) + mt = mt.union_rows(mt2) + assert_contains_node(mt, ir.MatrixUnionRows) + assert_unique_uids(mt) + + # test MatrixDistinctByRow + rmt = hl.utils.range_matrix_table(20, 10, 3) + mt = rmt.key_rows_by(k=rmt.row_idx % 5) + mt = mt.distinct_by_row() + assert_contains_node(mt, ir.MatrixDistinctByRow) + assert_unique_uids(mt) + + # test MatrixRowsHead + rmt = hl.utils.range_matrix_table(20, 10, 3) + mt = rmt.head(10) + assert_contains_node(mt, ir.MatrixRowsHead) + assert_unique_uids(mt) + + # test MatrixColsHead + rmt = hl.utils.range_matrix_table(10, 20, 3) + mt = rmt.head(None, 10) + assert_contains_node(mt, ir.MatrixColsHead) + assert_unique_uids(mt) + + # test MatrixRowsTail + rmt = hl.utils.range_matrix_table(20, 10, 3) + mt = rmt.tail(10) + assert_contains_node(mt, ir.MatrixRowsTail) + assert_unique_uids(mt) + + # test MatrixColsTail + rmt = hl.utils.range_matrix_table(10, 20, 3) + mt = rmt.tail(None, 10) + assert_contains_node(mt, ir.MatrixColsTail) + assert_unique_uids(mt) + + # test MatrixExplodeCols + rmt = hl.utils.range_matrix_table(10, 20, 3) + mt = rmt.annotate_cols(s=hl.struct(a=hl.range(rmt.col_idx))) + mt = mt.explode_cols(mt.s.a) + assert_contains_node(mt, ir.MatrixExplodeCols) + assert_unique_uids(mt) + + # test CastTableToMatrix + rt = hl.utils.range_table(10, 3) + t = rt.annotate(e=hl.range(10).map(lambda i: hl.struct(x=i))) + t = t.annotate_globals(c=hl.range(10).map(lambda i: hl.struct(y=i))) + mt = t._unlocalize_entries('e', 'c', []) + assert_contains_node(mt, ir.CastTableToMatrix) + assert_unique_uids(mt) + + # test MatrixAnnotateRowsTable + t = hl.utils.range_table(12, 3) + t = t.key_by(k=(t.idx // 2) * 2) + mt = hl.utils.range_matrix_table(8, 10, 3) + mt = mt.key_rows_by(k=(mt.row_idx // 2) * 3) + joined = mt.annotate_rows(x=t[mt.k].idx) + assert_contains_node(joined, ir.MatrixAnnotateRowsTable) + assert_unique_uids(joined) + + # test MatrixAnnotateColsTable + t = hl.utils.range_table(12, 3) + t = t.key_by(k=(t.idx // 2) * 2) + mt = hl.utils.range_matrix_table(10, 8, 3) + mt = mt.key_cols_by(k=(mt.col_idx // 2) * 3) + joined = mt.annotate_cols(x=t[mt.k].idx) + assert_contains_node(joined, ir.MatrixAnnotateColsTable) + assert_unique_uids(joined) + + # test MatrixToMatrixApply + rmt = hl.utils.range_matrix_table(10, 10, 3) + mt = rmt._filter_partitions([0, 2]) + assert_contains_node(mt, ir.MatrixToMatrixApply) + assert_unique_uids(mt) + + # test MatrixRename + rmt = hl.utils.range_matrix_table(10, 10, 3) + mt = rmt.rename({'row_idx': 'row_index'}) + assert_contains_node(mt, ir.MatrixRename) + assert_unique_uids(mt) + + # test MatrixFilterIntervals + rmt = hl.utils.range_matrix_table(20, 10, 3) + intervals = [hl.interval(0, 5), hl.interval(10, 15)] + mt = hl.filter_intervals(rmt, intervals) + assert_contains_node(mt, ir.MatrixFilterIntervals) + assert_unique_uids(mt) diff --git a/hail/python/test/hail/methods/relatedness/test_identity_by_descent.py b/hail/python/test/hail/methods/relatedness/test_identity_by_descent.py index dfa7289bc3f..4259d25ff8b 100644 --- a/hail/python/test/hail/methods/relatedness/test_identity_by_descent.py +++ b/hail/python/test/hail/methods/relatedness/test_identity_by_descent.py @@ -4,11 +4,7 @@ import hail as hl import hail.utils as utils -from ...helpers import (startTestHailContext, stopTestHailContext, get_dataset, - fails_service_backend, fails_local_backend) - -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext +from ...helpers import get_dataset, fails_service_backend, fails_local_backend class Tests(unittest.TestCase): diff --git a/hail/python/test/hail/methods/relatedness/test_pc_relate.py b/hail/python/test/hail/methods/relatedness/test_pc_relate.py index bea71723d2b..0cc52ee9a68 100644 --- a/hail/python/test/hail/methods/relatedness/test_pc_relate.py +++ b/hail/python/test/hail/methods/relatedness/test_pc_relate.py @@ -1,10 +1,7 @@ import hail as hl import hail.utils as utils -from ...helpers import (resource, startTestHailContext, stopTestHailContext, skip_when_service_backend) - -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext +from ...helpers import resource, skip_when_service_backend def test_pc_relate_against_R_truth(): diff --git a/hail/python/test/hail/methods/test_family_methods.py b/hail/python/test/hail/methods/test_family_methods.py index 445590148e3..47e7e66f1f0 100644 --- a/hail/python/test/hail/methods/test_family_methods.py +++ b/hail/python/test/hail/methods/test_family_methods.py @@ -3,9 +3,6 @@ import hail as hl from ..helpers import * -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - class Tests(unittest.TestCase): def test_trio_matrix(self): diff --git a/hail/python/test/hail/methods/test_impex.py b/hail/python/test/hail/methods/test_impex.py index f6b989c749d..e7697496d6e 100644 --- a/hail/python/test/hail/methods/test_impex.py +++ b/hail/python/test/hail/methods/test_impex.py @@ -15,9 +15,6 @@ from hail import ir from hail.utils import new_temp_file, FatalError, run_command, uri_path, HailUserError -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - _FLOAT_INFO_FIELDS = [ 'BaseQRankSum', 'ClippingRankSum', diff --git a/hail/python/test/hail/methods/test_king.py b/hail/python/test/hail/methods/test_king.py index 7da1e19b4cb..1e86a1ae563 100644 --- a/hail/python/test/hail/methods/test_king.py +++ b/hail/python/test/hail/methods/test_king.py @@ -1,10 +1,7 @@ import pytest import hail as hl -from ..helpers import resource, startTestHailContext, stopTestHailContext, fails_local_backend, fails_service_backend - -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext +from ..helpers import resource, fails_local_backend, fails_service_backend def assert_c_king_same_as_hail_king(c_king_path, hail_king_mt): diff --git a/hail/python/test/hail/methods/test_misc.py b/hail/python/test/hail/methods/test_misc.py index 66739dff3ae..54f05951316 100644 --- a/hail/python/test/hail/methods/test_misc.py +++ b/hail/python/test/hail/methods/test_misc.py @@ -3,9 +3,6 @@ import hail as hl from ..helpers import * -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - class Tests(unittest.TestCase): def test_rename_duplicates(self): diff --git a/hail/python/test/hail/methods/test_pca.py b/hail/python/test/hail/methods/test_pca.py index e21f97e37b3..e0155dea95e 100644 --- a/hail/python/test/hail/methods/test_pca.py +++ b/hail/python/test/hail/methods/test_pca.py @@ -4,12 +4,7 @@ import hail as hl from hail.methods.pca import _make_tsm -from ..helpers import (resource, startTestHailContext, stopTestHailContext, fails_local_backend, - fails_service_backend, skip_when_service_backend) - - -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext +from ..helpers import resource, fails_local_backend, fails_service_backend, skip_when_service_backend @fails_local_backend() diff --git a/hail/python/test/hail/methods/test_qc.py b/hail/python/test/hail/methods/test_qc.py index a0d041c807a..ec827a45649 100644 --- a/hail/python/test/hail/methods/test_qc.py +++ b/hail/python/test/hail/methods/test_qc.py @@ -4,9 +4,6 @@ import hail.expr.aggregators as agg from ..helpers import * -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - class Tests(unittest.TestCase): def test_sample_qc(self): diff --git a/hail/python/test/hail/methods/test_statgen.py b/hail/python/test/hail/methods/test_statgen.py index 6fde93ba006..ef41d522891 100644 --- a/hail/python/test/hail/methods/test_statgen.py +++ b/hail/python/test/hail/methods/test_statgen.py @@ -1,6 +1,5 @@ import os import math -import unittest import pytest import numpy as np @@ -9,16 +8,15 @@ import hail.utils as utils from hail.linalg import BlockMatrix from hail.utils import FatalError -from hail.utils.java import choose_backend -from ..helpers import (startTestHailContext, stopTestHailContext, resource, - fails_local_backend, fails_service_backend) +from hail.utils.java import choose_backend, Env +from ..helpers import resource, fails_local_backend, fails_service_backend -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext +class Tests: + def __init__(self): + pass -class Tests(unittest.TestCase): - @unittest.skipIf('HAIL_TEST_SKIP_PLINK' in os.environ, 'Skipping tests requiring plink') + @pytest.mark.skipif('HAIL_TEST_SKIP_PLINK' in os.environ, 'Skipping tests requiring plink') @fails_service_backend() def test_impute_sex_same_as_plink(self): ds = hl.import_vcf(resource('x-chromosome.vcf')) @@ -454,6 +452,7 @@ def eq(x1, x2): eq(combined.p_value, combined.multi.p_value[0]) & eq(combined.multi.p_value[0], combined.multi.p_value[1])))) + def test_logistic_regression_rows_max_iter_zero(self): import hail as hl mt = hl.utils.range_matrix_table(1, 3) @@ -596,7 +595,7 @@ def test_weighted_linear_regression(self): def equal_with_nans(arr1, arr2): def both_nan_or_none(a, b): - return (a is None or np.isnan(a)) and (b is None or np.isnan(b)) + return (a is None or not np.isfinite(a)) and (b is None or not np.isfinite(b)) return all([both_nan_or_none(a, b) or math.isclose(a, b) for a, b in zip(arr1, arr2)]) @@ -1278,11 +1277,8 @@ def test_poisson_pass_through(self): assert mt.aggregate_rows(hl.agg.all(mt.foo.bar == ht[mt.row_key].bar)) - @fails_local_backend() - @fails_service_backend() def test_genetic_relatedness_matrix(self): n, m = 100, 200 - hl.set_global_seed(0) mt = hl.balding_nichols_model(3, n, m, fst=[.9, .9, .9], n_partitions=4) g = BlockMatrix.from_entry_expr(mt.GT.n_alt_alleles()).to_numpy().T @@ -1312,11 +1308,9 @@ def _filter_and_standardize_cols(a): col_filter = col_lengths > 0 return np.copy(a[:, np.squeeze(col_filter)] / col_lengths[col_filter]) - @fails_service_backend() - @fails_local_backend() def test_realized_relationship_matrix(self): n, m = 100, 200 - hl.set_global_seed(0) + hl.reset_global_randomness() mt = hl.balding_nichols_model(3, n, m, fst=[.9, .9, .9], n_partitions=4) g = BlockMatrix.from_entry_expr(mt.GT.n_alt_alleles()).to_numpy().T @@ -1352,11 +1346,8 @@ def test_row_correlation_vs_hardcode(self): self.assertTrue(np.allclose(actual, expected)) - @fails_service_backend() - @fails_local_backend() def test_row_correlation_vs_numpy(self): n, m = 11, 10 - hl.set_global_seed(0) mt = hl.balding_nichols_model(3, n, m, fst=[.9, .9, .9], n_partitions=2) mt = mt.annotate_rows(sd=agg.stats(mt.GT.n_alt_alleles()).stdev) mt = mt.filter_rows(mt.sd > 1e-30) @@ -1552,7 +1543,7 @@ def test_ld_prune_with_duplicate_row_keys(self): self.assertEqual(pruned_table.count(), 1) def test_balding_nichols_model(self): - hl.set_global_seed(1) + hl.reset_global_randomness() ds = hl.balding_nichols_model(2, 20, 25, 3, pop_dist=[1.0, 2.0], fst=[.02, .06], @@ -1573,13 +1564,13 @@ def test_balding_nichols_model(self): def test_balding_nichols_model_same_results(self): for mixture in [True, False]: - hl.set_global_seed(1) + hl.reset_global_randomness() ds1 = hl.balding_nichols_model(2, 20, 25, 3, pop_dist=[1.0, 2.0], fst=[.02, .06], af_dist=hl.rand_beta(a=0.01, b=2.0, lower=0.05, upper=0.95), mixture=mixture) - hl.set_global_seed(1) + hl.reset_global_randomness() ds2 = hl.balding_nichols_model(2, 20, 25, 3, pop_dist=[1.0, 2.0], fst=[.02, .06], @@ -1589,7 +1580,7 @@ def test_balding_nichols_model_same_results(self): def test_balding_nichols_model_af_ranges(self): def test_af_range(rand_func, min, max, seed): - hl.set_global_seed(seed) + hl.reset_global_randomness() bn = hl.balding_nichols_model(3, 400, 400, af_dist=rand_func) self.assertTrue( bn.aggregate_rows( @@ -1603,7 +1594,7 @@ def test_af_range(rand_func, min, max, seed): def test_balding_nichols_stats(self): def test_stat(k, n, m, seed): - hl.set_global_seed(seed) + hl.reset_global_randomness() bn = hl.balding_nichols_model(k, n, m, af_dist=hl.rand_unif(0.1, 0.9)) # test pop distribution @@ -1638,18 +1629,17 @@ def variance(expr): test_stat(40, 400, 20, 12) def test_balding_nichols_model_phased(self): - hl.set_global_seed(1) bn_ds = hl.balding_nichols_model(1, 5, 5, phased=True) assert bn_ds.aggregate_entries(hl.agg.all(bn_ds.GT.phased)) == True actual = bn_ds.GT.collect() expected = [ hl.Call(a, phased=True) for a in [ - [0, 1], [0, 0], [0, 1], [0, 0], [0, 0], - [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], - [0, 1], [0, 0], [0, 0], [1, 1], [0, 1], - [1, 1], [1, 1], [1, 0], [1, 1], [1, 0], - [0, 0], [0, 0], [0, 1], [0, 0], [1, 0]]] + [0, 1], [0, 0], [0, 1], [0, 0], [1, 0], + [1, 1], [0, 1], [1, 1], [0, 0], [0, 1], + [1, 0], [0, 0], [1, 0], [0, 0], [0, 0], + [1, 1], [1, 1], [1, 0], [0, 1], [1, 1], + [1, 1], [1, 1], [1, 1], [1, 1], [1, 1]]] assert actual == expected @fails_service_backend() @@ -1730,17 +1720,17 @@ def test_skat_max_iteration_fails_explodes_in_37_steps(self): [10, 5, 1] ])[mt.row_idx] ) - ht = hl.skat( - hl.literal(0), - mt.row_idx, - y=mt.y, - x=mt.x[mt.col_idx], - logistic=(37, 1e-10), - # The logistic settings are only used when fitting the null model, so we need to use a - # covariate that triggers nonconvergence - covariates=[mt.y] - ) try: + ht = hl.skat( + hl.literal(0), + mt.row_idx, + y=mt.y, + x=mt.x[mt.col_idx], + logistic=(37, 1e-10), + # The logistic settings are only used when fitting the null model, so we need to use a + # covariate that triggers nonconvergence + covariates=[mt.y] + ) ht.collect()[0] except FatalError as err: assert 'Failed to fit logistic regression null model (MLE with covariates only): exploded at Newton iteration 37' in err.args[0] @@ -1757,17 +1747,17 @@ def test_skat_max_iterations_fails_to_converge_in_fewer_than_36_steps(self): [10, 5, 1] ])[mt.row_idx] ) - ht = hl.skat( - hl.literal(0), - mt.row_idx, - y=mt.y, - x=mt.x[mt.col_idx], - logistic=(36, 1e-10), - # The logistic settings are only used when fitting the null model, so we need to use a - # covariate that triggers nonconvergence - covariates=[mt.y] - ) try: + ht = hl.skat( + hl.literal(0), + mt.row_idx, + y=mt.y, + x=mt.x[mt.col_idx], + logistic=(36, 1e-10), + # The logistic settings are only used when fitting the null model, so we need to use a + # covariate that triggers nonconvergence + covariates=[mt.y] + ) ht.collect()[0] except FatalError as err: assert 'Failed to fit logistic regression null model (MLE with covariates only): Newton iteration failed to converge' in err.args[0] diff --git a/hail/python/test/hail/table/test_grouped_table.py b/hail/python/test/hail/table/test_grouped_table.py index 74c0c99bffb..918e2afeba3 100644 --- a/hail/python/test/hail/table/test_grouped_table.py +++ b/hail/python/test/hail/table/test_grouped_table.py @@ -1,10 +1,6 @@ import unittest import hail as hl -from ..helpers import * - -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext class GroupedTableTests(unittest.TestCase): diff --git a/hail/python/test/hail/table/test_table.py b/hail/python/test/hail/table/test_table.py index f91b3d8a207..1538149f146 100644 --- a/hail/python/test/hail/table/test_table.py +++ b/hail/python/test/hail/table/test_table.py @@ -10,14 +10,12 @@ import hail.expr.aggregators as agg from hail.utils import new_temp_file from hail.utils.java import Env +import hail.ir as ir from hail import ExpressionException from ..helpers import * from test.hail.matrixtable.test_file_formats import create_all_values_datasets -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - class Tests(unittest.TestCase): def test_annotate(self): @@ -1234,8 +1232,6 @@ def test_join_types(self): assert inner.idx.collect() == [*([2] * 4), *([3] * 9)] assert outer.idx.collect() == [1, *([2] * 4), *([3] * 9), *([4] * 4)] - @fails_service_backend() - @fails_local_backend() def test_partitioning_rewrite(self): ht = hl.utils.range_table(10, 3) ht1 = ht.annotate(x=hl.rand_unif(0, 1)) @@ -2020,6 +2016,246 @@ def test_indexed_read_boundaries(branching_factor): assert t1.idx.collect() == [141, 142, 143, 144, 152] +def test_table_randomness(): + def assert_unique_uids(ht): + ht = ht.annotate(r=hl.rand_int64()) + x = ht.aggregate(hl.struct(r=hl.agg.collect_as_set(ht.r), n=hl.agg.count())) + assert(len(x.r) == x.n) + + def assert_contains_node(t, node): + assert(t._tir.base_search(lambda x: isinstance(x, node))) + + # test TableRange + t = hl.utils.range_table(10, 3) + assert_contains_node(t, ir.TableRange) + assert_unique_uids(t) + + # test MatrixRowsTable + mt = hl.utils.range_matrix_table(10, 10, 3) + t = mt.rows() + assert_contains_node(t, ir.MatrixRowsTable) + assert_unique_uids(t) + + # test TableJoin + t1 = hl.utils.range_table(12, 3) + t1 = t1.key_by(k=(t1.idx // 2) * 2) + t2 = hl.utils.range_table(8, 3) + t2 = t2.key_by(k=(t2.idx // 2) * 3) + t = t1.join(t2, how='outer') + assert_contains_node(t, ir.TableJoin) + assert_unique_uids(t) + + # test TableLeftJoinRightDistinct + t1 = hl.utils.range_table(12, 3) + t1 = t1.key_by(k=(t1.idx // 2) * 2) + t2 = hl.utils.range_table(4, 3) + t2 = t2.key_by(k=t2.idx * 3) + t = t1.annotate(x=t2[t1.k].idx) + assert_contains_node(t, ir.TableLeftJoinRightDistinct) + assert_unique_uids(t) + + # test TableIntervalJoin + t1 = hl.utils.range_table(12, 3) + t2 = hl.utils.range_table(4, 3) + t2 = t2.key_by(k=hl.interval(t2.idx * 3, (t2.idx + 1) * 3)) + t = t1.annotate(x=t2[t1.idx].idx) + assert_contains_node(t, ir.TableIntervalJoin) + assert_unique_uids(t) + + # test TableUnion + t1 = hl.utils.range_table(12, 3) + t2 = hl.utils.range_table(4, 3) + t2 = t2.key_by(idx=t2.idx * 3) + t = t1.union(t2) + assert_contains_node(t, ir.TableUnion) + assert_unique_uids(t) + + # test TableMapGlobals + rt = hl.utils.range_table(5) + # with body randomness + t1 = rt.annotate_globals(x=hl.rand_int64()) + assert_contains_node(t1, ir.TableMapGlobals) + t1._force_count() # test with no consumer randomness + assert_unique_uids(t1) + # w/o body randomness + t2 = rt.annotate_globals(x=1) + assert_contains_node(t2, ir.TableMapGlobals) + assert_unique_uids(t2) + + # test TableExplode + t = hl.utils.range_table(5) + t = t.annotate(s=hl.struct(a=hl.range(t.idx))) + t = t.explode(t.s.a) + assert_contains_node(t, ir.TableExplode) + assert_unique_uids(t) + + # test TableKeyBy + t = hl.utils.range_table(12, 3) + t = t.key_by(k=t.idx // 4) + assert_contains_node(t, ir.TableKeyBy) + assert_unique_uids(t) + + # test TableMapRows + rt = hl.utils.range_table(12, 3) + # with body randomness + t = rt.annotate(x=hl.rand_int64()) + assert_contains_node(t, ir.TableMapRows) + t._force_count() # test with no consumer randomness + assert_unique_uids(t) + # with body scan randomness + t = rt.annotate(x=hl.scan.sum(hl.rand_int64())) + assert_contains_node(t, ir.TableMapRows) + assert_unique_uids(t) + # w/o body randomness + t = rt.annotate(x=1) + assert_contains_node(t, ir.TableMapRows) + assert_unique_uids(t) + + # test TableMapPartitions + rt = hl.utils.range_table(10, 3) + t = rt.annotate(x=hl.rand_int64()) + t = t._map_partitions(lambda part: part.map(lambda row: row.annotate(x=row.x / 2))) + assert_contains_node(t, ir.TableMapPartitions) + t._force_count() # test with no consumer randomness + + # test TableRead + rt = hl.utils.range_table(10, 3) + path = new_temp_file() + rt.write(path) + t = hl.read_table(path) + assert_contains_node(t, ir.TableRead) + assert_unique_uids(t) + + # test MatrixEntriesTable + mt = hl.utils.range_matrix_table(10, 10, 3) + t = mt.entries() + assert_contains_node(t, ir.MatrixEntriesTable) + assert_unique_uids(t) + + # test TableFilter + rt = hl.utils.range_table(20, 3) + # with cond randomness + t = rt.filter(hl.rand_int64() % 2 == 0) + assert_contains_node(t, ir.TableFilter) + t._force_count() # test with no consumer randomness + assert_unique_uids(t) + # w/o cond randomness + t = rt.filter(rt.idx < 100) + assert_contains_node(t, ir.TableFilter) + assert_unique_uids(t) + + # test TableKeyByAndAggregate + rt = hl.utils.range_table(20, 3) + # with body randomness + t = rt.group_by(k=rt.idx % 5).aggregate(x=hl.agg.sum(rt.idx) + hl.rand_int64()) + assert_contains_node(t, ir.TableKeyByAndAggregate) + t._force_count() # test with no consumer randomness + assert_unique_uids(t) + # with agg randomness + t = rt.group_by(k=rt.idx % 5).aggregate(x=hl.agg.sum(hl.rand_int64())) + assert_contains_node(t, ir.TableKeyByAndAggregate) + t._force_count() # test with no consumer randomness + assert_unique_uids(t) + # w/o body randomness + t = rt.group_by(k=rt.idx % 5).aggregate(x=hl.agg.sum(rt.idx)) + assert_contains_node(t, ir.TableKeyByAndAggregate) + assert_unique_uids(t) + + # test TableAggregateByKey + rt = hl.utils.range_table(20, 3) + t = rt.key_by(k=rt.idx % 5) + t = t.collect_by_key() + assert_contains_node(t, ir.TableAggregateByKey) + assert_unique_uids(t) + + # test MatrixColsTable + mt = hl.utils.range_matrix_table(10, 10, 3) + t = mt.cols() + assert_contains_node(t, ir.MatrixColsTable) + assert_unique_uids(t) + + # test TableParallelize + rt = hl.utils.range_table(20, 3) + # with body randomness + t = hl.Table.parallelize(hl.array([1, 2, 3]).map(lambda x: hl.struct(x=x, r=hl.rand_int64()))) + assert_contains_node(t, ir.TableParallelize) + t._force_count() # test with no consumer randomness + assert_unique_uids(t) + # w/o body randomness + t = hl.Table.parallelize(hl.array([1, 2, 3]).map(lambda x: hl.struct(x=x))) + assert_contains_node(t, ir.TableParallelize) + assert_unique_uids(t) + + # test TableHead + t = hl.utils.range_table(20, 3) + t = t.head(10) + assert_contains_node(t, ir.TableHead) + assert_unique_uids(t) + + # test TableTail + t = hl.utils.range_table(20, 3) + t = t.tail(10) + assert_contains_node(t, ir.TableTail) + assert_unique_uids(t) + + # test TableOrderBy + t = hl.utils.range_table(10, 3) + t = t.order_by(-t.idx) + assert_contains_node(t, ir.TableOrderBy) + assert_unique_uids(t) + + # test TableDistinct + rt = hl.utils.range_table(20, 3) + t = rt.key_by(k=rt.idx % 5) + t = t.distinct() + assert_contains_node(t, ir.TableDistinct) + assert_unique_uids(t) + + # test TableRepartition + if not hl.current_backend().requires_lowering: + rt = hl.utils.range_table(20, 3) + t = rt.repartition(5) + print(t._tir) + assert_contains_node(t, ir.TableRepartition) + assert_unique_uids(t) + + # test CastMatrixToTable + mt = hl.utils.range_matrix_table(10, 10, 3) + t = mt._localize_entries("entries", "cols") + assert_contains_node(t, ir.CastMatrixToTable) + assert_unique_uids(t) + + # test TableRename + rt = hl.utils.range_table(20, 3) + t = rt.rename({'idx': 'index'}) + assert_contains_node(t, ir.TableRename) + assert_unique_uids(t) + + # test TableMultiWayZipJoin + t1 = hl.utils.range_table(12, 3) + t1 = t1.key_by(k=(t1.idx // 2) * 2) + t2 = hl.utils.range_table(12, 3) + t2 = t2.key_by(k=(t2.idx // 3) * 3) + t3 = hl.utils.range_table(12, 3) + t3 = t3.key_by(k=(t3.idx // 4) * 4) + t = hl.Table.multi_way_zip_join([t1, t2, t3], 'data', 'globals') + assert_contains_node(t, ir.TableMultiWayZipJoin) + assert_unique_uids(t) + + # test TableFilterIntervals + rt = hl.utils.range_table(20, 3) + intervals = [hl.interval(0, 5), hl.interval(10, 15)] + t = hl.filter_intervals(rt, intervals) + assert_contains_node(t, ir.TableFilterIntervals) + assert_unique_uids(t) + + # test BlockMatrixToTable + bm = hl.linalg.BlockMatrix.fill(10, 10, 0) + t = bm.entries() + assert_contains_node(t, ir.BlockMatrixToTable) + assert_unique_uids(t) + + def test_query_table(): f = new_temp_file(extension='ht') ht = hl.utils.range_table(200, 10) diff --git a/hail/python/test/hail/test_context.py b/hail/python/test/hail/test_context.py index b2645ae9405..f71f6ba19d5 100644 --- a/hail/python/test/hail/test_context.py +++ b/hail/python/test/hail/test_context.py @@ -1,12 +1,8 @@ import unittest import hail as hl -from .helpers import startTestHailContext, stopTestHailContext, skip_unless_spark_backend, fails_local_backend, fails_service_backend from hail.utils.java import Env -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - class Tests(unittest.TestCase): def test_init_hail_context_twice(self): diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index 86091448a07..dcd378fc5ff 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -8,9 +8,6 @@ from hail.utils import new_temp_file from .helpers import * -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - class ValueIRTests(unittest.TestCase): def value_irs_env(self): @@ -44,7 +41,7 @@ def value_irs(self): s = ir.Ref('s', env['s']) t = ir.Ref('t', env['t']) call = ir.Ref('call', env['call']) - rngState = ir.RNGStateLiteral((1, 2, 3, 4)) + rngState = ir.RNGStateLiteral() table = ir.TableRange(5, 3) @@ -75,7 +72,7 @@ def aggregate(x): ir.ArraySort(ir.ToStream(a), 'l', 'r', ir.ApplyComparisonOp("LT", ir.Ref('l', hl.tint32), ir.Ref('r', hl.tint32))), ir.ToSet(a), ir.ToDict(da), - ir.ToArray(a), + ir.ToArray(st), ir.CastToArray(ir.NA(hl.tset(hl.tint32))), ir.MakeNDArray(ir.MakeArray([ir.F64(-1.0), ir.F64(1.0)], hl.tarray(hl.tfloat64)), ir.MakeTuple([ir.I64(1), ir.I64(2)]), diff --git a/hail/python/test/hail/utils/test_hl_hadoop_and_hail_fs.py b/hail/python/test/hail/utils/test_hl_hadoop_and_hail_fs.py index d275ff3751a..8355ee113a6 100644 --- a/hail/python/test/hail/utils/test_hl_hadoop_and_hail_fs.py +++ b/hail/python/test/hail/utils/test_hl_hadoop_and_hail_fs.py @@ -7,11 +7,6 @@ from hail.utils import hadoop_open, hadoop_copy from hailtop.utils import secret_alnum_string from hail.utils.java import FatalError -from ..helpers import startTestHailContext, stopTestHailContext, fails_service_backend - - -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext def touch(fs, filename: str): diff --git a/hail/python/test/hail/utils/test_placement_tree.py b/hail/python/test/hail/utils/test_placement_tree.py index 1e52a2cf312..55c70bd68b9 100644 --- a/hail/python/test/hail/utils/test_placement_tree.py +++ b/hail/python/test/hail/utils/test_placement_tree.py @@ -3,10 +3,6 @@ import hail as hl from hail.utils.placement_tree import PlacementTree -from ..helpers import startTestHailContext, stopTestHailContext - -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext class Tests(unittest.TestCase): diff --git a/hail/python/test/hail/utils/test_utils.py b/hail/python/test/hail/utils/test_utils.py index fc8e1bda7a8..e759d4901dd 100644 --- a/hail/python/test/hail/utils/test_utils.py +++ b/hail/python/test/hail/utils/test_utils.py @@ -9,9 +9,6 @@ from ..helpers import * -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - def normalize_path(path: str) -> str: return hl.hadoop_stat(path)['path'] @@ -285,15 +282,6 @@ def test_interval_ops(self): def test_range_matrix_table_n_lt_partitions(self): hl.utils.range_matrix_table(1, 1)._force_count_rows() - def test_seeding_is_consistent(self): - hl.set_global_seed(0) - a = [Env.next_seed() for _ in range(10)] - hl.set_global_seed(0) - b = [Env.next_seed() for _ in range(10)] - - self.assertEqual(len(set(a)), 10) - self.assertEqual(a, b) - def test_escape_string(self): self.assertEqual(escape_str("\""), "\\\"") self.assertEqual(escape_str("cat"), "cat") @@ -345,7 +333,7 @@ def test_json_encoder(self): @pytest.fixture(scope="module") -def glob_tests_directory(): +def glob_tests_directory(init_hail): with hl.TemporaryDirectory() as dirname: touch(dirname + '/abc/ghi/123') touch(dirname + '/abc/ghi/!23') @@ -371,16 +359,12 @@ def glob_tests_directory(): def test_hadoop_ls_folder_glob(glob_tests_directory): - fs = hl.current_backend().fs - expected = [glob_tests_directory + '/abc/ghi/123', glob_tests_directory + '/abc/jkl/123'] actual = [x['path'] for x in hl.hadoop_ls(glob_tests_directory + '/abc/*/123')] assert set(actual) == set(expected) def test_hadoop_ls_prefix_folder_glob_qmarks(glob_tests_directory): - fs = hl.current_backend().fs - expected = [glob_tests_directory + '/abc/ghi/78', glob_tests_directory + '/abc/jkl/78'] actual = [x['path'] for x in hl.hadoop_ls(glob_tests_directory + '/abc/*/??')] @@ -388,8 +372,6 @@ def test_hadoop_ls_prefix_folder_glob_qmarks(glob_tests_directory): def test_hadoop_ls_two_folder_globs(glob_tests_directory): - fs = hl.current_backend().fs - expected = [glob_tests_directory + '/abc/ghi/123', glob_tests_directory + '/abc/jkl/123', glob_tests_directory + '/def/ghi/123', @@ -399,8 +381,6 @@ def test_hadoop_ls_two_folder_globs(glob_tests_directory): def test_hadoop_ls_two_folder_globs_and_two_qmarks(glob_tests_directory): - fs = hl.current_backend().fs - expected = [glob_tests_directory + '/abc/ghi/78', glob_tests_directory + '/abc/jkl/78', glob_tests_directory + '/def/ghi/78', @@ -410,8 +390,6 @@ def test_hadoop_ls_two_folder_globs_and_two_qmarks(glob_tests_directory): def test_hadoop_ls_one_folder_glob_and_qmarks_in_multiple_components(glob_tests_directory): - fs = hl.current_backend().fs - expected = [glob_tests_directory + '/abc/ghi/78', glob_tests_directory + '/def/ghi/78'] actual = [x['path'] for x in hl.hadoop_ls(glob_tests_directory + '/*/?h?/??')] @@ -419,24 +397,18 @@ def test_hadoop_ls_one_folder_glob_and_qmarks_in_multiple_components(glob_tests_ def test_hadoop_ls_groups(glob_tests_directory): - fs = hl.current_backend().fs - expected = [glob_tests_directory + '/abc/ghi/123'] actual = [x['path'] for x in hl.hadoop_ls(glob_tests_directory + '/abc/[ghi][ghi]i/123')] assert set(actual) == set(expected) def test_hadoop_ls_size_one_groups(glob_tests_directory): - fs = hl.current_backend().fs - expected = [] actual = [x['path'] for x in hl.hadoop_ls(glob_tests_directory + '/abc/[h][g]i/123')] assert set(actual) == set(expected) def test_hadoop_ls_component_with_only_groups(glob_tests_directory): - fs = hl.current_backend().fs - expected = [glob_tests_directory + '/abc/ghi/123', glob_tests_directory + '/abc/ghi/!23', glob_tests_directory + '/abc/ghi/?23', @@ -447,8 +419,6 @@ def test_hadoop_ls_component_with_only_groups(glob_tests_directory): def test_hadoop_ls_negated_group(glob_tests_directory): - fs = hl.current_backend().fs - expected = [glob_tests_directory + '/abc/ghi/!23', glob_tests_directory + '/abc/ghi/?23'] actual = [x['path'] for x in hl.hadoop_ls(glob_tests_directory + '/abc/ghi/[!1]23')] diff --git a/hail/python/test/hail/vds/test_combiner.py b/hail/python/test/hail/vds/test_combiner.py index 46ab5b2e6b4..e8fdcb780a1 100644 --- a/hail/python/test/hail/vds/test_combiner.py +++ b/hail/python/test/hail/vds/test_combiner.py @@ -6,10 +6,7 @@ from hail.utils.misc import new_temp_file from hail.vds.combiner import combine_variant_datasets, new_combiner, load_combiner, transform_gvcf from hail.vds.combiner.combine import defined_entry_fields -from ..helpers import startTestHailContext, stopTestHailContext, resource, fails_local_backend, fails_service_backend - -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext +from ..helpers import resource, fails_local_backend, fails_service_backend all_samples = ['HG00308', 'HG00592', 'HG02230', 'NA18534', 'NA20760', diff --git a/hail/python/test/hail/vds/test_vds.py b/hail/python/test/hail/vds/test_vds.py index 89b0a746b4e..2acbd9cc486 100644 --- a/hail/python/test/hail/vds/test_vds.py +++ b/hail/python/test/hail/vds/test_vds.py @@ -4,10 +4,7 @@ import hail as hl from hail.utils import new_temp_file from hail.vds.combiner.combine import defined_entry_fields -from ..helpers import startTestHailContext, stopTestHailContext, resource, fails_local_backend, fails_service_backend - -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext +from ..helpers import resource, fails_local_backend, fails_service_backend # run this method to regenerate the combined VDS from 5 samples diff --git a/hail/python/test/hail/vds/test_vds_functions.py b/hail/python/test/hail/vds/test_vds_functions.py index b305d703b37..91706fdb82d 100644 --- a/hail/python/test/hail/vds/test_vds_functions.py +++ b/hail/python/test/hail/vds/test_vds_functions.py @@ -1,10 +1,5 @@ import hail as hl -from ..helpers import startTestHailContext, stopTestHailContext - -setUpModule = startTestHailContext -tearDownModule = stopTestHailContext - def test_lgt_to_gt(): call_0_0_f = hl.call(0, 0, phased=False) call_0_0_t = hl.call(0, 0, phased=True) diff --git a/hail/src/main/scala/is/hail/HailFeatureFlags.scala b/hail/src/main/scala/is/hail/HailFeatureFlags.scala index 711b7102258..ee3898150e0 100644 --- a/hail/src/main/scala/is/hail/HailFeatureFlags.scala +++ b/hail/src/main/scala/is/hail/HailFeatureFlags.scala @@ -29,7 +29,8 @@ object HailFeatureFlags { ("use_ssa_logs", "HAIL_USE_SSA_LOGS" -> null), ("gcs_requester_pays_project", "HAIL_GCS_REQUESTER_PAYS_PROJECT" -> null), ("gcs_requester_pays_buckets", "HAIL_GCS_REQUESTER_PAYS_BUCKETS" -> null), - ("index_branching_factor", "HAIL_INDEX_BRANCHING_FACTOR" -> null) + ("index_branching_factor", "HAIL_INDEX_BRANCHING_FACTOR" -> null), + ("rng_nonce", "HAIL_RNG_NONCE" -> "0x0") ) def fromEnv(): HailFeatureFlags = diff --git a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala index afa736ad215..8e14ac15e85 100644 --- a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala +++ b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala @@ -107,7 +107,7 @@ class ExecuteContext( ) extends Closeable { var backendContext: BackendContext = _ - val rngKey: IndexedSeq[Long] = Threefry.expandKey(Array(0L, 0L, 0L, 0L)) + val rngNonce: Long = java.lang.Long.decode(getFlag("rng_nonce")) private val tempFileManager: TempFileManager = if (_tempFileManager != null) _tempFileManager diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 4fef27825d8..623c7e6d380 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -11,7 +11,7 @@ import is.hail.asm4s._ import is.hail.backend.{Backend, BackendContext, BroadcastValue, ExecuteContext, HailTaskContext} import is.hail.expr.JSONAnnotationImpex import is.hail.expr.ir.lowering._ -import is.hail.expr.ir.{Compile, IR, IRParser, MakeTuple, SortField} +import is.hail.expr.ir.{Compile, IR, IRParser, MakeTuple, SortField, Threefry} import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.bgen.IndexBgen @@ -455,7 +455,7 @@ class ServiceBackendSocketAPI2( private[this] val backend: ServiceBackend, private[this] val in: InputStream, private[this] val out: OutputStream, - private[this] val sessionId: String + private[this] val sessionId: String, ) extends Thread { private[this] val LOAD_REFERENCES_FROM_DATASET = 1 private[this] val VALUE_TYPE = 2 diff --git a/hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala b/hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala index e5ecff4b9dd..a5621fcfd9e 100644 --- a/hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala @@ -945,7 +945,7 @@ case class ValueToBlockMatrix( } case class BlockMatrixRandom( - seed: Long, + staticUID: Long, gaussian: Boolean, shape: IndexedSeq[Long], blockSize: Int) extends BlockMatrixIR { @@ -961,11 +961,11 @@ case class BlockMatrixRandom( def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixRandom = { assert(newChildren.isEmpty) - BlockMatrixRandom(seed, gaussian, shape, blockSize) + BlockMatrixRandom(staticUID, gaussian, shape, blockSize) } override protected[ir] def execute(ctx: ExecuteContext): BlockMatrix = { - BlockMatrix.random(shape(0), shape(1), blockSize, seed, gaussian) + BlockMatrix.random(shape(0), shape(1), blockSize, ctx.rngNonce, staticUID, gaussian) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/Children.scala b/hail/src/main/scala/is/hail/expr/ir/Children.scala index f1327834966..c11788b7fd1 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Children.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Children.scala @@ -95,7 +95,7 @@ object Children { Array(orderedCollection, elem) case GroupByKey(collection) => Array(collection) - case RNGStateLiteral(_) => none + case RNGStateLiteral() => none case RNGSplit(state, split) => Array(state, split) case StreamLen(a) => @@ -215,7 +215,7 @@ object Children { args.toFastIndexedSeq case Apply(_, _, args, _, _) => args.toFastIndexedSeq - case ApplySeeded(_, args, rngState, seed, _) => + case ApplySeeded(_, args, rngState, _, _) => args.toFastIndexedSeq :+ rngState case ApplySpecial(_, _, args, _, _) => args.toFastIndexedSeq diff --git a/hail/src/main/scala/is/hail/expr/ir/Copy.scala b/hail/src/main/scala/is/hail/expr/ir/Copy.scala index 46d307e5c87..9ea593dc625 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Copy.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Copy.scala @@ -165,7 +165,7 @@ object Copy { case GroupByKey(_) => assert(newChildren.length == 1) GroupByKey(newChildren(0).asInstanceOf[IR]) - case RNGStateLiteral(key) => RNGStateLiteral(key) + case RNGStateLiteral() => RNGStateLiteral() case RNGSplit(_, _) => assert(newChildren.nonEmpty) RNGSplit(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR]) @@ -329,8 +329,8 @@ object Copy { r case Apply(fn, typeArgs, args, t, errorID) => Apply(fn, typeArgs, newChildren.map(_.asInstanceOf[IR]), t, errorID) - case ApplySeeded(fn, args, rngState, seed, t) => - ApplySeeded(fn, newChildren.init.map(_.asInstanceOf[IR]), newChildren.last.asInstanceOf[IR], seed, t) + case ApplySeeded(fn, args, rngState, staticUID, t) => + ApplySeeded(fn, newChildren.init.map(_.asInstanceOf[IR]), newChildren.last.asInstanceOf[IR], staticUID, t) case ApplySpecial(fn, typeArgs, args, t, errorID) => ApplySpecial(fn, typeArgs, newChildren.map(_.asInstanceOf[IR]), t, errorID) // from MatrixIR diff --git a/hail/src/main/scala/is/hail/expr/ir/Emit.scala b/hail/src/main/scala/is/hail/expr/ir/Emit.scala index 2b3acfc8625..4944cdab59d 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Emit.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Emit.scala @@ -1306,29 +1306,26 @@ class Emit[C]( dictType.construct(finishOuter(cb)) } - case RNGStateLiteral(key) => - IEmitCode.present(cb, SRNGStateStaticSizeValue(cb, key)) + case RNGStateLiteral() => + IEmitCode.present(cb, SRNGStateStaticSizeValue(cb)) case RNGSplit(state, dynBitstring) => - // FIXME: When new rng support is complete, don't allow missing states - emitI(state).flatMap(cb) { stateValue => - emitI(dynBitstring).map(cb) { tupleOrLong => - val longs = if (tupleOrLong.isInstanceOf[SInt64Value]) { - Array(tupleOrLong.asInt64.value) - } else { - val tuple = tupleOrLong.asBaseStruct - Array.tabulate(tuple.st.size) { i => - tuple.loadField(cb, i) - .get(cb, "RNGSplit tuple components are required") - .asInt64 - .value - } - } - var result = stateValue.asRNGState - longs.foreach(l => result = result.splitDyn(cb, l)) - result + val stateValue = emitI(state).get(cb) + val tupleOrLong = emitI(dynBitstring).get(cb) + val longs = if (tupleOrLong.isInstanceOf[SInt64Value]) { + Array(tupleOrLong.asInt64.value) + } else { + val tuple = tupleOrLong.asBaseStruct + Array.tabulate(tuple.st.size) { i => + tuple.loadField(cb, i) + .get(cb, "RNGSplit tuple components are required") + .asInt64 + .value } } + var result = stateValue.asRNGState + longs.foreach(l => result = result.splitDyn(cb, l)) + presentPC(result) case x@StreamLen(a) => emitStream(a, cb, region).map(cb) { case stream: SStreamValue => @@ -2066,24 +2063,14 @@ class Emit[C]( val rvAgg = agg.Extract.getAgg(sig) rvAgg.result(cb, sc.states(idx), region) - case x@ApplySeeded(fn, args, rngState, seed, rt) => + case x@ApplySeeded(fn, args, rngState, staticUID, rt) => val codeArgs = args.map(a => EmitCode.fromI(cb.emb)(emitInNewBuilder(_, a))) + val codeArgsMem = codeArgs.map(_.memoize(cb, "ApplySeeded_arg")) + val state = emitI(rngState).get(cb) val impl = x.implementation - val unified = impl.unify(Array.empty[Type], x.argTypes, rt) - assert(unified) - val newRNGEnabled = Array[String]() - if (newRNGEnabled.contains(fn)) { - val pureImpl = x.pureImplementation - assert(pureImpl.unify(Array.empty[Type], rngState.typ +: x.argTypes, rt)) - val codeArgsMem = codeArgs.map(_.memoize(cb, "ApplySeeded_arg")) - emitI(rngState).consumeI(cb, { - impl.applySeededI(seed, cb, region, impl.computeReturnEmitType(x.typ, codeArgs.map(_.emitType)).st, codeArgsMem.map(_.load): _*) - }, { state => - pureImpl.applyI(EmitRegion(cb.emb, region), cb, impl.computeReturnEmitType(x.typ, codeArgs.map(_.emitType)).st, Seq[Type](), const(0), EmitCode.present(mb, state) +: codeArgsMem.map(_.load): _*) - }) - } else { - impl.applySeededI(seed, cb, region, impl.computeReturnEmitType(x.typ, codeArgs.map(_.emitType)).st, codeArgs: _*) - } + assert(impl.unify(Array.empty[Type], x.argTypes, rt)) + val newState = EmitCode.present(mb, state.asRNGState.splitStatic(cb, staticUID)) + impl.applyI(EmitRegion(cb.emb, region), cb, impl.computeReturnEmitType(x.typ, newState.emitType +: codeArgs.map(_.emitType)).st, Seq[Type](), const(0), newState +: codeArgsMem.map(_.load): _*) case AggStateValue(i, _) => val AggContainer(_, sc, _) = container.get diff --git a/hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala b/hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala index 82d19f6cb9d..cebfb754893 100644 --- a/hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala +++ b/hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala @@ -171,6 +171,8 @@ trait WrappedEmitClassBuilder[C] extends WrappedEmitModuleBuilder { def newRNG(seed: Long): Value[IRRandomness] = ecb.newRNG(seed) + def getThreefryRNG(): Value[ThreefryRandomEngine] = ecb.getThreefryRNG() + def resultWithIndex(writeIRs: Boolean = false, print: Option[PrintWriter] = None): (HailClassLoader, FS, Int, Region) => C = ecb.resultWithIndex(writeIRs, print) def getOrGenEmitMethod( @@ -600,6 +602,8 @@ class EmitClassBuilder[C]( val rngs: BoxedArrayBuilder[(Settable[IRRandomness], Code[IRRandomness])] = new BoxedArrayBuilder() + var threefryRNG: Option[(Settable[ThreefryRandomEngine], Code[ThreefryRandomEngine])] = None + def makeAddPartitionRegion(): Unit = { cb.addInterface(typeInfo[FunctionWithPartitionRegion].iname) val mb = newEmitMethod("addPartitionRegion", FastIndexedSeq[ParamType](typeInfo[Region]), typeInfo[Unit]) @@ -645,9 +649,15 @@ class EmitClassBuilder[C]( val mb = newEmitMethod("setPartitionIndex", IndexedSeq[ParamType](typeInfo[Int]), typeInfo[Unit]) val rngFields = rngs.result() - val initialize = Code(rngFields.map { case (field, initialization) => - field := initialization - }) + val initialize = Code( + Code(rngFields.map { case (field, initialization) => + field := initialization + }), + threefryRNG match { + case Some((field, init)) => field := init + case None => Code._empty.get + } + ) val reseed = Code(rngFields.map { case (field, _) => field.invoke[Int, Unit]("reset", mb.getCodeParam[Int](1)) @@ -666,6 +676,18 @@ class EmitClassBuilder[C]( rng } + def getThreefryRNG(): Value[ThreefryRandomEngine] = { + threefryRNG match { + case Some((rngField, _)) => rngField + case None => + val rngField = genFieldThisRef[ThreefryRandomEngine]() + val rngInit = Code.invokeScalaObject0[ThreefryRandomEngine]( + ThreefryRandomEngine.getClass, "apply") + threefryRNG = Some(rngField -> rngInit) + rngField + } + } + def resultWithIndex( writeIRs: Boolean, print: Option[PrintWriter] = None diff --git a/hail/src/main/scala/is/hail/expr/ir/IR.scala b/hail/src/main/scala/is/hail/expr/ir/IR.scala index cc93419945b..58e6bddb8da 100644 --- a/hail/src/main/scala/is/hail/expr/ir/IR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/IR.scala @@ -295,12 +295,7 @@ final case class LowerBoundOnOrderedCollection(orderedCollection: IR, elem: IR, final case class GroupByKey(collection: IR) extends IR -// FIXME: Revisit all uses after all infra is in place -object RNGStateLiteral { - def apply(): RNGStateLiteral = - RNGStateLiteral(Array.fill(4)(util.Random.nextLong())) -} -final case class RNGStateLiteral(key: IndexedSeq[Long]) extends IR +final case class RNGStateLiteral() extends IR final case class RNGSplit(state: IR, dynBitstring: IR) extends IR @@ -700,11 +695,9 @@ sealed abstract class AbstractApplyNode[F <: JVMFunction] extends IR { final case class Apply(function: String, typeArgs: Seq[Type], args: Seq[IR], returnType: Type, errorID: Int) extends AbstractApplyNode[UnseededMissingnessObliviousJVMFunction] -final case class ApplySeeded(function: String, args: Seq[IR], rngState: IR, seed: Long, returnType: Type) extends AbstractApplyNode[SeededJVMFunction] { +final case class ApplySeeded(function: String, _args: Seq[IR], rngState: IR, staticUID: Long, returnType: Type) extends AbstractApplyNode[UnseededMissingnessObliviousJVMFunction] { + val args = rngState +: _args val typeArgs: Seq[Type] = Seq.empty[Type] - lazy val pureImplementation: UnseededMissingnessObliviousJVMFunction = - IRFunctionRegistry.lookupFunctionOrFail(function + "_pure", returnType, typeArgs, TRNGState +: argTypes) - .asInstanceOf[UnseededMissingnessObliviousJVMFunction] } final case class ApplySpecial(function: String, typeArgs: Seq[Type], args: Seq[IR], returnType: Type, errorID: Int) extends AbstractApplyNode[UnseededMissingnessAwareJVMFunction] diff --git a/hail/src/main/scala/is/hail/expr/ir/InferType.scala b/hail/src/main/scala/is/hail/expr/ir/InferType.scala index 52a95988eaf..5f31258a1b4 100644 --- a/hail/src/main/scala/is/hail/expr/ir/InferType.scala +++ b/hail/src/main/scala/is/hail/expr/ir/InferType.scala @@ -111,7 +111,7 @@ object InferType { case ToStream(a, _) => val elt = tcoerce[TIterable](a.typ).elementType TStream(elt) - case RNGStateLiteral(_) => + case RNGStateLiteral() => TRNGState case RNGSplit(_, _) => TRNGState diff --git a/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala b/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala index e4672387df0..23bf243851b 100644 --- a/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala @@ -673,10 +673,11 @@ object LowerMatrixIR { val keyMap = Symbol(genUID()) val aggElementIdx = Symbol(genUID()) - val substEnv = matrixSubstEnv(child) - val ceSub = subst(lower(ctx, colExpr, ab), substEnv) - val vaBinding = 'row.selectFields(child.typ.rowType.fieldNames: _*) - val eeSub = subst(lower(ctx, entryExpr, ab), substEnv.bindEval("va", vaBinding).bindAgg("va", vaBinding)) + val e1 = Env[IRProxy]("global" -> 'global.selectFields(child.typ.globalType.fieldNames: _*), + "va" -> 'row.selectFields(child.typ.rowType.fieldNames: _*)) + val e2 = Env[IRProxy]("global" -> 'global.selectFields(child.typ.globalType.fieldNames: _*)) + val ceSub = subst(lower(ctx, colExpr, ab), BindingEnv(e2, agg = Some(e1))) + val eeSub = subst(lower(ctx, entryExpr, ab), BindingEnv(e1, agg = Some(e1))) lower(ctx, child, ab) .mapGlobals('global.insertFields(keyMap -> diff --git a/hail/src/main/scala/is/hail/expr/ir/Parser.scala b/hail/src/main/scala/is/hail/expr/ir/Parser.scala index 0696e835a1e..0e18316180f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Parser.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Parser.scala @@ -937,7 +937,7 @@ object IRParser { case Array(a, start, stop, step) => ArraySlice(a, start, Some(stop), step, errorID) } case "RNGStateLiteral" => - done(RNGStateLiteral(int64_literals(it))) + done(RNGStateLiteral()) case "RNGSplit" => for { state <- ir_value_expr(env)(it) @@ -1363,12 +1363,12 @@ object IRParser { } yield ConsoleLog(msg, result) case "ApplySeeded" => val function = identifier(it) - val seed = int64_literal(it) + val staticUID = int64_literal(it) val rt = type_expr(env.typEnv)(it) for { rngState <- ir_value_expr(env)(it) args <- ir_value_children(env)(it) - } yield ApplySeeded(function, args, rngState, seed, rt) + } yield ApplySeeded(function, args, rngState, staticUID, rt) case "ApplyIR" | "ApplySpecial" | "Apply" => val errorID = int32_literal(it) val function = identifier(it) @@ -2054,11 +2054,11 @@ object IRParser { ValueToBlockMatrix(child, shape, blockSize) } case "BlockMatrixRandom" => - val seed = int64_literal(it) + val staticUID = int64_literal(it) val gaussian = boolean_literal(it) val shape = int64_literals(it) val blockSize = int32_literal(it) - done(BlockMatrixRandom(seed, gaussian, shape, blockSize)) + done(BlockMatrixRandom(staticUID, gaussian, shape, blockSize)) case "RelationalLetBlockMatrix" => val name = identifier(it) for { diff --git a/hail/src/main/scala/is/hail/expr/ir/Pretty.scala b/hail/src/main/scala/is/hail/expr/ir/Pretty.scala index a7c0b0cca16..9f83e6e190c 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Pretty.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Pretty.scala @@ -196,7 +196,6 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, case ApplyBinaryPrimOp(op, _, _) => single(Pretty.prettyClass(op)) case ApplyUnaryPrimOp(op, _) => single(Pretty.prettyClass(op)) case ApplyComparisonOp(op, _, _) => single(op.render()) - case RNGStateLiteral(key) => single(prettyLongs(key, false)) case GetField(_, name) => single(prettyIdentifier(name)) case GetTupleElement(_, idx) => single(idx.toString) case MakeTuple(fields) => FastSeq(prettyInts(fields.map(_._1).toFastIndexedSeq, elideLiterals)) @@ -271,7 +270,7 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, case ArrayRef(_,_, errorID) => single(s"$errorID") case ApplyIR(function, typeArgs, _, errorID) => FastSeq(s"$errorID", prettyIdentifier(function), prettyTypes(typeArgs), ir.typ.parsableString()) case Apply(function, typeArgs, _, t, errorID) => FastSeq(s"$errorID", prettyIdentifier(function), prettyTypes(typeArgs), t.parsableString()) - case ApplySeeded(function, _, rngState, seed, t) => FastSeq(prettyIdentifier(function), seed.toString, t.parsableString()) + case ApplySeeded(function, _, rngState, staticUID, t) => FastSeq(prettyIdentifier(function), staticUID.toString, t.parsableString()) case ApplySpecial(function, typeArgs, _, t, errorID) => FastSeq(s"$errorID", prettyIdentifier(function), prettyTypes(typeArgs), t.parsableString()) case SelectFields(_, fields) => single(fillList(fields.view.map(f => text(prettyIdentifier(f))))) case LowerBoundOnOrderedCollection(_, _, onKey) => single(Pretty.prettyBooleanLiteral(onKey)) @@ -307,8 +306,9 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, single(fillList(indicesToKeepPerDim.toSeq.view.map(indices => prettyLongs(indices, elideLiterals)))) case BlockMatrixSparsify(_, sparsifier) => single(sparsifier.pretty()) - case BlockMatrixRandom(seed, gaussian, shape, blockSize) => - FastSeq(seed.toString, + case BlockMatrixRandom(staticUID, gaussian, shape, blockSize) => + FastSeq( + staticUID.toString, Pretty.prettyBooleanLiteral(gaussian), prettyLongs(shape, elideLiterals), blockSize.toString) diff --git a/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala b/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala index b2b02038532..1f447b81d4d 100644 --- a/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala +++ b/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala @@ -791,7 +791,8 @@ object PruneDeadFields { case MatrixRepartition(child, _, _) => memoizeMatrixIR(ctx, child, requestedType, memo) case MatrixUnionRows(children) => - children.foreach(memoizeMatrixIR(ctx, _, requestedType, memo)) + memoizeMatrixIR(ctx, children.head, requestedType, memo) + children.tail.foreach(memoizeMatrixIR(ctx, _, requestedType.copy(colType = requestedType.colKeyStruct), memo)) case MatrixDistinctByRow(child) => val dep = requestedType.copy( rowKey = child.typ.rowKey, @@ -1752,10 +1753,12 @@ object PruneDeadFields { rebuildIR(ctx, colExpr, BindingEnv(child2.typ.globalEnv, agg = Some(child2.typ.colEnv)), memo)) case MatrixUnionRows(children) => val requestedType = memo.requestedType.lookup(mir).asInstanceOf[MatrixType] - MatrixUnionRows(children.map { child => - upcast(ctx, rebuild(ctx, child, memo), requestedType, + val firstChild = upcast(ctx, rebuild(ctx, children.head, memo), requestedType, upcastGlobals = false) + val remainingChildren = children.tail.map { child => + upcast(ctx, rebuild(ctx, child, memo), requestedType.copy(colType = requestedType.colKeyStruct), upcastGlobals = false) - }) + } + MatrixUnionRows(firstChild +: remainingChildren) case MatrixUnionCols(left, right, joinType) => val requestedType = memo.requestedType.lookup(mir).asInstanceOf[MatrixType] val left2 = rebuild(ctx, left, memo) diff --git a/hail/src/main/scala/is/hail/expr/ir/Random.scala b/hail/src/main/scala/is/hail/expr/ir/Random.scala index 085d6873d70..b4ec50be458 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Random.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Random.scala @@ -1,11 +1,16 @@ package is.hail.expr.ir import is.hail.asm4s._ -import is.hail.types.physical.stypes.concrete.SRNGState import is.hail.utils.FastIndexedSeq +import net.sourceforge.jdistlib.{Beta, Gamma, HyperGeometric, Poisson} import net.sourceforge.jdistlib.rng.RandomEngine +import org.apache.commons.math3.random.RandomGenerator object Threefry { + val staticTweak = -1L + val finalBlockNoPadTweak = -2L + val finalBlockPaddedTweak = -3L + val keyConst = 0x1BD11BDAA9FC1A22L val rotConsts = Array( @@ -20,6 +25,9 @@ object Threefry { val defaultNumRounds = 20 + val defaultKey: IndexedSeq[Long] = + expandKey(FastIndexedSeq(0x215d6dfdb7dfdf6bL, 0x045cfa043329c49fL, 0x9ec75a93692444ddL, 0x1284681663220f1cL)) + def expandKey(k: IndexedSeq[Long]): IndexedSeq[Long] = { assert(k.length == 4) val k4 = k(0) ^ k(1) ^ k(2) ^ k(3) ^ keyConst @@ -36,24 +44,25 @@ object Threefry { cb.assign(x1, x0 ^ x1) } - def injectKey(key: IndexedSeq[Long], tweak: Long, block: Array[Long], s: Int): Unit = { - val tweakExt = Array[Long](tweak, 0, tweak) + def injectKey(key: IndexedSeq[Long], tweak: IndexedSeq[Long], block: Array[Long], s: Int): Unit = { + assert(tweak.length == 3) + assert(key.length == 5) + assert(block.length == 4) block(0) += key(s % 5) - block(1) += key((s + 1) % 5) + tweakExt(s % 3) - block(2) += key((s + 2) % 5) + tweakExt((s + 1) % 3) + block(1) += key((s + 1) % 5) + tweak(s % 3) + block(2) += key((s + 2) % 5) + tweak((s + 1) % 3) block(3) += key((s + 3) % 5) + s.toLong } def injectKey(cb: CodeBuilderLike, key: IndexedSeq[Long], - tweak: Value[Long], + tweak: IndexedSeq[Value[Long]], block: IndexedSeq[Settable[Long]], s: Int ): Unit = { - val tweakExt = Array[Value[Long]](tweak, const(0), tweak) cb.assign(block(0), block(0) + key(s % 5)) - cb.assign(block(1), block(1) + const(key((s + 1) % 5)) + tweakExt(s % 3)) - cb.assign(block(2), block(2) + const(key((s + 2) % 5)) + tweakExt((s + 1) % 3)) + cb.assign(block(1), block(1) + const(key((s + 1) % 5)) + tweak(s % 3)) + cb.assign(block(2), block(2) + const(key((s + 2) % 5)) + tweak((s + 1) % 3)) cb.assign(block(3), block(3) + const(key((s + 3) % 5)) + const(s.toLong)) } @@ -63,16 +72,14 @@ object Threefry { x(3) = tmp } - def encryptUnrolled(k0: Long, k1: Long, k2: Long, k3: Long, t: Long, _x0: Long, _x1: Long, _x2: Long, _x3: Long): Unit = { + def encryptUnrolled(k0: Long, k1: Long, k2: Long, k3: Long, k4: Long, t0: Long, t1: Long, x: Array[Long]): Unit = { import java.lang.Long.rotateLeft - var x0 = _x0 - var x1 = _x1 - var x2 = _x2 - var x3 = _x3 - val k4 = k0 ^ k1 ^ k2 ^ k3 ^ keyConst + var x0 = x(0); var x1 = x(1); var x2 = x(2); var x3 = x(3) + val t2 = t0 ^ t1 + // d = 0 // injectKey s = 0 - x0 += k0; x1 += k1 + t; x2 += k2; x3 += k3 + x0 += k0; x1 += k1 + t0; x2 += k2 + t1; x3 += k3 x0 += x1; x1 = rotateLeft(x1, 14); x1 ^= x0 x2 += x3; x3 = rotateLeft(x3, 16); x3 ^= x2 // d = 1 @@ -86,7 +93,7 @@ object Threefry { x2 += x1; x1 = rotateLeft(x1, 37); x1 ^= x2 // d = 4 // injectKey s = 1 - x0 += k1; x1 += k2; x2 += k3 + t; x3 += k4 + 1 + x0 += k1; x1 += k2 + t1; x2 += k3 + t2; x3 += k4 + 1 x0 += x1; x1 = rotateLeft(x1, 25); x1 ^= x0 x2 += x3; x3 = rotateLeft(x3, 33); x3 ^= x2 // d = 5 @@ -100,7 +107,7 @@ object Threefry { x2 += x1; x1 = rotateLeft(x1, 32); x1 ^= x2 // d = 8 // injectKey s = 2 - x0 += k2; x1 += k3 + t; x2 += k4 + t; x3 += k0 + 2 + x0 += k2; x1 += k3 + t2; x2 += k4 + t0; x3 += k0 + 2 x0 += x1; x1 = rotateLeft(x1, 14); x1 ^= x0 x2 += x3; x3 = rotateLeft(x3, 16); x3 ^= x2 // d = 9 @@ -114,7 +121,7 @@ object Threefry { x2 += x1; x1 = rotateLeft(x1, 37); x1 ^= x2 // d = 12 // injectKey s = 3 - x0 += k3; x1 += k4 + t; x2 += k0; x3 += k1 + 3 + x0 += k3; x1 += k4 + t0; x2 += k0 + t1; x3 += k1 + 3 x0 += x1; x1 = rotateLeft(x1, 25); x1 ^= x0 x2 += x3; x3 = rotateLeft(x3, 33); x3 ^= x2 // d = 13 @@ -128,7 +135,7 @@ object Threefry { x2 += x1; x1 = rotateLeft(x1, 32); x1 ^= x2 // d = 16 // injectKey s = 4 - x0 += k4; x1 += k0; x2 += k1 + t; x3 += k2 + 4 + x0 += k4; x1 += k0 + t1; x2 += k1 + t2; x3 += k2 + 4 x0 += x1; x1 = rotateLeft(x1, 14); x1 ^= x0 x2 += x3; x3 = rotateLeft(x3, 16); x3 ^= x2 // d = 17 @@ -142,15 +149,19 @@ object Threefry { x2 += x1; x1 = rotateLeft(x1, 37); x1 ^= x2 // d = 20 // injectKey s = 5 - x0 += k0; x1 += k1 + t; x2 += k2 + t; x3 += k3 + 5 + x0 += k0; x1 += k1 + t2; x2 += k2 + t0; x3 += k3 + 5 + + x(0) = x0; x(1) = x1; x(2) = x2; x(3) = x3 } - def encrypt(k: IndexedSeq[Long], t: Long, x: Array[Long]): Unit = + def encrypt(k: IndexedSeq[Long], t: IndexedSeq[Long], x: Array[Long]): Unit = encrypt(k, t, x, defaultNumRounds) - def encrypt(k: IndexedSeq[Long], t: Long, x: Array[Long], rounds: Int): Unit = { + def encrypt(k: IndexedSeq[Long], _t: IndexedSeq[Long], x: Array[Long], rounds: Int): Unit = { assert(k.length == 5) + assert(_t.length == 2) assert(x.length == 4) + val t = Array(_t(0), _t(1), _t(0) ^ _t(1)) for (d <- 0 until rounds) { if (d % 4 == 0) @@ -174,20 +185,22 @@ object Threefry { def encrypt(cb: CodeBuilderLike, k: IndexedSeq[Long], - t: Value[Long], + t: IndexedSeq[Value[Long]], x: IndexedSeq[Settable[Long]] ): Unit = encrypt(cb, k, t, x, defaultNumRounds) def encrypt(cb: CodeBuilderLike, k: IndexedSeq[Long], - t: Value[Long], + _t: IndexedSeq[Value[Long]], _x: IndexedSeq[Settable[Long]], rounds: Int ): Unit = { assert(k.length == 5) + assert(_t.length == 2) assert(_x.length == 4) val x = _x.toArray + val t = Array(_t(0), _t(1), cb.memoize(_t(0) ^ _t(1))) for (d <- 0 until rounds) { if (d % 4 == 0) @@ -207,11 +220,46 @@ object Threefry { cb.println(s"[$info]=\n\t", x(0).toString, " ", x(1).toString, " ", x(2).toString, " ", x(3).toString) } - def apply(k: IndexedSeq[Long]): AsmFunction2[Array[Long], Long, Unit] = { - val f = FunctionBuilder[Array[Long], Long, Unit]("Threefry") + def pmac(nonce: Long, staticID: Long, message: IndexedSeq[Long]): Array[Long] = { + val (hash, finalTweak) = pmacHash(nonce, staticID, message) + encrypt(Threefry.defaultKey, Array(finalTweak, 0L), hash) + hash + } + + def pmacHash(nonce: Long, staticID: Long, _message: IndexedSeq[Long]): (Array[Long], Long) = { + val length = _message.length + val paddedLength = Math.max((length + 3) & (~3), 4) + val padded = (paddedLength != length) + val message = Array.ofDim[Long](paddedLength) + _message.copyToArray(message) + if (padded) message(length) = 1L + + val sum = Array(nonce, staticID, 0L, 0L) + encrypt(Threefry.defaultKey, Array(Threefry.staticTweak, 0L), sum) + var i = 0 + while (i + 4 < paddedLength) { + val x = message.slice(i, i + 4) + encrypt(Threefry.defaultKey, Array(i.toLong, 0L), x) + sum(0) ^= x(0) + sum(1) ^= x(1) + sum(2) ^= x(2) + sum(3) ^= x(3) + i += 4 + } + for (j <- 0 until 4) { + sum(j) ^= message(i + j) + } + val finalTweak = if (padded) Threefry.finalBlockPaddedTweak else Threefry.finalBlockNoPadTweak + + (sum, finalTweak) + } + + def apply(k: IndexedSeq[Long]): AsmFunction2[Array[Long], Array[Long], Unit] = { + val f = FunctionBuilder[Array[Long], Array[Long], Unit]("Threefry") f.mb.emitWithBuilder { cb => val xArray = f.mb.getArg[Array[Long]](1) - val t = f.mb.getArg[Long](2) + val tArray = f.mb.getArg[Array[Long]](2) + val t = Array(cb.memoize(tArray(0)), cb.memoize(tArray(1))) val x = Array.tabulate[Settable[Long]](4)(i => cb.newLocal[Long](s"x$i", xArray(i))) encrypt(cb, expandKey(k), t, x) for (i <- 0 until 4) cb += (xArray(i) = x(i)) @@ -221,44 +269,44 @@ object Threefry { } } -class RNGState { - val staticAcc: Array[Long] = Array.fill(4)(0) - val staticIdx: Int = 0 - val staticOpen: Array[Long] = Array.fill(4)(0) - val staticOpenLen: Int = 0 - val dynAcc: Array[Long] = Array.fill(4)(0) - val dynIdx: Int = 0 - val dynOpen: Array[Long] = Array.fill(4)(0) - val dynOpenLen: Int = 0 -} - object ThreefryRandomEngine { - def apply( - k1: Long, k2: Long, k3: Long, k4: Long, - h1: Long, h2: Long, h3: Long, h4: Long, - x1: Long, x2: Long, x3: Long - ): ThreefryRandomEngine = { + def apply(): ThreefryRandomEngine = { + val key = Threefry.defaultKey new ThreefryRandomEngine( - Threefry.expandKey(FastIndexedSeq(k1, k2, k3, k4)), - Array(h1 ^ x1, h2 ^ x2, h3 ^ x3, h4), - 0) + key(0), key(1), key(2), key(3), key(4), + 0, 0, 0, 0, 0, 0) } - def apply(): ThreefryRandomEngine = { + def apply(nonce: Long, staticID: Long, message: IndexedSeq[Long]): ThreefryRandomEngine = { + val engine = ThreefryRandomEngine() + val (hash, finalTweak) = Threefry.pmacHash(nonce, staticID, message) + engine.resetState(hash(0), hash(1), hash(2), hash(3), finalTweak) + engine + } + + def randState(): ThreefryRandomEngine = { val rand = new java.util.Random() + val key = Threefry.expandKey(Array.fill(4)(rand.nextLong())) new ThreefryRandomEngine( - Threefry.expandKey(Array.fill(4)(rand.nextLong())), - Array.fill(4)(rand.nextLong()), - 0) + key(0), key(1), key(2), key(3), key(4), + rand.nextLong(), rand.nextLong(), rand.nextLong(), rand.nextLong(), + 0, 0) } } class ThreefryRandomEngine( - val key: IndexedSeq[Long], - val state: Array[Long], + val k0: Long, + val k1: Long, + val k2: Long, + val k3: Long, + val k4: Long, + var state0: Long, + var state1: Long, + var state2: Long, + var state3: Long, var counter: Long, - val tweak: Long = SRNGState.finalBlockNoPadTweak -) extends RandomEngine { + var tweak: Long +) extends RandomEngine with RandomGenerator { val buffer: Array[Long] = Array.ofDim[Long](4) var usedInts: Int = 8 var hasBufferedGaussian: Boolean = false @@ -266,98 +314,53 @@ class ThreefryRandomEngine( override def clone(): ThreefryRandomEngine = ??? + def resetState(s0: Long, s1: Long, s2: Long, s3: Long, _tweak: Long): Unit = { + state0 = s0 + state1 = s1 + state2 = s2 + state3 = s3 + tweak = _tweak + counter = 0 + usedInts = 8 + hasBufferedGaussian = false + } + + private[this] val poisState = Poisson.create_random_state() + + def runif(min: Double, max: Double): Double = min + (max - min) * nextDouble() + + def rnorm(mean: Double, sd: Double): Double = mean + sd * nextGaussian() + + def rpois(lambda: Double): Double = Poisson.random(lambda, this, poisState) + + def rbeta(a: Double, b: Double): Double = Beta.random(a, b, this) + + def rgamma(shape: Double, scale: Double): Double = Gamma.random(shape, scale, this) + + def rhyper(numSuccessStates: Double, numFailureStates: Double, numToDraw: Double): Double = + HyperGeometric.random(numSuccessStates, numFailureStates, numToDraw, this) + private def fillBuffer(): Unit = { - import java.lang.Long.rotateLeft - var x0 = state(0) - var x1 = state(1) - var x2 = state(2) - var x3 = state(3) ^ counter - val k0 = key(0); val k1 = key(1); val k2 = key(2); val k3 = key(3) - val k4 = k0 ^ k1 ^ k2 ^ k3 ^ Threefry.keyConst - val t = tweak - // d = 0 - // injectKey s = 0 - x0 += k0; x1 += k1 + t; x2 += k2; x3 += k3 - x0 += x1; x1 = rotateLeft(x1, 14); x1 ^= x0 - x2 += x3; x3 = rotateLeft(x3, 16); x3 ^= x2 - // d = 1 - x0 += x3; x3 = rotateLeft(x3, 52); x3 ^= x0 - x2 += x1; x1 = rotateLeft(x1, 57); x1 ^= x2 - // d = 2 - x0 += x1; x1 = rotateLeft(x1, 23); x1 ^= x0 - x2 += x3; x3 = rotateLeft(x3, 40); x3 ^= x2 - // d = 3 - x0 += x3; x3 = rotateLeft(x3, 5); x3 ^= x0 - x2 += x1; x1 = rotateLeft(x1, 37); x1 ^= x2 - // d = 4 - // injectKey s = 1 - x0 += k1; x1 += k2; x2 += k3 + t; x3 += k4 + 1 - x0 += x1; x1 = rotateLeft(x1, 25); x1 ^= x0 - x2 += x3; x3 = rotateLeft(x3, 33); x3 ^= x2 - // d = 5 - x0 += x3; x3 = rotateLeft(x3, 46); x3 ^= x0 - x2 += x1; x1 = rotateLeft(x1, 12); x1 ^= x2 - // d = 6 - x0 += x1; x1 = rotateLeft(x1, 58); x1 ^= x0 - x2 += x3; x3 = rotateLeft(x3, 22); x3 ^= x2 - // d = 7 - x0 += x3; x3 = rotateLeft(x3, 32); x3 ^= x0 - x2 += x1; x1 = rotateLeft(x1, 32); x1 ^= x2 - // d = 8 - // injectKey s = 2 - x0 += k2; x1 += k3 + t; x2 += k4 + t; x3 += k0 + 2 - x0 += x1; x1 = rotateLeft(x1, 14); x1 ^= x0 - x2 += x3; x3 = rotateLeft(x3, 16); x3 ^= x2 - // d = 9 - x0 += x3; x3 = rotateLeft(x3, 52); x3 ^= x0 - x2 += x1; x1 = rotateLeft(x1, 57); x1 ^= x2 - // d = 10 - x0 += x1; x1 = rotateLeft(x1, 23); x1 ^= x0 - x2 += x3; x3 = rotateLeft(x3, 40); x3 ^= x2 - // d = 11 - x0 += x3; x3 = rotateLeft(x3, 5); x3 ^= x0 - x2 += x1; x1 = rotateLeft(x1, 37); x1 ^= x2 - // d = 12 - // injectKey s = 3 - x0 += k3; x1 += k4 + t; x2 += k0; x3 += k1 + 3 - x0 += x1; x1 = rotateLeft(x1, 25); x1 ^= x0 - x2 += x3; x3 = rotateLeft(x3, 33); x3 ^= x2 - // d = 13 - x0 += x3; x3 = rotateLeft(x3, 46); x3 ^= x0 - x2 += x1; x1 = rotateLeft(x1, 12); x1 ^= x2 - // d = 14 - x0 += x1; x1 = rotateLeft(x1, 58); x1 ^= x0 - x2 += x3; x3 = rotateLeft(x3, 22); x3 ^= x2 - // d = 15 - x0 += x3; x3 = rotateLeft(x3, 32); x3 ^= x0 - x2 += x1; x1 = rotateLeft(x1, 32); x1 ^= x2 - // d = 16 - // injectKey s = 4 - x0 += k4; x1 += k0; x2 += k1 + t; x3 += k2 + 4 - x0 += x1; x1 = rotateLeft(x1, 14); x1 ^= x0 - x2 += x3; x3 = rotateLeft(x3, 16); x3 ^= x2 - // d = 17 - x0 += x3; x3 = rotateLeft(x3, 52); x3 ^= x0 - x2 += x1; x1 = rotateLeft(x1, 57); x1 ^= x2 - // d = 18 - x0 += x1; x1 = rotateLeft(x1, 23); x1 ^= x0 - x2 += x3; x3 = rotateLeft(x3, 40); x3 ^= x2 - // d = 19 - x0 += x3; x3 = rotateLeft(x3, 5); x3 ^= x0 - x2 += x1; x1 = rotateLeft(x1, 37); x1 ^= x2 - // d = 20 - // injectKey s = 5 - x0 += k0; x1 += k1 + t; x2 += k2 + t; x3 += k3 + 5 + buffer(0) = state0; buffer(1) = state1; buffer(2) = state2; buffer(3) = state3 + Threefry.encryptUnrolled(k0, k1, k2, k3, k4, tweak, counter, buffer) - buffer(0) = x0; buffer(1) = x1; buffer(2) = x2; buffer(3) = x3 - counter += 1 usedInts = 0 + counter += 1 } + override def setSeed(seed: Int): Unit = ??? + override def setSeed(seed: Long): Unit = ??? + override def setSeed(seed: Array[Int]): Unit = ??? + override def getSeed: Long = ??? + override def nextBytes(x: Array[Byte]): Unit = ??? + + override def nextBoolean(): Boolean = + (nextInt() ^ 1) == 0 + override def nextLong(): Long = { usedInts += usedInts & 1 // round up to multiple of 2 if (usedInts >= 8) fillBuffer() diff --git a/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala b/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala index 2b6a574b194..da68808028d 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala @@ -463,11 +463,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { requiredness.union(node.children.forall { case c: IR => lookup(c).required }) // always required - case _: I32 | _: I64 | _: F32 | _: F64 | _: Str | True() | False() | _: IsNA | _: Die | _: UUID4 | _: Consume | _: RNGStateLiteral => - // FIXME: once support for new rng is complete, make states required - case RNGSplit(state, dynBitstring) => - requiredness.union(lookup(state).required) - requiredness.union(lookup(dynBitstring).required) + case _: I32 | _: I64 | _: F32 | _: F64 | _: Str | True() | False() | _: IsNA | _: Die | _: UUID4 | _: Consume | _: RNGStateLiteral | _: RNGSplit => case _: CombOpValue | _: AggStateValue => case Trap(child) => // error message field is missing if the child runs without error diff --git a/hail/src/main/scala/is/hail/expr/ir/Simplify.scala b/hail/src/main/scala/is/hail/expr/ir/Simplify.scala index dbfb52e988f..14425eb6587 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Simplify.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Simplify.scala @@ -79,6 +79,7 @@ object Simplify { private[this] def isStrict(x: IR): Boolean = { x match { case _: Apply | + _: ApplySeeded | _: ApplyUnaryPrimOp | _: ApplyBinaryPrimOp | _: ArrayRef | @@ -86,7 +87,6 @@ object Simplify { _: GetField | _: GetTupleElement => true case ApplyComparisonOp(op, _, _) => op.strict - case f: ApplySeeded => f.implementation.isStrict case _ => false } } @@ -98,6 +98,7 @@ object Simplify { private[this] def hasMissingStrictChild(x: IR): Boolean = { x match { case _: Apply | + _: ApplySeeded | _: ApplyUnaryPrimOp | _: ApplyBinaryPrimOp | _: ArrayRef | @@ -105,7 +106,6 @@ object Simplify { _: GetField | _: GetTupleElement => Children(x).exists(_.isInstanceOf[NA]) case ApplyComparisonOp(op, _, _) if op.strict => Children(x).exists(_.isInstanceOf[NA]) - case f: ApplySeeded if f.implementation.isStrict => f.args.exists(_.isInstanceOf[NA]) case _ => false } } diff --git a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala index 538e1c94739..14f0b1caf2c 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala @@ -197,7 +197,7 @@ object LoweredTableReader { "key", MakeStruct(FastIndexedSeq( "key" -> Ref("key", keyType), - "token" -> invokeSeeded("rand_unif", 1, TFloat64, NA(TRNGState), F64(0.0), F64(1.0)), + "token" -> invokeSeeded("rand_unif", 1, TFloat64, RNGStateLiteral(), F64(0.0), F64(1.0)), "prevkey" -> ApplyScanOp(FastIndexedSeq(), FastIndexedSeq(Ref("key", keyType)), prevkey)))), "x", Let("n", ApplyAggOp(FastIndexedSeq(), FastIndexedSeq(), count), diff --git a/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala b/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala index 9d87e0c96b3..76783100e34 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala @@ -272,8 +272,8 @@ object TypeCheck { val td = tcoerce[TDict](x.typ) assert(td.keyType == telt.types(0)) assert(td.valueType == TArray(telt.types(1))) - case RNGStateLiteral(key) => - assert(key.length == 4) + case x@RNGStateLiteral() => + assert(x.typ == TRNGState) case RNGSplit(state, dynBitstring) => assert(state.typ == TRNGState) def isValid: Type => Boolean = { diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala index 0964597aa50..9c7e93cd716 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala @@ -183,11 +183,10 @@ object IRFunctionRegistry { } } - def lookupSeeded(name: String, seed: Long, returnType: Type, arguments: Seq[Type]): Option[(Seq[IR], IR) => IR] = { - lookupFunction(name, returnType, Array.empty[Type], arguments) - .filter(_.isInstanceOf[SeededJVMFunction]) - .map { case f: SeededJVMFunction => - (irArguments: Seq[IR], rngState: IR) => ApplySeeded(name, irArguments, rngState, seed, f.returnType.subst()) + def lookupSeeded(name: String, staticUID: Long, returnType: Type, arguments: Seq[Type]): Option[(Seq[IR], IR) => IR] = { + lookupFunction(name, returnType, Array.empty[Type], TRNGState +: arguments) + .map { f => + (irArguments: Seq[IR], rngState: IR) => ApplySeeded(name, irArguments, rngState, staticUID, f.returnType.subst()) } } @@ -204,7 +203,7 @@ object IRFunctionRegistry { } val validMethods = lookupFunction(name, returnType, typeParameters, arguments) - .filter(!_.isInstanceOf[SeededJVMFunction]).map { f => + .map { f => { (irValueParametersTypes: Seq[Type], irArguments: Seq[IR], errorID: Int) => f match { case _: UnseededMissingnessObliviousJVMFunction => @@ -251,12 +250,7 @@ object IRFunctionRegistry { jvmRegistry.foreach { case (name, fns) => fns.foreach { f => - println(s"""${ - if (f.isInstanceOf[SeededJVMFunction]) - "register_seeded_function" - else - "register_function" - }("${ StringEscapeUtils.escapeString(name) }", (${ f.typeParameters.map(dtype).mkString(",") }), (${ f.valueParameterTypes.map(dtype).mkString(",") }), ${ dtype(f.returnType) })""") + println(s"""register_function("${ StringEscapeUtils.escapeString(name) }", (${ f.typeParameters.map(dtype).mkString(",") }), (${ f.valueParameterTypes.map(dtype).mkString(",") }), ${ dtype(f.returnType) })""") } } } @@ -644,60 +638,6 @@ abstract class RegistryFunctions { def registerIR4(name: String, mt1: Type, mt2: Type, mt3: Type, mt4: Type, returnType: Type, typeParameters: Array[Type] = Array.empty)(f: (Seq[Type], IR, IR, IR, IR, Int) => IR): Unit = registerIR(name, Array(mt1, mt2, mt3, mt4), returnType, typeParameters = typeParameters) { case (t, Seq(a1, a2, a3, a4), errorID) => f(t, a1, a2, a3, a4, errorID) } - - def registerSeeded( - name: String, - valueParameterTypes: Array[Type], - returnType: Type, - computeReturnType: (Type, Seq[SType]) => SType - )( - impl: (EmitCodeBuilder, Value[Region], SType, Long, Array[SValue]) => SValue - ) { - IRFunctionRegistry.addJVMFunction( - new SeededMissingnessObliviousJVMFunction(name, valueParameterTypes, returnType, computeReturnType) { - val isDeterministic: Boolean = false - - def applySeeded(cb: EmitCodeBuilder, seed: Long, r: Value[Region], rpt: SType, args: SValue*): SValue = { - assert(unify(Array.empty[Type], args.map(_.st.virtualType), rpt.virtualType)) - impl(cb, r, rpt, seed, args.toArray) - } - - def applySeededI(seed: Long, cb: EmitCodeBuilder, r: Value[Region], rpt: SType, args: EmitCode*): IEmitCode = { - IEmitCode.multiMapEmitCodes(cb, args.toFastIndexedSeq) { - argPCs => applySeeded(cb, seed, r, rpt, argPCs: _*) - } - } - - override val isStrict: Boolean = true - }) - } - - def registerSeeded0(name: String, returnType: Type, pt: SType)(impl: (EmitCodeBuilder, Value[Region], SType, Long) => SValue): Unit = - registerSeeded(name, Array[Type](), returnType, if (pt == null) null else (_: Type, _: Seq[SType]) => pt) { case (cb, r, rt, seed, _) => impl(cb, r, rt, seed) } - - def registerSeeded1(name: String, arg1: Type, returnType: Type, pt: (Type, SType) => SType)(impl: (EmitCodeBuilder, Value[Region], SType, Long, SValue) => SValue): Unit = - registerSeeded(name, Array(arg1), returnType, unwrappedApply(pt)) { - case (cb, r, rt, seed, Array(a1)) => impl(cb, r, rt, seed, a1) - } - - def registerSeeded2(name: String, arg1: Type, arg2: Type, returnType: Type, pt: (Type, SType, SType) => SType) - (impl: (EmitCodeBuilder, Value[Region], SType, Long, SValue, SValue) => SValue): Unit = - registerSeeded(name, Array(arg1, arg2), returnType, unwrappedApply(pt)) { case - (cb, r, rt, seed, Array(a1, a2)) => - impl(cb, r, rt, seed, a1, a2) - } - - def registerSeeded3(name: String, arg1: Type, arg2: Type, arg3: Type, returnType: Type, pt: (Type, SType, SType, SType) => SType) - (impl: (EmitCodeBuilder, Value[Region], SType, Long, SValue, SValue, SValue) => SValue): Unit = - registerSeeded(name, Array(arg1, arg2, arg3), returnType, unwrappedApply(pt)) { - case (cb, r, rt, seed, Array(a1, a2, a3)) => impl(cb, r, rt, seed, a1, a2, a3) - } - - def registerSeeded4(name: String, arg1: Type, arg2: Type, arg3: Type, arg4: Type, returnType: Type, pt: (Type, SType, SType, SType, SType) => SType) - (impl: (EmitCodeBuilder, Value[Region], SType, Long, SValue, SValue, SValue, SValue) => SValue): Unit = - registerSeeded(name, Array(arg1, arg2, arg3, arg4), returnType, unwrappedApply(pt)) { - case (cb, r, rt, seed, Array(a1, a2, a3, a4)) => impl(cb, r, rt, seed, a1, a2, a3, a4) - } } sealed abstract class JVMFunction { @@ -801,49 +741,3 @@ abstract class UnseededMissingnessAwareJVMFunction ( ??? } } - -abstract class SeededJVMFunction ( - override val name: String, - override val valueParameterTypes: Seq[Type], - override val returnType: Type -) extends JVMFunction { - def typeParameters: Seq[Type] = Seq.empty[Type] - - private[this] var seed: Long = _ - - def setSeed(s: Long): Unit = { seed = s } - - def applySeededI(seed: Long, cb: EmitCodeBuilder, region: Value[Region], rpt: SType, args: EmitCode*): IEmitCode - - def apply(region: EmitRegion, rpt: SType, typeParameters: Seq[Type], errorID: Value[Int], args: EmitCode*): EmitCode = - fatal("seeded functions must go through IEmitCode path") - - def apply(region: EmitRegion, rpt: SType, args: EmitCode*): EmitCode = - fatal("seeded functions must go through IEmitCode path") - - def isStrict: Boolean = false -} - -abstract class SeededMissingnessObliviousJVMFunction ( - override val name: String, - override val valueParameterTypes: Seq[Type], - override val returnType: Type, - missingnessObliviousreturnSType: (Type, Seq[SType]) => SType -) extends SeededJVMFunction(name, valueParameterTypes, returnType) { - override def computeReturnEmitType(returnType: Type, valueParameterTypes: Seq[EmitType]): EmitType = { - EmitType(computeStrictReturnEmitType(returnType, valueParameterTypes.map(_.st)), valueParameterTypes.forall(_.required)) - } - - def computeStrictReturnEmitType(returnType: Type, valueParameterTypes: Seq[SType]): SType = - MissingnessObliviousJVMFunction.returnSType(missingnessObliviousreturnSType)(returnType, valueParameterTypes) -} - -abstract class SeededMissingnessAwareJVMFunction ( - override val name: String, - override val valueParameterTypes: Seq[Type], - override val returnType: Type, - missingnessAwarereturnSType: (Type, Seq[EmitType]) => EmitType -) extends SeededJVMFunction(name, valueParameterTypes, returnType) { - override def computeReturnEmitType(returnType: Type, valueParameterTypes: Seq[EmitType]): EmitType = - MissingnessAwareJVMFunction.returnSType(missingnessAwarereturnSType)(returnType, valueParameterTypes) -} diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/RandomSeededFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/RandomSeededFunctions.scala index ccb7b695d98..c85df19b878 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/RandomSeededFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/RandomSeededFunctions.scala @@ -1,13 +1,15 @@ package is.hail.expr.ir.functions import is.hail.asm4s._ +import is.hail.expr.Nat import is.hail.expr.ir.{EmitCodeBuilder, IEmitCode} import is.hail.types.physical.stypes._ -import is.hail.types.physical.stypes.concrete.{SIndexablePointer, SRNGStateStaticSizeValue} +import is.hail.types.physical.stypes.concrete.{SIndexablePointer, SNDArrayPointer, SRNGStateStaticSizeValue, SRNGStateValue} import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives._ -import is.hail.types.physical.{PBoolean, PCanonicalArray, PFloat64, PInt32, PType} +import is.hail.types.physical.{PBoolean, PCanonicalArray, PCanonicalNDArray, PFloat64, PInt32, PType} import is.hail.types.virtual._ +import is.hail.utils.FastIndexedSeq import net.sourceforge.jdistlib.rng.MersenneTwister import net.sourceforge.jdistlib.{Beta, Gamma, HyperGeometric, Poisson} @@ -30,6 +32,8 @@ class IRRandomness(seed: Long) { def rint32(n: Int): Int = random.nextInt(n) + def rint64(): Long = random.nextLong() + def rint64(n: Long): Long = random.nextLong(n) def rcoin(p: Double): Boolean = random.nextDouble() < p @@ -42,7 +46,8 @@ class IRRandomness(seed: Long) { def rgamma(shape: Double, scale: Double): Double = Gamma.random(shape, scale, random) - def rhyper(numSuccessStates: Double, numFailureStates: Double, numToDraw: Double): Double = HyperGeometric.random(numSuccessStates, numFailureStates, numToDraw, random) + def rhyper(numSuccessStates: Double, numFailureStates: Double, numToDraw: Double): Double = + HyperGeometric.random(numSuccessStates, numFailureStates, numToDraw, random) def rcat(prob: Array[Double]): Int = { var i = 0 @@ -91,105 +96,150 @@ object RandomSeededFunctions extends RegistryFunctions { } def registerAll() { - registerSeeded2("rand_unif", TFloat64, TFloat64, TFloat64, { - case (_: Type, _: SType, _: SType) => SFloat64 - }) { case (cb, r, rt, seed, min, max) => - primitive(cb.memoize(cb.emb.newRNG(seed).invoke[Double, Double, Double]("runif", min.asDouble.value, max.asDouble.value))) - } - - registerSCode3("rand_unif_pure", TRNGState, TFloat64, TFloat64, TFloat64, { + registerSCode3("rand_unif", TRNGState, TFloat64, TFloat64, TFloat64, { case (_: Type, _: SType, _: SType, _: SType) => SFloat64 - }) { case (_, cb, rt, rngState: SRNGStateStaticSizeValue, min: SFloat64Value, max: SFloat64Value, errorID) => + }) { case (_, cb, rt, rngState: SRNGStateValue, min: SFloat64Value, max: SFloat64Value, errorID) => primitive(cb.memoize(rand_unif(cb, rngState.rand(cb)) * (max.value - min.value) + min.value)) } - registerSeeded1("rand_int32", TInt32, TInt32, { - case (_: Type, _: SType) => SInt32 - }) { case (cb, r, rt, seed, n) => - primitive(cb.memoize(cb.emb.newRNG(seed).invoke[Int, Int]("rint32", n.asInt.value))) + registerSCode5("rand_unif_nd", TRNGState, TInt64, TInt64, TFloat64, TFloat64, TNDArray(TFloat64, Nat(2)), { + case (_: Type, _: SType, _: SType, _: SType, _: SType, _: SType) => PCanonicalNDArray(PFloat64(true), 2, true).sType + }) { case (r, cb, rt: SNDArrayPointer, rngState: SRNGStateValue, nRows: SInt64Value, nCols: SInt64Value, min, max, errorID) => + val result = rt.pType.constructUninitialized(FastIndexedSeq(SizeValueDyn(nRows.value), SizeValueDyn(nCols.value)), cb, r.region) + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + result.coiterateMutate(cb, r.region) { _ => + primitive(cb.memoize(rng.invoke[Double, Double, Double]("runif", min.asDouble.value, max.asDouble.value))) + } + result } - registerSeeded1("rand_int64", TInt64, TInt64, { - case (_: Type, _: SType) => SInt64 - }) { case (cb, r, rt, seed, n) => - primitive(cb.memoize(cb.emb.newRNG(seed).invoke[Long, Long]("rint64", n.asLong.value))) + registerSCode2("rand_int32", TRNGState, TInt32, TInt32, { + case (_: Type, _: SType, _: SType) => SInt32 + }) { case (r, cb, rt, rngState: SRNGStateValue, n: SInt32Value, errorID) => + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + primitive(cb.memoize(rng.invoke[Int, Int]("nextInt", n.value))) } - registerSeeded2("rand_norm", TFloat64, TFloat64, TFloat64, { - case (_: Type, _: SType, _: SType) => SFloat64 - }) { case (cb, r, rt, seed, mean, sd) => - primitive(cb.memoize(cb.emb.newRNG(seed).invoke[Double, Double, Double]("rnorm", mean.asDouble.value, sd.asDouble.value))) + registerSCode2("rand_int64", TRNGState, TInt64, TInt64, { + case (_: Type, _: SType, _: SType) => SInt64 + }) { case (r, cb, rt, rngState: SRNGStateValue, n: SInt64Value, errorID) => + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + primitive(cb.memoize(rng.invoke[Long, Long]("nextLong", n.value))) } - registerSeeded1("rand_bool", TFloat64, TBoolean, (_: Type, _: SType) => SBoolean) { case (cb, r, rt, seed, p) => - primitive(cb.memoize(cb.emb.newRNG(seed).invoke[Double, Boolean]("rcoin", p.asDouble.value))) + registerSCode1("rand_int64", TRNGState, TInt64, { + case (_: Type, _: SType) => SInt64 + }) { case (r, cb, rt, rngState: SRNGStateValue, errorID) => + primitive(rngState.rand(cb)(0)) } - registerSeeded1("rand_pois", TFloat64, TFloat64, (_: Type, _: SType) => SFloat64) { case (cb, r, rt, seed, lambda) => - primitive(cb.memoize(cb.emb.newRNG(seed).invoke[Double, Double]("rpois", lambda.asDouble.value))) + registerSCode5("rand_norm_nd", TRNGState, TInt64, TInt64, TFloat64, TFloat64, TNDArray(TFloat64, Nat(2)), { + case (_: Type, _: SType, _: SType, _: SType, _: SType, _: SType) => PCanonicalNDArray(PFloat64(true), 2, true).sType + }) { case (r, cb, rt: SNDArrayPointer, rngState: SRNGStateValue, nRows: SInt64Value, nCols: SInt64Value, mean, sd, errorID) => + val result = rt.pType.constructUninitialized(FastIndexedSeq(SizeValueDyn(nRows.value), SizeValueDyn(nCols.value)), cb, r.region) + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + result.coiterateMutate(cb, r.region) { _ => + primitive(cb.memoize(rng.invoke[Double, Double, Double]("rnorm", mean.asDouble.value, sd.asDouble.value))) + } + result } - registerSeeded2("rand_pois", TInt32, TFloat64, TArray(TFloat64), { - case (_: Type, _: SType, _: SType) => PCanonicalArray(PFloat64(true)).sType - }) { case (cb, r, SIndexablePointer(rt: PCanonicalArray), seed, n, lambdaCode) => - val len = n.asInt.value - val lambda = lambdaCode.asDouble.value + registerSCode3("rand_norm", TRNGState, TFloat64, TFloat64, TFloat64, { + case (_: Type, _: SType, _: SType, _: SType) => SFloat64 + }) { case (_, cb, rt, rngState: SRNGStateValue, mean: SFloat64Value, sd: SFloat64Value, errorID) => + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + primitive(cb.memoize(rng.invoke[Double, Double, Double]("rnorm", mean.value, sd.value))) + } - rt.constructFromElements(cb, r, len, deepCopy = false) { case (cb, _) => - IEmitCode.present(cb, primitive(cb.memoize(cb.emb.newRNG(seed).invoke[Double, Double]("rpois", lambda)))) - } + registerSCode2("rand_bool", TRNGState, TFloat64, TBoolean, { + case (_: Type, _: SType, _: SType) => SBoolean + }) { case (_, cb, rt, rngState: SRNGStateValue, p: SFloat64Value, errorID) => + val u = rand_unif(cb, rngState.rand(cb)) + primitive(cb.memoize(u < p.value)) } - registerSeeded2("rand_beta", TFloat64, TFloat64, TFloat64, { + registerSCode2("rand_pois", TRNGState, TFloat64, TFloat64, { case (_: Type, _: SType, _: SType) => SFloat64 - }) { case (cb, r, rt, seed, a, b) => - primitive(cb.memoize( - cb.emb.newRNG(seed).invoke[Double, Double, Double]("rbeta", - a.asDouble.value, b.asDouble.value))) - } - - registerSeeded4("rand_beta", TFloat64, TFloat64, TFloat64, TFloat64, TFloat64, { - case (_: Type, _: SType, _: SType, _: SType, _: SType) => SFloat64 - }) { - case (cb, r, rt, seed, a, b, min, max) => - val rng = cb.emb.newRNG(seed) - val la = a.asDouble.value - val lb = b.asDouble.value - val lmin = min.asDouble.value - val lmax = max.asDouble.value - val value = cb.newLocal[Double]("value", rng.invoke[Double, Double, Double]("rbeta", la, lb)) - cb.whileLoop(value < lmin || value > lmax, { - cb.assign(value, rng.invoke[Double, Double, Double]("rbeta", la, lb)) - }) - primitive(value) + }) { case (_, cb, rt, rngState: SRNGStateValue, lambda: SFloat64Value, errorID) => + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + primitive(cb.memoize(rng.invoke[Double, Double]("rpois", lambda.value))) } - registerSeeded2("rand_gamma", TFloat64, TFloat64, TFloat64, { - case (_: Type, _: SType, _: SType) => SFloat64 - }) { case (cb, r, rt, seed, a, scale) => - primitive(cb.memoize( - cb.emb.newRNG(seed).invoke[Double, Double, Double]("rgamma", a.asDouble.value, scale.asDouble.value) - )) + registerSCode3("rand_pois", TRNGState, TInt32, TFloat64, TArray(TFloat64), { + case (_: Type, _: SType, _: SType, _: SType) => PCanonicalArray(PFloat64(true)).sType + }) { case (r, cb, SIndexablePointer(rt: PCanonicalArray), rngState: SRNGStateValue, n: SInt32Value, lambda: SFloat64Value, errorID) => + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + rt.constructFromElements(cb, r.region, n.value, deepCopy = false) { case (cb, _) => + IEmitCode.present(cb, + primitive(cb.memoize(rng.invoke[Double, Double]("rpois", lambda.value))) + ) + } } - registerSeeded1("rand_cat", TArray(TFloat64), TInt32, (_: Type, _: SType) => SInt32) { case (cb, r, rt, seed, weights: SIndexableValue) => - val len = weights.loadLength() + registerSCode3("rand_beta", TRNGState, TFloat64, TFloat64, TFloat64, { + case (_: Type, _: SType, _: SType, _: SType) => SFloat64 + }) { case (_, cb, rt, rngState: SRNGStateValue, a: SFloat64Value, b: SFloat64Value, errorID) => + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + primitive(cb.memoize(rng.invoke[Double, Double, Double]("rbeta", a.value, b.value))) + } + + registerSCode5("rand_beta", TRNGState, TFloat64, TFloat64, TFloat64, TFloat64, TFloat64, { + case (_: Type, _: SType, _: SType, _: SType, _: SType, _: SType) => SFloat64 + }) { case (_, cb, rt, rngState: SRNGStateValue, a: SFloat64Value, b: SFloat64Value, min: SFloat64Value, max: SFloat64Value, errorID) => + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + val value = cb.newLocal[Double]("value", rng.invoke[Double, Double, Double]("rbeta", a.value, b.value)) + cb.whileLoop(value < min.value || value > max.value, { + cb.assign(value, rng.invoke[Double, Double, Double]("rbeta", a.value, b.value)) + }) + primitive(value) + } - val a = cb.newLocal[Array[Double]]("rand_cat_a", Code.newArray[Double](len)) + registerSCode3("rand_gamma", TRNGState, TFloat64, TFloat64, TFloat64, { + case (_: Type, _: SType, _: SType, _: SType) => SFloat64 + }) { case (_, cb, rt, rngState: SRNGStateValue, a: SFloat64Value, scale: SFloat64Value, errorID) => + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + primitive(cb.memoize(rng.invoke[Double, Double, Double]("rgamma", a.value, scale.value))) + } - val i = cb.newLocal[Int]("rand_cat_i", 0) + registerSCode2("rand_cat", TRNGState, TArray(TFloat64), TInt32, { + case (_: Type, _: SType, _: SType) => SInt32 + }) { case (_, cb, rt, rngState: SRNGStateValue, weights: SIndexableValue, errorID) => + val len = weights.loadLength() + val i = cb.newLocal[Int]("i", 0) + val s = cb.newLocal[Double]("sum", 0.0) cb.whileLoop(i < len, { - weights.loadElement(cb, i).consume(cb, - cb._fatal("rand_cat requires all elements of input array to be present"), - sc => cb += a.update(i, sc.asDouble.value) - ) + cb.assign(s, s + weights.loadElement(cb, i).get(cb, "rand_cat requires all elements of input array to be present").asFloat64.value) cb.assign(i, i + 1) }) - primitive(cb.memoize(cb.emb.newRNG(seed).invoke[Array[Double], Int]("rcat", a))) + val r = cb.newLocal[Double]("r", rand_unif(cb, rngState.rand(cb)) * s) + cb.assign(i, 0) + val elt = cb.newLocal[Double]("elt") + cb.loop { start => + cb.assign(elt, weights.loadElement(cb, i).get(cb, "rand_cat requires all elements of input array to be present").asFloat64.value) + cb.ifx(r > elt && i < len, { + cb.assign(r, r - elt) + cb.assign(i, i + 1) + cb.goto(start) + }) + } + primitive(i) } - registerSeeded2("shuffle_compute_num_samples_per_partition", TInt32, TArray(TInt32), TArray(TInt32), - (_, _, _) => SIndexablePointer(PCanonicalArray(PInt32(true), false))) { case (cb, r, rt, seed, initalNumSamplesToSelect: SInt32Value, partitionCounts: SIndexableValue) => + registerSCode3("shuffle_compute_num_samples_per_partition", TRNGState, TInt32, TArray(TInt32), TArray(TInt32), + (_, _, _, _) => SIndexablePointer(PCanonicalArray(PInt32(true), false)) + ) { case (r, cb, rt, rngState: SRNGStateValue, initalNumSamplesToSelect: SInt32Value, partitionCounts: SIndexableValue, errorID) => + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) val totalNumberOfRecords = cb.newLocal[Int]("scnspp_total_number_of_records", 0) val resultSize: Value[Int] = partitionCounts.loadLength() @@ -205,10 +255,10 @@ object RandomSeededFunctions extends RegistryFunctions { val failureStatesRemaining = cb.newLocal[Int]("scnspp_failure", totalNumberOfRecords - successStatesRemaining) val arrayRt = rt.asInstanceOf[SIndexablePointer] - val (push, finish) = arrayRt.pType.asInstanceOf[PCanonicalArray].constructFromFunctions(cb, r, resultSize, false) + val (push, finish) = arrayRt.pType.asInstanceOf[PCanonicalArray].constructFromFunctions(cb, r.region, resultSize, false) cb.forLoop(cb.assign(i, 0), i < resultSize, cb.assign(i, i + 1), { - val numSuccesses = cb.memoize(cb.emb.newRNG(seed).invoke[Double, Double, Double, Double]("rhyper", + val numSuccesses = cb.memoize(rng.invoke[Double, Double, Double, Double]("rhyper", successStatesRemaining.toD, failureStatesRemaining.toD, partitionCounts.loadElement(cb, i).get(cb).asInt32.value.toD).toI) cb.assign(successStatesRemaining, successStatesRemaining - numSuccesses) cb.assign(failureStatesRemaining, failureStatesRemaining - (partitionCounts.loadElement(cb, i).get(cb).asInt32.value - numSuccesses)) diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerBlockMatrixIR.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerBlockMatrixIR.scala index dd5db25b27e..3fe2f0c0f58 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerBlockMatrixIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerBlockMatrixIR.scala @@ -271,16 +271,19 @@ object LowerBlockMatrixIR { bmir match { case BlockMatrixRead(reader) => reader.lower(ctx) - case x@BlockMatrixRandom(seed, gaussian, shape, blockSize) => - val generator = invokeSeeded(if (gaussian) "rand_norm" else "rand_unif", seed, TFloat64, NA(TRNGState), F64(0.0), F64(1.0)) - new BlockMatrixStage(IndexedSeq(), Array(), TTuple(TInt64, TInt64)) { + case x@BlockMatrixRandom(staticUID, gaussian, shape, blockSize) => + new BlockMatrixStage(IndexedSeq(), Array(), TTuple(TInt64, TInt64, TInt32)) { def blockContext(idx: (Int, Int)): IR = { - val (i, j) = x.typ.blockShape(idx._1, idx._2) - MakeTuple.ordered(FastSeq(i, j)) + val (m, n) = x.typ.blockShape(idx._1, idx._2) + MakeTuple.ordered(FastSeq(m, n, idx._1 * x.typ.nColBlocks + idx._2)) } def blockBody(ctxRef: Ref): IR = { - val len = (GetTupleElement(ctxRef, 0) * GetTupleElement(ctxRef, 1)).toI - MakeNDArray(ToArray(mapIR(rangeIR(len))(_ => generator)), ctxRef, True(), ErrorIDs.NO_ERROR) + val m = GetTupleElement(ctxRef, 0) + val n = GetTupleElement(ctxRef, 1) + val i = GetTupleElement(ctxRef, 2) + val f = if (gaussian) "rand_norm_nd" else "rand_unif_nd" + val rngState = RNGSplit(RNGStateLiteral(), Cast(i, TInt64)) + invokeSeeded(f, staticUID, TNDArray(TFloat64, Nat(2)), rngState, m, n, F64(0.0), F64(1.0)) } } case BlockMatrixMap(child, eltName, f, _) => diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala index 9f9757c082a..0c54b43fb5a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala @@ -528,7 +528,7 @@ object LowerTableIR { MakeArray( ApplyAggOp( FastIndexedSeq(I32(samplesPerPartition)), - FastIndexedSeq(SelectFields(elt, keyType.fieldNames), invokeSeeded("rand_unif", 1, TFloat64, NA(TRNGState), F64(0.0), F64(1.0))), + FastIndexedSeq(SelectFields(elt, keyType.fieldNames), invokeSeeded("rand_unif", 1, TFloat64, RNGStateLiteral(), F64(0.0), F64(1.0))), samplekey), ApplyAggOp( FastIndexedSeq(I32(1)), diff --git a/hail/src/main/scala/is/hail/expr/ir/package.scala b/hail/src/main/scala/is/hail/expr/ir/package.scala index ca50905e9c7..3739ddc634f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/package.scala +++ b/hail/src/main/scala/is/hail/expr/ir/package.scala @@ -52,10 +52,11 @@ package object ir { def invoke(name: String, rt: Type, errorID: Int, args: IR*): IR = invoke(name, rt, Array.empty[Type], errorID, args:_*) - def invokeSeeded(name: String, seed: Long, rt: Type, rngState: IR, args: IR*): IR = IRFunctionRegistry.lookupSeeded(name, seed, rt, args.map(_.typ)) match { - case Some(f) => f(args, rngState) - case None => fatal(s"no seeded function found for $name(${args.map(_.typ).mkString(", ")}) => $rt") - } + def invokeSeeded(name: String, staticUID: Long, rt: Type, rngState: IR, args: IR*): IR = + IRFunctionRegistry.lookupSeeded(name, staticUID, rt, args.map(_.typ)) match { + case Some(f) => f(args, rngState) + case None => fatal(s"no seeded function found for $name(${args.map(_.typ).mkString(", ")}) => $rt") + } implicit def irToPrimitiveIR(ir: IR): PrimitiveIR = new PrimitiveIR(ir) diff --git a/hail/src/main/scala/is/hail/linalg/BlockMatrix.scala b/hail/src/main/scala/is/hail/linalg/BlockMatrix.scala index 94352a639e0..2acb5fa3de6 100644 --- a/hail/src/main/scala/is/hail/linalg/BlockMatrix.scala +++ b/hail/src/main/scala/is/hail/linalg/BlockMatrix.scala @@ -1,37 +1,33 @@ package is.hail.linalg -import java.io._ -import java.nio._ import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, sum => breezeSum, _} import breeze.numerics.{abs => breezeAbs, log => breezeLog, pow => breezePow, sqrt => breezeSqrt} import breeze.stats.distributions.{RandBasis, ThreadLocalRandomGenerator} import is.hail._ import is.hail.annotations._ -import is.hail.backend.{BroadcastValue, ExecuteContext, HailTaskContext} import is.hail.backend.spark.{SparkBackend, SparkTaskContext} -import is.hail.utils._ -import is.hail.expr.Parser -import is.hail.expr.ir.{CompileAndEvaluate, IR, IntArrayBuilder, TableReader, TableValue} -import is.hail.types._ -import is.hail.types.physical.{PArray, PCanonicalArray, PCanonicalStruct, PFloat64, PFloat64Optional, PFloat64Required, PInt64, PInt64Optional, PInt64Required, PStruct} -import is.hail.types.virtual._ +import is.hail.backend.{BroadcastValue, ExecuteContext} +import is.hail.expr.ir.{IntArrayBuilder, TableReader, TableValue, ThreefryRandomEngine} import is.hail.io._ -import is.hail.rvd.{RVD, RVDContext, RVDPartitioner} +import is.hail.io.fs.FS +import is.hail.io.index.IndexWriter +import is.hail.rvd.{RVD, RVDContext} import is.hail.sparkextras.{ContextRDD, OriginUnionPartition, OriginUnionRDD} +import is.hail.types._ +import is.hail.types.physical._ +import is.hail.types.virtual._ import is.hail.utils._ import is.hail.utils.richUtils.{ByteTrackingOutputStream, RichArray, RichContextRDD, RichDenseMatrixDouble} -import is.hail.io.fs.FS -import is.hail.io.index.IndexWriter import org.apache.commons.lang3.StringUtils import org.apache.commons.math3.random.MersenneTwister -import org.apache.spark.executor.InputMetrics import org.apache.spark._ +import org.apache.spark.executor.InputMetrics import org.apache.spark.mllib.linalg.distributed._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.storage.StorageLevel import org.json4s._ +import java.io._ import scala.collection.immutable.NumericRange case class CollectMatricesRDDPartition(index: Int, firstPartition: Int, blockPartitions: Array[Partition], blockSize: Int, nRows: Int, nCols: Int) extends Partition { @@ -138,12 +134,11 @@ object BlockMatrix { // uniform or Gaussian def random(nRows: Long, nCols: Long, blockSize: Int = defaultBlockSize, - seed: Long = 0, gaussian: Boolean): M = + nonce: Long = 0, staticUID: Long = 0, gaussian: Boolean): M = BlockMatrix(GridPartitioner(blockSize, nRows, nCols), (gp, pi) => { val (i, j) = gp.blockCoordinates(pi) - val blockSeed = seed + 15485863 * pi // millionth prime - - val randBasis: RandBasis = new RandBasis(new ThreadLocalRandomGenerator(new MersenneTwister(blockSeed))) + val generator = ThreefryRandomEngine(nonce, staticUID, Array(pi.toLong)) + val randBasis: RandBasis = new RandBasis(generator) val rand = if (gaussian) randBasis.gaussian else randBasis.uniform ((i, j), BDM.rand[Double](gp.blockRowNRows(i), gp.blockColNCols(j), rand)) diff --git a/hail/src/main/scala/is/hail/types/physical/PCanonicalNDArray.scala b/hail/src/main/scala/is/hail/types/physical/PCanonicalNDArray.scala index 354d912a9d4..bb5107df355 100644 --- a/hail/src/main/scala/is/hail/types/physical/PCanonicalNDArray.scala +++ b/hail/src/main/scala/is/hail/types/physical/PCanonicalNDArray.scala @@ -174,6 +174,14 @@ final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boo constructByCopyingDataPointer(shape, strides, this.allocateData(shape, region), cb, region) } + def constructUninitialized( + shape: IndexedSeq[SizeValue], + cb: EmitCodeBuilder, + region: Value[Region] + ): SNDArrayPointerValue = { + constructByCopyingDataPointer(shape, makeColumnMajorStrides(shape, region, cb), this.allocateData(shape, region), cb, region) + } + def constructByCopyingArray( shape: IndexedSeq[Value[Long]], strides: IndexedSeq[Value[Long]], diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SRNGState.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SRNGState.scala index 297fa896745..efaac4b2225 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SRNGState.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SRNGState.scala @@ -2,7 +2,7 @@ package is.hail.types.physical.stypes.concrete import is.hail.annotations.Region import is.hail.asm4s._ -import is.hail.expr.ir.{EmitCodeBuilder, Threefry} +import is.hail.expr.ir.{EmitCodeBuilder, Threefry, ThreefryRandomEngine} import is.hail.types.{RPrimitive, TypeWithRequiredness} import is.hail.types.physical.{PType, StoredSTypePType} import is.hail.types.physical.stypes.primitives.SInt64Value @@ -13,12 +13,6 @@ import is.hail.utils.{Bitstring, FastIndexedSeq, toRichIterable} import scala.collection.mutable import scala.collection.mutable -object SRNGState { - val staticTweak = -1L - val finalBlockNoPadTweak = -2L - val finalBlockPaddedTweak = -3L -} - final case class SRNGStateStaticInfo(numWordsInLastBlock: Int, hasStaticSplit: Boolean, numDynBlocks: Int) { assert(numWordsInLastBlock <= 4 && numWordsInLastBlock >= 0) } @@ -87,10 +81,22 @@ trait SRNGStateValue extends SValue { def splitStatic(cb: EmitCodeBuilder, idx: Long): SRNGStateValue def splitDyn(cb: EmitCodeBuilder, idx: Value[Long]): SRNGStateValue def rand(cb: EmitCodeBuilder): IndexedSeq[Value[Long]] + def copyIntoEngine(cb: EmitCodeBuilder, tf: Value[ThreefryRandomEngine]): Unit } trait SRNGStateSettable extends SRNGStateValue with SSettable +object SCanonicalRNGStateValue { + def apply(cb: EmitCodeBuilder): SCanonicalRNGStateValue = { + val typ = SRNGState(None) + new SCanonicalRNGStateValue( + typ, + Array.fill[Value[Long]](4)(0), + Array.fill[Value[Long]](4)(0), + 0, false, 0) + } +} + class SCanonicalRNGStateValue( override val st: SRNGState, val runningSum: IndexedSeq[Value[Long]], @@ -109,11 +115,12 @@ class SCanonicalRNGStateValue( new SInt64Value(4*8 + 4*8 + 4 + 4 + 4) def splitStatic(cb: EmitCodeBuilder, idx: Long): SCanonicalRNGStateValue = { - cb.ifx(!hasStaticSplit, cb._fatal("RNGState received two static splits")) + cb.ifx(hasStaticSplit, cb._fatal("RNGState received two static splits")) val x = Array.ofDim[Long](4) - x(0) = idx - val key = cb.emb.ctx.rngKey - Threefry.encrypt(key, SRNGState.staticTweak, x) + x(0) = cb.emb.ctx.rngNonce + x(1) = idx + val key = Threefry.defaultKey + Threefry.encrypt(key, Array(Threefry.staticTweak, 0L), x) val newDynBlocksSum = Array.tabulate[Value[Long]](4)(i => cb.memoize(runningSum(i) ^ x(i))) new SCanonicalRNGStateValue(st, newDynBlocksSum, lastDynBlock, numWordsInLastBlock, const(true), numDynBlocks) @@ -136,8 +143,8 @@ class SCanonicalRNGStateValue( newLastDynBlock(3) := idx)) cb.assign(newNumWordsInLastBlock, newNumWordsInLastBlock + 1) }, { - val key = cb.emb.ctx.rngKey - Threefry.encrypt(cb, key, cb.memoize(numDynBlocks.toL), newLastDynBlock) + val key = Threefry.defaultKey + Threefry.encrypt(cb, key, Array(cb.memoize(numDynBlocks.toL), const(0L)), newLastDynBlock) for (i <- 0 until 4) cb.assign(newRunningSum(i), newRunningSum(i) ^ newLastDynBlock(i)) cb.assign(newLastDynBlock(0), idx) for (i <- 1 until 4) cb.assign(newLastDynBlock(i), 0L) @@ -149,20 +156,40 @@ class SCanonicalRNGStateValue( } def rand(cb: EmitCodeBuilder): IndexedSeq[Value[Long]] = { + cb.ifx(!hasStaticSplit, cb._fatal("RNGState never received static split")) val x = Array.tabulate[Settable[Long]](4)(i => cb.newLocal[Long](s"rand_x$i", runningSum(i))) - val key = cb.emb.ctx.rngKey - val tweak = cb.ifx(numWordsInLastBlock.ceq(4), SRNGState.finalBlockNoPadTweak, SRNGState.finalBlockPaddedTweak) + val key = Threefry.defaultKey + val finalTweak = cb.ifx(numWordsInLastBlock.ceq(4), Threefry.finalBlockNoPadTweak, Threefry.finalBlockPaddedTweak) + for (i <- 0 until 4) cb.assign(x(i), x(i) ^ lastDynBlock(i)) cb += Code.switch( numWordsInLastBlock, Code._fatal[Unit]("invalid numWordsInLastBlock"), FastIndexedSeq( - x(0) := x(0) ^ (1L << 63), - x(1) := x(1) ^ (1L << 63), - x(2) := x(2) ^ (1L << 63), - x(3) := x(3) ^ (1L << 63))) - Threefry.encrypt(cb, key, tweak, x) + x(0) := x(0) ^ 1L, + x(1) := x(1) ^ 1L, + x(2) := x(2) ^ 1L, + x(3) := x(3) ^ 1L, + Code._empty)) + Threefry.encrypt(cb, key, Array(finalTweak, const(0L)), x) x } + + def copyIntoEngine(cb: EmitCodeBuilder, tf: Value[ThreefryRandomEngine]): Unit = { + cb.ifx(!hasStaticSplit, cb._fatal("RNGState never received static split")) + val x = Array.tabulate[Settable[Long]](4)(i => cb.newLocal[Long](s"cie_x$i", runningSum(i))) + val finalTweak = cb.ifx(numWordsInLastBlock.ceq(4), Threefry.finalBlockNoPadTweak, Threefry.finalBlockPaddedTweak) + for (i <- 0 until 4) cb.assign(x(i), x(i) ^ lastDynBlock(i)) + cb += Code.switch( + numWordsInLastBlock, + Code._fatal[Unit]("invalid numWordsInLastBlock"), + FastIndexedSeq( + x(0) := x(0) ^ 1L, + x(1) := x(1) ^ 1L, + x(2) := x(2) ^ 1L, + x(3) := x(3) ^ 1L, + Code._empty)) + cb += tf.invoke[Long, Long, Long, Long, Long, Unit]("resetState", x(0), x(1), x(2), x(3), finalTweak) + } } class SCanonicalRNGStateSettable( @@ -190,7 +217,7 @@ class SCanonicalRNGStateSettable( } object SRNGStateStaticSizeValue { - def apply(cb: EmitCodeBuilder, key: IndexedSeq[Long]): SRNGStateStaticSizeValue = { + def apply(cb: EmitCodeBuilder): SRNGStateStaticSizeValue = { val typ = SRNGState(Some(SRNGStateStaticInfo(0, false, 0))) new SRNGStateStaticSizeValue( typ, @@ -217,9 +244,10 @@ class SRNGStateStaticSizeValue( def splitStatic(cb: EmitCodeBuilder, idx: Long): SRNGStateStaticSizeValue = { assert(!staticInfo.hasStaticSplit) val x = Array.ofDim[Long](4) - x(0) = idx - val key = cb.emb.ctx.rngKey - Threefry.encrypt(key, SRNGState.staticTweak, x) + x(0) = cb.emb.ctx.rngNonce + x(1) = idx + val key = Threefry.defaultKey + Threefry.encrypt(key, Array(Threefry.staticTweak, 0L), x) val newDynBlocksSum = Array.tabulate[Value[Long]](4)(i => cb.memoize(runningSum(i) ^ x(i))) new SRNGStateStaticSizeValue( @@ -237,8 +265,8 @@ class SRNGStateStaticSizeValue( ) } val x = Array.tabulate[Settable[Long]](4)(i => cb.newLocal[Long](s"splitDyn_x$i", lastDynBlock(i))) - val key = cb.emb.ctx.rngKey - Threefry.encrypt(cb, key, staticInfo.numDynBlocks.toLong, x) + val key = Threefry.defaultKey + Threefry.encrypt(cb, key, Array(const(staticInfo.numDynBlocks.toLong), const(0L)), x) for (i <- 0 until 4) cb.assign(x(i), x(i) ^ runningSum(i)) new SRNGStateStaticSizeValue( @@ -249,18 +277,31 @@ class SRNGStateStaticSizeValue( } def rand(cb: EmitCodeBuilder): IndexedSeq[Value[Long]] = { + assert(staticInfo.hasStaticSplit) val x = Array.tabulate[Settable[Long]](4)(i => cb.newLocal[Long](s"rand_x$i", runningSum(i))) - val key = cb.emb.ctx.rngKey + val key = Threefry.defaultKey if (staticInfo.numWordsInLastBlock == 4) { for (i <- lastDynBlock.indices) cb.assign(x(i), x(i) ^ lastDynBlock(i)) - Threefry.encrypt(cb, key, SRNGState.finalBlockNoPadTweak, x) + Threefry.encrypt(cb, key, Array(const(Threefry.finalBlockNoPadTweak), const(0L)), x) } else { for (i <- lastDynBlock.indices) cb.assign(x(i), x(i) ^ lastDynBlock(i)) - cb.assign(x(lastDynBlock.size), x(lastDynBlock.size) ^ (1L << 63)) - Threefry.encrypt(cb, key, SRNGState.finalBlockPaddedTweak, x) + cb.assign(x(lastDynBlock.size), x(lastDynBlock.size) ^ 1L) + Threefry.encrypt(cb, key, Array(const(Threefry.finalBlockPaddedTweak), const(0L)), x) } x } + + def copyIntoEngine(cb: EmitCodeBuilder, tf: Value[ThreefryRandomEngine]): Unit = { + val x = Array.tabulate[Settable[Long]](4)(i => cb.newLocal[Long](s"cie_x$i", runningSum(i))) + for (i <- lastDynBlock.indices) cb.assign(x(i), x(i) ^ lastDynBlock(i)) + val finalTweak = if (staticInfo.numWordsInLastBlock < 4) { + cb.assign(x(lastDynBlock.size), x(lastDynBlock.size) ^ 1L) + Threefry.finalBlockPaddedTweak + } else { + Threefry.finalBlockNoPadTweak + } + cb += tf.invoke[Long, Long, Long, Long, Long, Unit]("resetState", x(0), x(1), x(2), x(3), finalTweak) + } } class SRNGStateStaticSizeSettable( diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SNDArray.scala b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SNDArray.scala index c40697e7c16..1462a66a723 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SNDArray.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SNDArray.scala @@ -652,9 +652,8 @@ trait SNDArrayValue extends SValue { coiterateMutate(cb, region, false, arrays: _*)(body) def coiterateMutate(cb: EmitCodeBuilder, region: Value[Region], deepCopy: Boolean, arrays: (SNDArrayValue, String)*)(body: IndexedSeq[SValue] => SValue): Unit = { - if (arrays.isEmpty) return - val indexVars = Array.tabulate(arrays(0)._1.st.nDims)(i => s"i$i").toFastIndexedSeq - val indices = Array.range(0, arrays(0)._1.st.nDims).toFastIndexedSeq + val indexVars = Array.tabulate(st.nDims)(i => s"i$i").toFastIndexedSeq + val indices = Array.range(0, st.nDims).toFastIndexedSeq coiterateMutate(cb, region, deepCopy, indexVars, indices, arrays.map { case (array, name) => (array, indices, name) }: _*)(body) } diff --git a/hail/src/test/scala/is/hail/expr/ir/FoldConstantsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/FoldConstantsSuite.scala index 69b42c0439a..29321325e55 100644 --- a/hail/src/test/scala/is/hail/expr/ir/FoldConstantsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/FoldConstantsSuite.scala @@ -8,7 +8,7 @@ import org.testng.annotations.{DataProvider, Test} class FoldConstantsSuite extends HailSuite { @Test def testRandomBlocksFolding() { - val x = ApplySeeded("rand_norm", Seq(F64(0d), F64(0d)), NA(TRNGState), 0L, TFloat64) + val x = ApplySeeded("rand_norm", Seq(F64(0d), F64(0d)), RNGStateLiteral(), 0L, TFloat64) assert(FoldConstants(ctx, x) == x) } diff --git a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala index 28a19af11a6..9de8924e101 100644 --- a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala @@ -2,13 +2,11 @@ package is.hail.expr.ir import is.hail.ExecStrategy.ExecStrategy import is.hail.TestUtils._ -import is.hail.annotations.{BroadcastRow, ExtendedOrdering, Region, SafeNDArray} -import is.hail.asm4s.{Code, Value} +import is.hail.annotations.{BroadcastRow, ExtendedOrdering, SafeNDArray} import is.hail.backend.ExecuteContext import is.hail.expr.Nat import is.hail.expr.ir.ArrayZipBehavior.ArrayZipBehavior import is.hail.expr.ir.IRBuilder._ -import is.hail.expr.ir.IRSuite.TestFunctions import is.hail.expr.ir.agg._ import is.hail.expr.ir.functions._ import is.hail.io.bgen.{IndexBgen, MatrixBGENReader} @@ -30,67 +28,30 @@ import org.testng.annotations.{DataProvider, Test} import scala.language.{dynamics, implicitConversions} -object IRSuite { - outer => - var globalCounter: Int = 0 - - def incr(): Unit = { - globalCounter += 1 - } - - object TestFunctions extends RegistryFunctions { - - def registerSeededWithMissingness( - name: String, - valueParameterTypes: Array[Type], - returnType: Type, - calculateReturnType: (Type, Seq[EmitType]) => EmitType - )( - impl: (EmitCodeBuilder, Value[Region], SType, Long, Array[EmitCode]) => IEmitCode - ) { - IRFunctionRegistry.addJVMFunction( - new SeededMissingnessAwareJVMFunction(name, valueParameterTypes, returnType, calculateReturnType) { - val isDeterministic: Boolean = false - def applySeededI(seed: Long, cb: EmitCodeBuilder, r: Value[Region], returnPType: SType, args: EmitCode*): IEmitCode = { - assert(unify(FastSeq(), args.map(_.st.virtualType), returnPType.virtualType)) - impl(cb, r, returnPType, seed, args.toArray) - } - } - ) - } +class IRSuite extends HailSuite { + implicit val execStrats = ExecStrategy.nonLowering - def registerSeededWithMissingness1( - name: String, - valueParameterType: Type, - returnType: Type, - calculateReturnType: (Type, EmitType) => EmitType - )( - impl: (EmitCodeBuilder, Value[Region], SType, Long, EmitCode) => IEmitCode - ): Unit = - registerSeededWithMissingness(name, Array(valueParameterType), returnType, unwrappedApply(calculateReturnType)) { - case (cb, r, rt, seed, Array(a1)) => impl(cb, r, rt, seed, a1) - } + @Test def testRandDifferentLengthUIDStrings() { + implicit val execStrats = ExecStrategy.lowering + val staticUID: Long = 112233 + var rng: IR = RNGStateLiteral() + rng = RNGSplit(rng, I64(12345)) + val expected1 = Threefry.pmac(ctx.rngNonce, staticUID, Array(12345L)) + assertEvalsTo(ApplySeeded("rand_int64", Seq(), rng, staticUID, TInt64), expected1(0)) - def registerAll() { - registerSeededWithMissingness1("incr_s", TBoolean, TBoolean, { (ret: Type, pt: EmitType) => pt }) { case (cb, r, _, _, l) => - cb += Code.invokeScalaObject0[Unit](outer.getClass, "incr") - l.toI(cb) - } + rng = RNGSplit(rng, I64(0)) + val expected2 = Threefry.pmac(ctx.rngNonce, staticUID, Array(12345L, 0L)) + assertEvalsTo(ApplySeeded("rand_int64", Seq(), rng, staticUID, TInt64), expected2(0)) - registerSeededWithMissingness1("incr_v", TBoolean, TBoolean, { (ret: Type, pt: EmitType) => pt }) { case (cb, _, _, _, l) => - l.toI(cb).map(cb) { pc => - cb += Code.invokeScalaObject0[Unit](outer.getClass, "incr") - pc - } - } - } + rng = RNGSplit(rng, I64(0)) + rng = RNGSplit(rng, I64(0)) + val expected3 = Threefry.pmac(ctx.rngNonce, staticUID, Array(12345L, 0L, 0L, 0L)) + assertEvalsTo(ApplySeeded("rand_int64", Seq(), rng, staticUID, TInt64), expected3(0)) + assert(expected1 != expected2) + assert(expected2 != expected3) + assert(expected1 != expected3) } -} - -class IRSuite extends HailSuite { - implicit val execStrats = ExecStrategy.nonLowering - @Test def testI32() { assertEvalsTo(I32(5), 5) } @@ -2722,7 +2683,7 @@ class IRSuite extends HailSuite { val nd = MakeNDArray(MakeArray(FastSeq(I32(-1), I32(1)), TArray(TInt32)), MakeTuple.ordered(FastSeq(I64(1), I64(2))), True(), ErrorIDs.NO_ERROR) - val rngState = RNGStateLiteral(Array(1L, 2L, 3L, 4L)) + val rngState = RNGStateLiteral() def collect(ir: IR): IR = ApplyAggOp(FastIndexedSeq.empty, FastIndexedSeq(ir), collectSig) @@ -3152,75 +3113,6 @@ class IRSuite extends HailSuite { assert(hc.irVectors.get(id) eq None) } - @Test def testEvaluations() { - TestFunctions.registerAll() - - def test(x: IR, i: java.lang.Boolean, expectedEvaluations: Int) { - val env = Env.empty[(Any, Type)] - val args = FastIndexedSeq((i, TBoolean)) - - IRSuite.globalCounter = 0 - Interpret[Any](ctx, x, env, args, optimize = false) - assert(IRSuite.globalCounter == expectedEvaluations) - - IRSuite.globalCounter = 0 - Interpret[Any](ctx, x, env, args) - assert(IRSuite.globalCounter == expectedEvaluations) - - IRSuite.globalCounter = 0 - eval(x, env, args, None, None, true, ctx) - assert(IRSuite.globalCounter == expectedEvaluations) - } - - def i = In(0, TBoolean) - - def rngState = RNGStateLiteral() - - def st = ApplySeeded("incr_s", FastSeq(True()), rngState, 0L, TBoolean) - - def sf = ApplySeeded("incr_s", FastSeq(True()), rngState, 0L, TBoolean) - - def sm = ApplySeeded("incr_s", FastSeq(NA(TBoolean)), rngState, 0L, TBoolean) - - def vt = ApplySeeded("incr_v", FastSeq(True()), rngState, 0L, TBoolean) - - def vf = ApplySeeded("incr_v", FastSeq(True()), rngState, 0L, TBoolean) - - def vm = ApplySeeded("incr_v", FastSeq(NA(TBoolean)), rngState, 0L, TBoolean) - - // baseline - test(st, true, 1); test(sf, true, 1); test(sm, true, 1) - test(vt, true, 1); test(vf, true, 1); test(vm, true, 0) - - // if - // condition - test(If(st, i, True()), true, 1) - test(If(sf, i, True()), true, 1) - test(If(sm, i, True()), true, 1) - - test(If(vt, i, True()), true, 1) - test(If(vf, i, True()), true, 1) - test(If(vm, i, True()), true, 0) - - // consequent - test(If(i, st, True()), true, 1) - test(If(i, sf, True()), true, 1) - test(If(i, sm, True()), true, 1) - - test(If(i, vt, True()), true, 1) - test(If(i, vf, True()), true, 1) - test(If(i, vm, True()), true, 0) - - // alternate - test(If(i, True(), st), false, 1) - test(If(i, True(), sf), false, 1) - test(If(i, True(), sm), false, 1) - - test(If(i, True(), vt), false, 1) - test(If(i, True(), vf), false, 1) - test(If(i, True(), vm), false, 0) - } - @Test def testArrayContinuationDealsWithIfCorrectly() { val ir = ToArray(StreamMap( If(IsNA(In(0, TBoolean)), diff --git a/hail/src/test/scala/is/hail/expr/ir/MatrixIRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/MatrixIRSuite.scala index 1240aa50bb9..2f194636d95 100644 --- a/hail/src/test/scala/is/hail/expr/ir/MatrixIRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/MatrixIRSuite.scala @@ -55,9 +55,12 @@ class MatrixIRSuite extends HailSuite { } } - def rangeMatrix(nRows: Int = 20, nCols: Int = 20, nPartitions: Option[Int] = Some(4)): MatrixIR = { + def rangeMatrix(nRows: Int = 20, nCols: Int = 20, nPartitions: Option[Int] = Some(4), uids: Boolean = false): MatrixIR = { val reader = MatrixRangeReader(nRows, nCols, nPartitions) - val requestedType = reader.fullMatrixTypeWithoutUIDs + val requestedType = if (uids) + reader.fullMatrixType + else + reader.fullMatrixTypeWithoutUIDs MatrixRead(requestedType, false, false, reader) } @@ -277,12 +280,18 @@ class MatrixIRSuite extends HailSuite { } @Test def testMatrixFiltersWorkWithRandomness() { - val range = rangeMatrix(20, 20, Some(4)) - val rand = ApplySeeded("rand_bool", FastIndexedSeq(0.5), RNGStateLiteral(), seed=0, TBoolean) - - val cols = Interpret(MatrixFilterCols(range, rand), ctx, optimize = true).toMatrixValue(range.typ.colKey).nCols - val rows = Interpret(MatrixFilterRows(range, rand), ctx, optimize = true).rvd.count() - val entries = Interpret(MatrixEntriesTable(MatrixFilterEntries(range, rand)), ctx, optimize = true).rvd.count() + val range = rangeMatrix(20, 20, Some(4), uids = true) + def rand(rng: IR): IR = + ApplySeeded("rand_bool", FastIndexedSeq(0.5), rng, 0, TBoolean) + + val colUID = GetField(Ref("sa", range.typ.colType), MatrixReader.colUIDFieldName) + val colRNG = RNGSplit(RNGStateLiteral(), colUID) + val cols = Interpret(MatrixFilterCols(range, rand(colRNG)), ctx, optimize = true).toMatrixValue(range.typ.colKey).nCols + val rowUID = GetField(Ref("va", range.typ.rowType), MatrixReader.rowUIDFieldName) + val rowRNG = RNGSplit(RNGStateLiteral(), rowUID) + val rows = Interpret(MatrixFilterRows(range, rand(rowRNG)), ctx, optimize = true).rvd.count() + val entryRNG = RNGSplit(RNGStateLiteral(), MakeTuple.ordered(FastSeq(rowUID, colUID))) + val entries = Interpret(MatrixEntriesTable(MatrixFilterEntries(range, rand(entryRNG))), ctx, optimize = true).rvd.count() assert(cols < 20 && cols > 0) assert(rows < 20 && rows > 0) diff --git a/hail/src/test/scala/is/hail/expr/ir/RandomFunctionsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/RandomFunctionsSuite.scala deleted file mode 100644 index a5c10c11b02..00000000000 --- a/hail/src/test/scala/is/hail/expr/ir/RandomFunctionsSuite.scala +++ /dev/null @@ -1,153 +0,0 @@ -package is.hail.expr.ir - -import is.hail.TestUtils._ -import is.hail.expr.ir.TestUtils._ -import is.hail.asm4s.Code -import is.hail.backend.ExecuteContext -import is.hail.expr.ir.functions.{IRRandomness, RegistryFunctions} -import is.hail.types.physical.stypes.interfaces._ -import is.hail.types.physical.stypes.primitives.{SInt32, SInt64} -import is.hail.types.virtual.{TArray, TFloat64, TInt32, TInt64, TRNGState, TStream} -import is.hail.utils._ -import is.hail.{ExecStrategy, HailSuite} -import org.apache.spark.sql.Row -import org.testng.annotations.{BeforeClass, Test} - -class TestIRRandomness(val seed: Long) extends IRRandomness(seed) { - private[this] var i = -1 - var partitionIndex: Int = 0 - - override def reset(pidx: Int) { - super.reset(pidx) - partitionIndex = 0 - i = -1 - } - - def counter(): Int = { - i += 1 - i - } -} - -object TestRandomFunctions extends RegistryFunctions { - def getTestRNG(mb: EmitMethodBuilder[_], seed: Long): Code[TestIRRandomness] = { - val rng = mb.genFieldThisRef[IRRandomness]() - mb.ecb.rngs += rng -> Code.checkcast[IRRandomness](Code.newInstance[TestIRRandomness, Long](seed)) - Code.checkcast[TestIRRandomness](rng) - } - - def registerAll() { - registerSeeded0("counter_seeded", TInt32, SInt32) { case (cb, r, rt, seed) => - primitive(cb.memoize(getTestRNG(cb.emb, seed).invoke[Int]("counter"))) - } - - registerSeeded0("seed_seeded", TInt64, SInt64) { case (cb, r, rt, seed) => - primitive(cb.memoize(getTestRNG(cb.emb, seed).invoke[Long]("seed"))) - } - - registerSeeded0("pi_seeded", TInt32, SInt32) { case (cb, r, rt, seed) => - primitive(cb.memoize(getTestRNG(cb.emb, seed).invoke[Int]("partitionIndex"))) - } - } -} - -class RandomFunctionsSuite extends HailSuite { - - implicit val execStrats = ExecStrategy.javaOnly - - def counter = ApplySeeded("counter_seeded", FastSeq(), RNGStateLiteral(), 0L, TInt32) - val partitionIdx = ApplySeeded("pi_seeded", FastSeq(), RNGStateLiteral(), 0L, TInt32) - - def mapped2(n: Int, npart: Int) = TableMapRows( - TableRange(n, npart), - InsertFields(Ref("row", TableRange(1, 1).typ.rowType), - FastSeq( - "pi" -> partitionIdx, - "counter" -> counter))) - - @BeforeClass def registerFunctions() { - TestRandomFunctions.registerAll() - } - - @Test def testRandomAcrossJoins() { - def asArray(ir: TableIR) = Interpret(ir, ctx).rdd.collect() - - val joined = TableJoin( - mapped2(10, 4), - TableRename(mapped2(10, 3), Map("pi" -> "pi2", "counter" -> "counter2"), Map.empty), - "left") - - val expected = asArray(mapped2(10, 4)).zip(asArray(mapped2(10, 3))) - .map { case (Row(idx1, pi1, c1), Row(idx2, pi2, c2)) => - assert(idx1 == idx2) - Row(idx1, pi1, c1, pi2, c2) - } - - assert(asArray(joined) sameElements expected) - } - - @Test def testRepartitioningAfterRandomness() { - val mapped = Interpret(mapped2(15, 4), ctx).rvd - val newRangeBounds = FastIndexedSeq( - Interval(Row(0), Row(4), true, true), - Interval(Row(4), Row(10), false, true), - Interval(Row(10), Row(14), false, true)) - val newPartitioner = mapped.partitioner.copy(rangeBounds=newRangeBounds) - - ExecuteContext.scoped() { ctx => - val repartitioned = mapped.repartition(ctx, newPartitioner) - val cachedAndRepartitioned = mapped.cache(ctx).repartition(ctx, newPartitioner) - - assert(mapped.toRows.collect() sameElements repartitioned.toRows.collect()) - assert(mapped.toRows.collect() sameElements cachedAndRepartitioned.toRows.collect()) - } - } - - @Test def testInterpretIncrementsCorrectly() { - assertEvalsTo( - ToArray(StreamMap(StreamRange(0, 3, 1), "i", counter * counter)), - FastIndexedSeq(0, 1, 4)) - - assertEvalsTo( - StreamFold(StreamRange(0, 3, 1), -1, "j", "i", counter + counter), - 4) - - assertEvalsTo( - ToArray(StreamFilter(StreamRange(0, 3, 1), "i", Ref("i", TInt32).ceq(counter) && counter.ceq(counter))), - FastIndexedSeq(0, 1, 2)) - - assertEvalsTo( - ToArray(StreamFlatMap(StreamRange(0, 3, 1), - "i", - MakeStream(FastSeq(counter, counter, counter), TStream(TInt32)))), - FastIndexedSeq(0, 0, 0, 1, 1, 1, 2, 2, 2)) - } - - @Test def testRepartitioningSimplifyRules() { - val tir = - TableMapRows( - TableHead( - TableMapRows( - TableRange(10, 3), - Ref("row", TableRange(1, 1).typ.rowType)), - 5L), - InsertFields( - Ref("row", TableRange(1, 1).typ.rowType), - FastSeq( - "pi" -> partitionIdx, - "counter" -> counter))) - - val expected = Interpret(tir, ctx).rvd.toRows.collect() - val actual = CompileAndEvaluate[IndexedSeq[Row]](ctx, GetField(collect(tir), "rows"), false) - - assert(expected.sameElements(actual)) - } - - @Test def testRandCat() { - val seed = 5L - assertEvalsTo(invokeSeeded("rand_cat", seed, TInt32, NA(TRNGState), MakeArray(IndexedSeq[IR](0.1), TArray(TFloat64))), 0) - assertEvalsTo(invokeSeeded("rand_cat", seed, TInt32, NA(TRNGState), MakeArray(IndexedSeq[IR](0.3, 0.2, 0.95, 0.05), TArray(TFloat64))), 1) - assertEvalsTo(invokeSeeded("rand_cat", seed, TInt32, NA(TRNGState), NA(TArray(TFloat64))), null) - assertFatal(invokeSeeded("rand_cat", seed, TInt32, NA(TRNGState), MakeArray(IndexedSeq[IR](0.3, NA(TFloat64)), TArray(TFloat64))), "rand_cat") - } -} diff --git a/hail/src/test/scala/is/hail/expr/ir/RandomSuite.scala b/hail/src/test/scala/is/hail/expr/ir/RandomSuite.scala index 1db3c97e1ae..b0b5d8ca0ee 100644 --- a/hail/src/test/scala/is/hail/expr/ir/RandomSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/RandomSuite.scala @@ -1,26 +1,178 @@ package is.hail.expr.ir import is.hail.HailSuite +import is.hail.asm4s._ +import is.hail.types.physical.stypes.concrete.{SCanonicalRNGStateSettable, SCanonicalRNGStateValue, SRNGState, SRNGStateStaticSizeValue} +import is.hail.utils.FastIndexedSeq import org.apache.commons.math3.distribution.ChiSquaredDistribution import org.testng.annotations.Test class RandomSuite extends HailSuite { + // from skein_golden_kat_short_internals.txt in the skein source + val threefryTestCases = FastIndexedSeq( + ( + Array(0x0L, 0x0L, 0x0L, 0x0L), + Array(0x0L, 0x0L), + Array(0x0L, 0x0L, 0x0L, 0x0L), + Array(0x09218EBDE6C85537L, 0x55941F5266D86105L, 0x4BD25E16282434DCL, 0xEE29EC846BD2E40BL) + ), ( + Array(0x1716151413121110L, 0x1F1E1D1C1B1A1918L, 0x2726252423222120L, 0x2F2E2D2C2B2A2928L), + Array(0x0706050403020100L, 0x0F0E0D0C0B0A0908L), + Array(0xF8F9FAFBFCFDFEFFL, 0xF0F1F2F3F4F5F6F7L, 0xE8E9EAEBECEDEEEFL, 0xE0E1E2E3E4E5E6E7L), + Array(0x008CF75D18C19DA0L, 0x1D7D14BE2266E7D8L, 0x5D09E0E985FE673BL, 0xB4A5480C6039B172L) + )) + @Test def testThreefry() { - val k = Array.fill[Long](4)(0) - val tf = Threefry(k) - val x = Array.fill[Long](4)(0) - val expected = Array( - 0x09218EBDE6C85537L, - 0x55941F5266D86105L, - 0x4BD25E16282434DCL, - 0xEE29EC846BD2E40BL - ) - tf(x, 0) - assert(x sameElements expected) - - val rand = new ThreefryRandomEngine(k, Array.fill(4)(0L), 0, tweak = 0) - val y = Array.fill(4)(rand.nextLong()) - assert(y sameElements expected) + for { + (key, tweak, input, expected) <- threefryTestCases + } { + val expandedKey = Threefry.expandKey(key) + val tf = Threefry(key) + + var x = input.clone() + tf(x, tweak) + assert(x sameElements expected) + + x = input.clone() + Threefry.encryptUnrolled( + expandedKey(0), expandedKey(1), expandedKey(2), expandedKey(3), expandedKey(4), + tweak(0), tweak(1), x) + assert(x sameElements expected) + + x = input.clone() + Threefry.encrypt(expandedKey, tweak, x) + assert(x sameElements expected) + } + } + + def pmacStagedStaticSize(staticID: Long, size: Int): AsmFunction1[Array[Long], Array[Long]] = { + val f = EmitFunctionBuilder[Array[Long], Array[Long]](ctx, "pmacStaticSize") + f.emb.emitWithBuilder { cb => + val message = f.mb.getArg[Array[Long]](1) + var state = SRNGStateStaticSizeValue(cb) + for (i <- 0 until size) { + state = state.splitDyn(cb, cb.memoize(message(i))) + } + state = state.splitStatic(cb, staticID) + + val result = state.rand(cb) + val resArray = cb.memoize(Code.newArray[Long](4)) + cb.append(resArray(0) = result(0)) + cb.append(resArray(1) = result(1)) + cb.append(resArray(2) = result(2)) + cb.append(resArray(3) = result(3)) + + resArray + } + f.result(ctx)(new HailClassLoader(getClass.getClassLoader)) + } + + def pmacEngineStagedStaticSize(staticID: Long, size: Int): AsmFunction1[Array[Long], ThreefryRandomEngine] = { + val f = EmitFunctionBuilder[Array[Long], ThreefryRandomEngine](ctx, "pmacStaticSize") + f.emb.emitWithBuilder { cb => + val message = f.mb.getArg[Array[Long]](1) + var state = SRNGStateStaticSizeValue(cb) + for (i <- 0 until size) { + state = state.splitDyn(cb, cb.memoize(message(i))) + } + state = state.splitStatic(cb, staticID) + + val engine = cb.memoize(Code.invokeScalaObject0[ThreefryRandomEngine]( + ThreefryRandomEngine.getClass, "apply")) + state.copyIntoEngine(cb, engine) + engine + } + f.result(ctx)(new HailClassLoader(getClass.getClassLoader)) + } + + def pmacStagedDynSize(staticID: Long): AsmFunction1[Array[Long], Array[Long]] = { + val f = EmitFunctionBuilder[Array[Long], Array[Long]](ctx, "pmacDynSize") + f.emb.emitWithBuilder { cb => + val message = f.mb.getArg[Array[Long]](1) + val state = cb.newSLocal(SRNGState(None), "state").asInstanceOf[SCanonicalRNGStateSettable] + cb.assign(state, SCanonicalRNGStateValue(cb)) + val i = cb.newLocal[Int]("i", 0) + val len = cb.memoize(message.length()) + cb.forLoop({}, i < len, cb.assign(i, i + 1), { + cb.assign(state, state.splitDyn(cb, cb.memoize(message(i)))) + }) + cb.assign(state, state.splitStatic(cb, staticID)) + + val result = state.rand(cb) + val resArray = cb.memoize(Code.newArray[Long](4)) + cb.append(resArray(0) = result(0)) + cb.append(resArray(1) = result(1)) + cb.append(resArray(2) = result(2)) + cb.append(resArray(3) = result(3)) + + resArray + } + f.result(ctx)(new HailClassLoader(getClass.getClassLoader)) + } + + def pmacEngineStagedDynSize(staticID: Long): AsmFunction1[Array[Long], ThreefryRandomEngine] = { + val f = EmitFunctionBuilder[Array[Long], ThreefryRandomEngine](ctx, "pmacDynSize") + f.emb.emitWithBuilder { cb => + val message = f.mb.getArg[Array[Long]](1) + val state = cb.newSLocal(SRNGState(None), "state").asInstanceOf[SCanonicalRNGStateSettable] + cb.assign(state, SCanonicalRNGStateValue(cb)) + val i = cb.newLocal[Int]("i", 0) + val len = cb.memoize(message.length()) + cb.forLoop({}, i < len, cb.assign(i, i + 1), { + cb.assign(state, state.splitDyn(cb, cb.memoize(message(i)))) + }) + cb.assign(state, state.splitStatic(cb, staticID)) + + val engine = cb.memoize(Code.invokeScalaObject0[ThreefryRandomEngine]( + ThreefryRandomEngine.getClass, "apply")) + state.copyIntoEngine(cb, engine) + engine + } + f.result(ctx)(new HailClassLoader(getClass.getClassLoader)) + } + + val pmacTestCases = FastIndexedSeq( + (Array[Long](), 0L), + (Array[Long](100, 101), 10L), + (Array[Long](100, 101, 102, 103), 20L), + (Array[Long](100, 101, 102, 103, 104), 30L) + ) + + @Test def testPMAC() { + for { + (message, staticID) <- pmacTestCases + } { + val res1 = Threefry.pmac(ctx.rngNonce, staticID, message) + val res2 = pmacStagedStaticSize(staticID, message.length)(message) + val res3 = pmacStagedDynSize(staticID)(message) + assert(res1 sameElements res2) + assert(res1 sameElements res3) + } + } + + @Test def testRandomEngine() { + for { + (message, staticID) <- pmacTestCases + } { + val (hash, finalTweak) = Threefry.pmacHash(ctx.rngNonce, staticID, message) + val engine1 = pmacEngineStagedStaticSize(staticID, message.length)(message) + val engine2 = pmacEngineStagedDynSize(staticID)(message) + + var expected = hash.clone() + Threefry.encrypt(Threefry.defaultKey, Array(finalTweak, 0L), expected) + assert(Array.fill(4)(engine1.nextLong()) sameElements expected) + assert(Array.fill(4)(engine2.nextLong()) sameElements expected) + + expected = hash.clone() + Threefry.encrypt(Threefry.defaultKey, Array(finalTweak, 1L), expected) + assert(Array.fill(4)(engine1.nextLong()) sameElements expected) + assert(Array.fill(4)(engine2.nextLong()) sameElements expected) + + expected = hash.clone() + Threefry.encrypt(Threefry.defaultKey, Array(finalTweak, 2L), expected) + assert(Array.fill(4)(engine1.nextLong()) sameElements expected) + assert(Array.fill(4)(engine2.nextLong()) sameElements expected) + } } def runChiSquareTest(samples: Int, buckets: Int)(sample: => Int) { @@ -46,7 +198,7 @@ class RandomSuite extends HailSuite { @Test def testRandomInt() { val n = 1 << 25 val k = 1 << 15 - val rand = ThreefryRandomEngine() + val rand = ThreefryRandomEngine.randState() runChiSquareTest(n, k) { rand.nextInt() & (k - 1) } @@ -55,7 +207,7 @@ class RandomSuite extends HailSuite { @Test def testBoundedUniformInt() { var n = 1 << 25 var k = 1 << 15 - val rand = ThreefryRandomEngine() + val rand = ThreefryRandomEngine.randState() runChiSquareTest(n, k) { rand.nextInt(k) } @@ -70,7 +222,7 @@ class RandomSuite extends HailSuite { @Test def testBoundedUniformLong() { var n = 1 << 25 var k = 1 << 15 - val rand = ThreefryRandomEngine() + val rand = ThreefryRandomEngine.randState() runChiSquareTest(n, k) { rand.nextLong(k).toInt } @@ -85,7 +237,7 @@ class RandomSuite extends HailSuite { @Test def testUniformDouble() { val n = 1 << 25 val k = 1 << 15 - val rand = ThreefryRandomEngine() + val rand = ThreefryRandomEngine.randState() runChiSquareTest(n, k) { val r = rand.nextDouble() assert(r >= 0.0 && r < 1.0, r) @@ -96,7 +248,7 @@ class RandomSuite extends HailSuite { @Test def testUniformFloat() { val n = 1 << 25 val k = 1 << 15 - val rand = ThreefryRandomEngine() + val rand = ThreefryRandomEngine.randState() runChiSquareTest(n, k) { val r = rand.nextFloat() assert(r >= 0.0 && r < 1.0, r) diff --git a/hail/src/test/scala/is/hail/linalg/BlockMatrixSuite.scala b/hail/src/test/scala/is/hail/linalg/BlockMatrixSuite.scala index 2a98acc86b5..b09c0758a1e 100644 --- a/hail/src/test/scala/is/hail/linalg/BlockMatrixSuite.scala +++ b/hail/src/test/scala/is/hail/linalg/BlockMatrixSuite.scala @@ -716,20 +716,27 @@ class BlockMatrixSuite extends HailSuite { @Test def randomTest() { - var lm1 = BlockMatrix.random(5, 10, 2, seed = 1, gaussian = false).toBreezeMatrix() - var lm2 = BlockMatrix.random(5, 10, 2, seed = 1, gaussian = false).toBreezeMatrix() - var lm3 = BlockMatrix.random(5, 10, 2, seed = 2, gaussian = false).toBreezeMatrix() + var lm1 = BlockMatrix.random(5, 10, 2, staticUID = 1, nonce = 1, gaussian = false).toBreezeMatrix() + var lm2 = BlockMatrix.random(5, 10, 2, staticUID = 1, nonce = 1, gaussian = false).toBreezeMatrix() + var lm3 = BlockMatrix.random(5, 10, 2, staticUID = 2, nonce = 1, gaussian = false).toBreezeMatrix() + var lm4 = BlockMatrix.random(5, 10, 2, staticUID = 1, nonce = 2, gaussian = false).toBreezeMatrix() + println(lm1) assert(lm1 === lm2) assert(lm1 !== lm3) + assert(lm1 !== lm4) + assert(lm3 !== lm4) assert(lm1.data.forall(x => x >= 0 && x <= 1)) - lm1 = BlockMatrix.random(5, 10, 2, seed = 1, gaussian = true).toBreezeMatrix() - lm2 = BlockMatrix.random(5, 10, 2, seed = 1, gaussian = true).toBreezeMatrix() - lm3 = BlockMatrix.random(5, 10, 2, seed = 2, gaussian = true).toBreezeMatrix() + lm1 = BlockMatrix.random(5, 10, 2, staticUID = 1, nonce = 1, gaussian = true).toBreezeMatrix() + lm2 = BlockMatrix.random(5, 10, 2, staticUID = 1, nonce = 1, gaussian = true).toBreezeMatrix() + lm3 = BlockMatrix.random(5, 10, 2, staticUID = 2, nonce = 1, gaussian = true).toBreezeMatrix() + lm4 = BlockMatrix.random(5, 10, 2, staticUID = 1, nonce = 2, gaussian = true).toBreezeMatrix() assert(lm1 === lm2) assert(lm1 !== lm3) + assert(lm1 !== lm4) + assert(lm3 !== lm4) } @Test