diff --git a/hail/src/main/scala/is/hail/expr/ir/Binds.scala b/hail/src/main/scala/is/hail/expr/ir/Binds.scala index d109cf42a48..fd7615f1f53 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Binds.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Binds.scala @@ -106,41 +106,15 @@ final case class Bindings( object Bindings { val empty: Bindings = Bindings(FastSeq.empty, AggEnv.NoOp, AggEnv.NoOp, FastSeq.empty, false) - // Create a `Bindings` which cannot see anything bound in the enclosing context. - private def inFreshScope( - eval: IndexedSeq[(String, Type)] = FastSeq.empty, - agg: Option[IndexedSeq[(String, Type)]] = None, - scan: Option[IndexedSeq[(String, Type)]] = None, - relational: IndexedSeq[(String, Type)] = FastSeq.empty, - ): Bindings = Bindings( - eval, - agg.map(AggEnv.Create(_)).getOrElse(AggEnv.Drop), - scan.map(AggEnv.Create(_)).getOrElse(AggEnv.Drop), - relational, - dropEval = true, - ) - /** Returns the environment of the `i`th child of `ir` given the environment of the parent node * `ir`. */ def get[E <: GenericBindingEnv[E, Type]](ir: BaseIR, i: Int, baseEnv: E): E = { - val res = ir match { - case ir: MatrixIR => childEnvMatrix(ir, i, baseEnv) - case ir: TableIR => childEnvTable(ir, i, baseEnv) - case ir: BlockMatrixIR => childEnvBlockMatrix(ir, i, baseEnv) - case ir: IR => childEnvValue(ir, i, baseEnv) - } - val newRes = get2(ir, i, baseEnv) - assert(res == newRes, s"\nnew = $newRes\n old = $res\n node = ${ir.getClass}") - res - } - - private def get2[E <: GenericBindingEnv[E, Type]](ir: BaseIR, i: Int, baseEnv: E): E = { val bindings = ir match { - case ir: MatrixIR => childEnvMatrix2(ir, i) - case ir: TableIR => childEnvTable2(ir, i) - case ir: BlockMatrixIR => childEnvBlockMatrix2(ir, i) - case ir: IR => childEnvValue2(ir, i) + case ir: MatrixIR => childEnvMatrix(ir, i) + case ir: TableIR => childEnvTable(ir, i) + case ir: BlockMatrixIR => childEnvBlockMatrix(ir, i) + case ir: IR => childEnvValue(ir, i) } baseEnv.extend(bindings) } @@ -154,7 +128,21 @@ object Bindings { def segregated[A](ir: BaseIR, i: Int, baseEnv: BindingEnv[A]): SegregatedBindingEnv[A, Type] = get(ir, i, SegregatedBindingEnv(baseEnv)) - private def childEnvMatrix2(ir: MatrixIR, i: Int): Bindings = { + // Create a `Bindings` which cannot see anything bound in the enclosing context. + private def inFreshScope( + eval: IndexedSeq[(String, Type)] = FastSeq.empty, + agg: Option[IndexedSeq[(String, Type)]] = None, + scan: Option[IndexedSeq[(String, Type)]] = None, + relational: IndexedSeq[(String, Type)] = FastSeq.empty, + ): Bindings = Bindings( + eval, + agg.map(AggEnv.Create(_)).getOrElse(AggEnv.Drop), + scan.map(AggEnv.Create(_)).getOrElse(AggEnv.Drop), + relational, + dropEval = true, + ) + + private def childEnvMatrix(ir: MatrixIR, i: Int): Bindings = { ir match { case MatrixMapRows(child, _) if i == 1 => Bindings.inFreshScope( @@ -209,62 +197,7 @@ object Bindings { } } - private def childEnvMatrix[E <: GenericBindingEnv[E, Type]](ir: MatrixIR, i: Int, _baseEnv: E) - : E = { - val baseEnv = _baseEnv.onlyRelational() - ir match { - case MatrixMapRows(child, _) if i == 1 => - baseEnv - .createAgg.createScan - .bindEval(child.typ.rowBindings: _*) - .bindEval("n_cols" -> TInt32) - .bindAgg(child.typ.entryBindings: _*) - .bindScan(child.typ.rowBindings: _*) - case MatrixFilterRows(child, _) if i == 1 => - baseEnv.bindEval(child.typ.rowBindings: _*) - case MatrixMapCols(child, _, _) if i == 1 => - baseEnv - .createAgg.createScan - .bindEval(child.typ.colBindings: _*) - .bindEval("n_rows" -> TInt64) - .bindAgg(child.typ.entryBindings: _*) - .bindScan(child.typ.colBindings: _*) - case MatrixFilterCols(child, _) if i == 1 => - baseEnv.bindEval(child.typ.colBindings: _*) - case MatrixMapEntries(child, _) if i == 1 => - baseEnv.bindEval(child.typ.entryBindings: _*) - case MatrixFilterEntries(child, _) if i == 1 => - baseEnv.bindEval(child.typ.entryBindings: _*) - case MatrixMapGlobals(child, _) if i == 1 => - baseEnv.bindEval(child.typ.globalBindings: _*) - case MatrixAggregateColsByKey(child, _, _) => - if (i == 1) - baseEnv - .bindEval(child.typ.rowBindings: _*) - .createAgg.bindAgg(child.typ.entryBindings: _*) - else if (i == 2) - baseEnv - .bindEval(child.typ.globalBindings: _*) - .createAgg.bindAgg(child.typ.colBindings: _*) - else baseEnv - case MatrixAggregateRowsByKey(child, _, _) => - if (i == 1) - baseEnv - .bindEval(child.typ.colBindings: _*) - .createAgg.bindAgg(child.typ.entryBindings: _*) - else if (i == 2) - baseEnv - .bindEval(child.typ.globalBindings: _*) - .createAgg.bindAgg(child.typ.rowBindings: _*) - else baseEnv - case RelationalLetMatrixTable(name, value, _) if i == 1 => - baseEnv.bindRelational(name -> value.typ) - case _ => - baseEnv - } - } - - private def childEnvTable2(ir: TableIR, i: Int): Bindings = { + private def childEnvTable(ir: TableIR, i: Int): Bindings = { ir match { case TableFilter(child, _) if i == 1 => Bindings.inFreshScope(child.typ.rowBindings) @@ -306,48 +239,7 @@ object Bindings { } } - private def childEnvTable[E <: GenericBindingEnv[E, Type]](ir: TableIR, i: Int, _baseEnv: E) - : E = { - val baseEnv = _baseEnv.onlyRelational() - ir match { - case TableFilter(child, _) if i == 1 => - baseEnv.bindEval(child.typ.rowBindings: _*) - case TableGen(contexts, globals, cname, gname, _, _, _) if i == 2 => - baseEnv.bindEval( - cname -> elementType(contexts.typ), - gname -> globals.typ, - ) - case TableMapGlobals(child, _) if i == 1 => - baseEnv.bindEval(child.typ.globalBindings: _*) - case TableMapRows(child, _) if i == 1 => - baseEnv - .bindEval(child.typ.rowBindings: _*) - .createScan.bindScan(child.typ.rowBindings: _*) - case TableAggregateByKey(child, _) if i == 1 => - baseEnv - .bindEval(child.typ.globalBindings: _*) - .createAgg.bindAgg(child.typ.rowBindings: _*) - case TableKeyByAndAggregate(child, _, _, _, _) => - if (i == 1) - baseEnv - .bindEval(child.typ.globalBindings: _*) - .createAgg.bindAgg(child.typ.rowBindings: _*) - else if (i == 2) - baseEnv.bindEval(child.typ.rowBindings: _*) - else baseEnv - case TableMapPartitions(child, g, p, _, _, _) if i == 1 => - baseEnv.bindEval( - g -> child.typ.globalType, - p -> TStream(child.typ.rowType), - ) - case RelationalLetTable(name, value, _) if i == 1 => - baseEnv.bindRelational(name -> value.typ) - case _ => - baseEnv - } - } - - private def childEnvBlockMatrix2(ir: BlockMatrixIR, i: Int): Bindings = { + private def childEnvBlockMatrix(ir: BlockMatrixIR, i: Int): Bindings = { ir match { case BlockMatrixMap(_, eltName, _, _) if i == 1 => Bindings.inFreshScope(FastSeq(eltName -> TFloat64)) @@ -360,25 +252,7 @@ object Bindings { } } - private def childEnvBlockMatrix[E <: GenericBindingEnv[E, Type]]( - ir: BlockMatrixIR, - i: Int, - _baseEnv: E, - ): E = { - val baseEnv = _baseEnv.onlyRelational() - ir match { - case BlockMatrixMap(_, eltName, _, _) if i == 1 => - baseEnv.bindEval(eltName -> TFloat64) - case BlockMatrixMap2(_, _, lName, rName, _, _) if i == 2 => - baseEnv.bindEval(lName -> TFloat64, rName -> TFloat64) - case RelationalLetBlockMatrix(name, value, _) if i == 1 => - baseEnv.bindRelational(name -> value.typ) - case _ => - baseEnv - } - } - - private def childEnvValue2(ir: IR, i: Int): Bindings = + private def childEnvValue(ir: IR, i: Int): Bindings = ir match { case Block(bindings, _) => val eval = mutable.ArrayBuilder.make[(String, Type)] @@ -617,166 +491,4 @@ object Bindings { Bindings(scan = AggEnv.Promote) else Bindings.empty } - - private def childEnvValue[E <: GenericBindingEnv[E, Type]](ir: IR, i: Int, baseEnv: E): E = - ir match { - case Block(bindings, _) => - var env = baseEnv - for (k <- 0 until i) bindings(k) match { - case Binding(name, value, scope) => - env = env.bindInScope(name, value.typ, scope) - } - if (i < bindings.length) bindings(i).scope match { - case Scope.EVAL => env - case Scope.AGG => env.promoteAgg - case Scope.SCAN => env.promoteScan - } - else env - case TailLoop(name, args, resultType, _) if i == args.length => - baseEnv - .bindEval(args.map { case (name, ir) => name -> ir.typ }: _*) - .bindEval(name -> TTuple(TTuple(args.map(_._2.typ): _*), resultType)) - case StreamMap(a, name, _) if i == 1 => - baseEnv.bindEval(name -> elementType(a.typ)) - case StreamZip(as, names, _, _, _) if i == as.length => - baseEnv.bindEval(names.zip(as.map(a => elementType(a.typ))): _*) - case StreamZipJoin(as, key, curKey, curVals, _) if i == as.length => - val eltType = tcoerce[TStruct](elementType(as.head.typ)) - baseEnv.bindEval( - curKey -> eltType.typeAfterSelectNames(key), - curVals -> TArray(eltType), - ) - case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, _) => - if (i == 1) { - val contextType = elementType(contexts.typ) - baseEnv.bindEval(ctxName -> contextType) - } else if (i == 2) { - val eltType = tcoerce[TStruct](elementType(makeProducer.typ)) - baseEnv.bindEval( - curKey -> eltType.typeAfterSelectNames(key), - curVals -> TArray(eltType), - ) - } else baseEnv - case StreamLeftIntervalJoin(left, right, _, _, lEltName, rEltName, _) if i == 2 => - baseEnv.bindEval( - lEltName -> elementType(left.typ), - rEltName -> TArray(elementType(right.typ)), - ) - case StreamFor(a, name, _) if i == 1 => - baseEnv.bindEval(name -> elementType(a.typ)) - case StreamFlatMap(a, name, _) if i == 1 => - baseEnv.bindEval(name -> elementType(a.typ)) - case StreamFilter(a, name, _) if i == 1 => - baseEnv.bindEval(name -> elementType(a.typ)) - case StreamTakeWhile(a, name, _) if i == 1 => - baseEnv.bindEval(name -> elementType(a.typ)) - case StreamDropWhile(a, name, _) if i == 1 => - baseEnv.bindEval(name -> elementType(a.typ)) - case StreamFold(a, zero, accumName, valueName, _) if i == 2 => - baseEnv.bindEval(accumName -> zero.typ, valueName -> elementType(a.typ)) - case StreamFold2(a, accum, valueName, _, _) => - if (i <= accum.length) - baseEnv - else if (i < 2 * accum.length + 1) - baseEnv - .bindEval(valueName -> elementType(a.typ)) - .bindEval(accum.map { case (name, value) => (name, value.typ) }: _*) - else - baseEnv.bindEval(accum.map { case (name, value) => (name, value.typ) }: _*) - case StreamBufferedAggregate(stream, _, _, _, name, _, _) if i > 0 => - baseEnv.bindEval(name -> elementType(stream.typ)) - case RunAggScan(a, name, _, _, _, _) if i == 2 || i == 3 => - baseEnv.bindEval(name -> elementType(a.typ)) - case StreamScan(a, zero, accumName, valueName, _) if i == 2 => - baseEnv.bindEval( - accumName -> zero.typ, - valueName -> elementType(a.typ), - ) - case StreamAggScan(a, name, _) if i == 1 => - val eltType = elementType(a.typ) - baseEnv - .bindEval(name -> eltType) - .createScan.bindScan(name -> eltType) - case StreamJoinRightDistinct(ll, rr, _, _, l, r, _, _) if i == 2 => - baseEnv.bindEval( - l -> elementType(ll.typ), - r -> elementType(rr.typ), - ) - case ArraySort(a, left, right, _) if i == 1 => - baseEnv.bindEval( - left -> elementType(a.typ), - right -> elementType(a.typ), - ) - case ArrayMaximalIndependentSet(a, Some((left, right, _))) if i == 1 => - val typ = tcoerce[TBaseStruct](elementType(a.typ)).types.head - val tupleType = TTuple(typ) - baseEnv.noEval.bindEval(left -> tupleType, right -> tupleType) - case AggArrayPerElement(a, elementName, indexName, _, _, isScan) => - if (i == 0) baseEnv.promoteAggOrScan(isScan) - else if (i == 1) - baseEnv - .bindEval(indexName -> TInt32) - .bindAggOrScan( - isScan, - elementName -> elementType(a.typ), - indexName -> TInt32, - ) - else baseEnv - case AggFold(zero, _, _, accumName, otherAccumName, isScan) => - if (i == 0) baseEnv.noAggOrScan(isScan) - else if (i == 1) baseEnv.promoteAggOrScan(isScan).bindEval(accumName -> zero.typ) - else baseEnv.noEval.noAggOrScan(isScan) - .bindEval(accumName -> zero.typ, otherAccumName -> zero.typ) - case NDArrayMap(nd, name, _) if i == 1 => - baseEnv.bindEval(name -> tcoerce[TNDArray](nd.typ).elementType) - case NDArrayMap2(l, r, lName, rName, _, _) if i == 2 => - baseEnv.bindEval( - lName -> tcoerce[TNDArray](l.typ).elementType, - rName -> tcoerce[TNDArray](r.typ).elementType, - ) - case CollectDistributedArray(contexts, globals, cname, gname, _, _, _, _) if i == 2 => - baseEnv.onlyRelational().bindEval( - cname -> elementType(contexts.typ), - gname -> globals.typ, - ) - case TableAggregate(child, _) => - if (i == 1) - baseEnv.onlyRelational() - .bindEval(child.typ.globalBindings: _*) - .createAgg.bindAgg(child.typ.rowBindings: _*) - else baseEnv.onlyRelational() - case MatrixAggregate(child, _) => - if (i == 1) - baseEnv.onlyRelational() - .bindEval(child.typ.globalBindings: _*) - .createAgg.bindAgg(child.typ.entryBindings: _*) - else baseEnv.onlyRelational() - case ApplyAggOp(init, _, _) => - if (i < init.length) baseEnv.noAgg - else baseEnv.promoteAgg - case ApplyScanOp(init, _, _) => - if (i < init.length) baseEnv.noScan - else baseEnv.promoteScan - case AggFilter(_, _, isScan) if i == 0 => - baseEnv.promoteAggOrScan(isScan) - case AggGroupBy(_, _, isScan) if i == 0 => - baseEnv.promoteAggOrScan(isScan) - case AggExplode(a, name, _, isScan) => - if (i == 0) baseEnv.promoteAggOrScan(isScan) - else baseEnv.bindAggOrScan(isScan, name -> elementType(a.typ)) - case StreamAgg(a, name, _) if i == 1 => - baseEnv.createAgg - .bindAgg(name -> elementType(a.typ)) - case RelationalLet(name, value, _) => - if (i == 1) - baseEnv.noAgg.noScan.bindRelational(name -> value.typ) - else - baseEnv.onlyRelational() - case _: LiftMeOut => - baseEnv.onlyRelational() - case _ => - if (UsesAggEnv(ir, i)) baseEnv.promoteAgg - else if (UsesScanEnv(ir, i)) baseEnv.promoteScan - else baseEnv - } } diff --git a/hail/src/main/scala/is/hail/expr/ir/Env.scala b/hail/src/main/scala/is/hail/expr/ir/Env.scala index 107fa7a42a4..6ffb70182e0 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Env.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Env.scala @@ -56,9 +56,6 @@ trait GenericBindingEnv[Self, V] { def promoteScan: Self - def promoteAggOrScan(isScan: Boolean): Self = - if (isScan) promoteScan else promoteAgg - def bindEval(bindings: (String, V)*): Self def noEval: Self @@ -67,9 +64,6 @@ trait GenericBindingEnv[Self, V] { def bindScan(bindings: (String, V)*): Self - def bindAggOrScan(isScan: Boolean, bindings: (String, V)*): Self = - if (isScan) bindScan(bindings: _*) else bindAgg(bindings: _*) - def bindInScope(name: String, v: V, scope: Int): Self = scope match { case Scope.EVAL => bindEval(name -> v) case Scope.AGG => bindAgg(name -> v) @@ -80,15 +74,10 @@ trait GenericBindingEnv[Self, V] { def createScan: Self - def createAggOrScan(isScan: Boolean): Self = - if (isScan) createScan else createAgg - def noAgg: Self def noScan: Self - def noAggOrScan(isScan: Boolean): Self = if (isScan) noScan else noAgg - def onlyRelational(keepAggCapabilities: Boolean = false): Self def bindRelational(bindings: (String, V)*): Self