diff --git a/hail/python/hail/ir/__init__.py b/hail/python/hail/ir/__init__.py index 38f81655bd2d..09b3c49d8059 100644 --- a/hail/python/hail/ir/__init__.py +++ b/hail/python/hail/ir/__init__.py @@ -11,7 +11,7 @@ NDArrayShape, NDArrayReshape, NDArrayMap, NDArrayMap2, NDArrayRef, NDArraySlice, NDArraySVD, \ NDArrayReindex, NDArrayAgg, NDArrayMatMul, NDArrayQR, NDArrayInv, NDArrayConcat, NDArrayWrite, \ ArraySort, ToSet, ToDict, toArray, ToArray, CastToArray, toStream, ToStream, \ - LowerBoundOnOrderedCollection, GroupByKey, StreamMap, StreamZip, \ + LowerBoundOnOrderedCollection, GroupByKey, StreamTake, StreamMap, StreamZip, \ StreamFilter, StreamFlatMap, StreamFold, StreamScan, \ StreamJoinRightDistinct, StreamFor, AggFilter, AggExplode, AggGroupBy, \ AggArrayPerElement, BaseApplyAggOp, ApplyAggOp, ApplyScanOp, AggFold, Begin, \ @@ -174,6 +174,7 @@ 'ToStream', 'LowerBoundOnOrderedCollection', 'GroupByKey', + 'StreamTake', 'StreamMap', 'StreamZip', 'StreamFilter', diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index b31031361477..3d3d1aa470e8 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -1,4 +1,5 @@ from hail import ir +from hail.expr.types import tstream import abc from typing import Sequence, MutableSequence, List, Set, Dict, Optional from collections import namedtuple @@ -220,7 +221,7 @@ def __call__(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: if child_idx >= len(node.children): # mark node as visited at potential let insertion site - if not node.is_effectful(): + if not (node.is_effectful() or isinstance(node.typ, tstream)): bind_depth = frame.bind_depth() if bind_depth < frame.min_value_binding_depth: if frame.scan_scope: diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index b4f101101bb7..547fd79deb1c 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -418,6 +418,21 @@ def test_cse(self): ' (Ref __cse_1)))') assert expected == CSERenderer()(x) + def test_stream_cse(self): + x = ir.StreamRange(ir.I32(0), ir.I32(10), ir.I32(1)) + a1 = ir.ToArray(x) + a2 = ir.ToArray(x) + t = ir.MakeTuple([a1, a2]) + expected = ( + '(Let __cse_1 (I32 0)' + ' (Let __cse_2 (I32 10)' + ' (Let __cse_3 (I32 1)' + ' (MakeTuple (0 1)' + ' (ToArray (StreamRange 1 False (Ref __cse_1) (Ref __cse_2) (Ref __cse_3)))' + ' (ToArray (StreamRange 1 False (Ref __cse_1) (Ref __cse_2) (Ref __cse_3)))))))' + ) + assert expected == CSERenderer()(t) + def test_cse2(self): x = ir.I32(5) y = ir.I32(4)