Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Oct 18, 2022
1 parent 49ba543 commit 5e05d33
Show file tree
Hide file tree
Showing 9 changed files with 15 additions and 17 deletions.
4 changes: 2 additions & 2 deletions hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2462,7 +2462,7 @@ def rand_bool(p, seed=None) -> BooleanExpression:
return _seeded_func("rand_bool", tbool, seed, p)


@typecheck(mean=expr_float64, sd=expr_float64, seed=nullable(int), size=nullable(tupleof(expr_int32)))
@typecheck(mean=expr_float64, sd=expr_float64, seed=nullable(int), size=nullable(tupleof(expr_int64)))
def rand_norm(mean=0, sd=1, seed=None, size=None) -> Float64Expression:
"""Samples from a normal distribution with mean `mean` and standard
deviation `sd`.
Expand Down Expand Up @@ -2588,7 +2588,7 @@ def rand_pois(lamb, seed=None) -> Float64Expression:
return _seeded_func("rand_pois", tfloat64, seed, lamb)


@typecheck(lower=expr_float64, upper=expr_float64, seed=nullable(int), size=nullable(tupleof(expr_int32)))
@typecheck(lower=expr_float64, upper=expr_float64, seed=nullable(int), size=nullable(tupleof(expr_int64)))
def rand_unif(lower=0.0, upper=1.0, seed=None, size=None) -> Float64Expression:
"""Samples from a uniform distribution within the interval
[`lower`, `upper`].
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/ir/blockmatrix_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def __init__(self, static_rng_uid, gaussian, shape, block_size):
self.block_size = block_size

def head_str(self):
return '{} {} {} {} {}'.format(self.static_rng_uid,
return '{} {} {} {}'.format(self.static_rng_uid,
self.gaussian,
_serialize_list(self.shape),
self.block_size)
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from hail.ir.blockmatrix_writer import BlockMatrixWriter, BlockMatrixMultiWriter
from hail.typecheck import typecheck, typecheck_method, sequenceof, numeric, \
sized_tupleof, nullable, tupleof, anytype, func_spec
from hail.utils.java import Env, HailUserError, warning
from hail.utils.java import Env, HailUserError
from hail.utils.misc import escape_str, parsable_strings, escape_id
from hail.utils.jsonx import dump_json
from .utils import default_row_uid, default_col_uid, unpack_row_uid, unpack_col_uid
Expand Down
1 change: 1 addition & 0 deletions hail/python/hail/ir/matrix_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from hail.ir.base_ir import BaseIR, MatrixIR
from hail.ir.utils import modify_deep_field, zip_with_index, zip_with_index_field, default_row_uid, default_col_uid, unpack_row_uid, unpack_col_uid
import hail.ir.ir as ir
from hail.utils import FatalError
from hail.utils.misc import escape_str, parsable_strings, escape_id
from hail.utils.jsonx import dump_json
from hail.utils.java import Env
Expand Down
1 change: 0 additions & 1 deletion hail/python/hail/utils/java.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import sys
import re

import hail
from hailtop.config import configuration_of


Expand Down
2 changes: 1 addition & 1 deletion hail/python/test/hail/test_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def value_irs(self):
s = ir.Ref('s', env['s'])
t = ir.Ref('t', env['t'])
call = ir.Ref('call', env['call'])
rngState = ir.RNGStateLiteral((1, 2, 3, 4))
rngState = ir.RNGStateLiteral()

table = ir.TableRange(5, 3)

Expand Down
2 changes: 2 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Random.scala
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ class ThreefryRandomEngine(

private[this] val poisState = Poisson.create_random_state()

def runif(min: Double, max: Double): Double = min + (max - min) * nextDouble()

def rnorm(mean: Double, sd: Double): Double = mean + sd * nextGaussian()

def rpois(lambda: Double): Double = Poisson.random(lambda, this, poisState)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,10 @@ object RandomSeededFunctions extends RegistryFunctions {
primitive(cb.memoize(rand_unif(cb, rngState.rand(cb)) * (max.value - min.value) + min.value))
}

registerSCode5("rand_unif_nd", TRNGState, TInt32, TInt32, TFloat64, TFloat64, TNDArray(TFloat64, Nat(2)), {
registerSCode5("rand_unif_nd", TRNGState, TInt64, TInt64, TFloat64, TFloat64, TNDArray(TFloat64, Nat(2)), {
case (_: Type, _: SType, _: SType, _: SType, _: SType, _: SType) => PCanonicalNDArray(PFloat64(true), 2, true).sType
}) { case (r, cb, rt: SNDArrayPointer, rngState: SRNGStateValue, nRows: SInt32Value, nCols: SInt32Value, min, max, errorID) =>
val nRowsL = cb.memoize(nRows.value.toL)
val nColsL = cb.memoize(nCols.value.toL)
val result = rt.pType.constructUnintialized(FastIndexedSeq(SizeValueDyn(nRowsL), SizeValueDyn(nColsL)), cb, r.region)
}) { case (r, cb, rt: SNDArrayPointer, rngState: SRNGStateValue, nRows: SInt64Value, nCols: SInt64Value, min, max, errorID) =>
val result = rt.pType.constructUnintialized(FastIndexedSeq(SizeValueDyn(nRows.value), SizeValueDyn(nCols.value)), cb, r.region)
val rng = cb.emb.getThreefryRNG()
rngState.copyIntoEngine(cb, rng)
result.coiterateMutate(cb, r.region) { _ =>
Expand Down Expand Up @@ -138,12 +136,10 @@ object RandomSeededFunctions extends RegistryFunctions {
primitive(rngState.rand(cb)(0))
}

registerSCode5("rand_norm_nd", TRNGState, TInt32, TInt32, TFloat64, TFloat64, TNDArray(TFloat64, Nat(2)), {
registerSCode5("rand_norm_nd", TRNGState, TInt64, TInt64, TFloat64, TFloat64, TNDArray(TFloat64, Nat(2)), {
case (_: Type, _: SType, _: SType, _: SType, _: SType, _: SType) => PCanonicalNDArray(PFloat64(true), 2, true).sType
}) { case (r, cb, rt: SNDArrayPointer, rngState: SRNGStateValue, nRows: SInt32Value, nCols: SInt32Value, mean, sd, errorID) =>
val nRowsL = cb.memoize(nRows.value.toL)
val nColsL = cb.memoize(nCols.value.toL)
val result = rt.pType.constructUnintialized(FastIndexedSeq(SizeValueDyn(nRowsL), SizeValueDyn(nColsL)), cb, r.region)
}) { case (r, cb, rt: SNDArrayPointer, rngState: SRNGStateValue, nRows: SInt64Value, nCols: SInt64Value, mean, sd, errorID) =>
val result = rt.pType.constructUnintialized(FastIndexedSeq(SizeValueDyn(nRows.value), SizeValueDyn(nCols.value)), cb, r.region)
val rng = cb.emb.getThreefryRNG()
rngState.copyIntoEngine(cb, rng)
result.coiterateMutate(cb, r.region) { _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ object LowerBlockMatrixIR {
bmir match {
case BlockMatrixRead(reader) => reader.lower(ctx)
case x@BlockMatrixRandom(staticUID, gaussian, shape, blockSize) =>
new BlockMatrixStage(IndexedSeq(), Array(), TTuple(TInt64, TInt64)) {
new BlockMatrixStage(IndexedSeq(), Array(), TTuple(TInt64, TInt64, TInt32)) {
def blockContext(idx: (Int, Int)): IR = {
val (m, n) = x.typ.blockShape(idx._1, idx._2)
MakeTuple.ordered(FastSeq(m, n, idx._1 * x.typ.nColBlocks + idx._2))
Expand Down

0 comments on commit 5e05d33

Please sign in to comment.