Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Oct 3, 2022
1 parent 4cb72c8 commit 54355a8
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
3 changes: 2 additions & 1 deletion hail/python/hail/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
NDArrayShape, NDArrayReshape, NDArrayMap, NDArrayMap2, NDArrayRef, NDArraySlice, NDArraySVD, \
NDArrayReindex, NDArrayAgg, NDArrayMatMul, NDArrayQR, NDArrayInv, NDArrayConcat, NDArrayWrite, \
ArraySort, ToSet, ToDict, toArray, ToArray, CastToArray, toStream, ToStream, \
LowerBoundOnOrderedCollection, GroupByKey, StreamMap, StreamZip, \
LowerBoundOnOrderedCollection, GroupByKey, StreamTake, StreamMap, StreamZip, \
StreamFilter, StreamFlatMap, StreamFold, StreamScan, \
StreamJoinRightDistinct, StreamFor, AggFilter, AggExplode, AggGroupBy, \
AggArrayPerElement, BaseApplyAggOp, ApplyAggOp, ApplyScanOp, AggFold, Begin, \
Expand Down Expand Up @@ -174,6 +174,7 @@
'ToStream',
'LowerBoundOnOrderedCollection',
'GroupByKey',
'StreamTake',
'StreamMap',
'StreamZip',
'StreamFilter',
Expand Down
3 changes: 2 additions & 1 deletion hail/python/hail/ir/renderer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from hail import ir
from hail.expr.types import tstream
import abc
from typing import Sequence, MutableSequence, List, Set, Dict, Optional
from collections import namedtuple
Expand Down Expand Up @@ -220,7 +221,7 @@ def __call__(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]:

if child_idx >= len(node.children):
# mark node as visited at potential let insertion site
if not node.is_effectful():
if not (node.is_effectful() or isinstance(node.typ, tstream)):
bind_depth = frame.bind_depth()
if bind_depth < frame.min_value_binding_depth:
if frame.scan_scope:
Expand Down
15 changes: 15 additions & 0 deletions hail/python/test/hail/test_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,21 @@ def test_cse(self):
' (Ref __cse_1)))')
assert expected == CSERenderer()(x)

def test_stream_cse(self):
x = ir.StreamRange(ir.I32(0), ir.I32(10), ir.I32(1))
a1 = ir.ToArray(x)
a2 = ir.ToArray(x)
t = ir.MakeTuple([a1, a2])
expected = (
'(Let __cse_1 (I32 0)'
' (Let __cse_2 (I32 10)'
' (Let __cse_3 (I32 1)'
' (MakeTuple (0 1)'
' (ToArray (StreamRange 1 False (Ref __cse_1) (Ref __cse_2) (Ref __cse_3)))'
' (ToArray (StreamRange 1 False (Ref __cse_1) (Ref __cse_2) (Ref __cse_3)))))))'
)
assert expected == CSERenderer()(t)

def test_cse2(self):
x = ir.I32(5)
y = ir.I32(4)
Expand Down

0 comments on commit 54355a8

Please sign in to comment.