Skip to content

Commit

Permalink
Fixing nim-lang#15243, WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
yglukhov committed Sep 29, 2020
1 parent 57b7841 commit 27b68b2
Show file tree
Hide file tree
Showing 3 changed files with 317 additions and 14 deletions.
1 change: 1 addition & 0 deletions compiler/ast.nim
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ type
nfDefaultRefsParam # a default param value references another parameter
# the flag is applied to proc default values and to calls
nfExecuteOnReload # A top-level statement that will be executed during reloads
nfStateSplit # whether a statement needs to be split into states (used in closure iterator transformation)

TNodeFlags* = set[TNodeFlag]
TTypeFlag* = enum # keep below 32 for efficiency reasons (now: ~40)
Expand Down
299 changes: 297 additions & 2 deletions compiler/closureiters.nim
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,23 @@ type
curExcHandlingState: int # Negative for except, positive for finally
nearestFinally: int # Index of the nearest finally block. For try/except it
# is their finally. For finally it is parent finally. Otherwise -1
blockStack: seq[Block]

BlockKind = enum
# bkWhile
bkBlock
bkTry

Block = object
case kind: BlockKind
# of bkWhile:
# continues: seq[PNode]
of bkBlock:
discard
of bkTry:
discard
label: PNode
breaks: seq[PNode]

const
nkSkip = {nkEmpty..nkNilLit, nkTemplateDef, nkTypeSection, nkStaticStmt,
Expand Down Expand Up @@ -228,11 +245,16 @@ proc newState(ctx: var Ctx, n, gotoOut: PNode): int =
assert(gotoOut.len == 0)
gotoOut.add(ctx.g.newIntLit(gotoOut.info, result))

proc copyStateSplitFlag(toNode, fromNode: PNode) =
if nfStateSplit in fromNode.flags:
toNode.flags.incl(nfStateSplit)

proc toStmtList(n: PNode): PNode =
result = n
if result.kind notin {nkStmtList, nkStmtListExpr}:
result = newNodeI(nkStmtList, n.info)
result.add(n)
copyStateSplitFlag(result, n)

proc addGotoOut(n: PNode, gotoOut: PNode): PNode =
# Make sure `n` is a stmtlist, and ends with `gotoOut`
Expand All @@ -246,6 +268,7 @@ proc newTempVar(ctx: var Ctx, typ: PType): PSym =

proc hasYields(n: PNode): bool =
# TODO: This is very inefficient. It traverses the node, looking for nkYieldStmt.
# Instead it should utilize nfStateSplit flag once yield-in-expr-lowering maintains it
case n.kind
of nkYieldStmt:
result = true
Expand All @@ -257,6 +280,7 @@ proc hasYields(n: PNode): bool =
result = true
break

# TODO: Remove this
proc transformBreaksAndContinuesInWhile(ctx: var Ctx, n: PNode, before, after: PNode): PNode =
result = n
case n.kind
Expand All @@ -276,6 +300,7 @@ proc transformBreaksAndContinuesInWhile(ctx: var Ctx, n: PNode, before, after: P
for i in 0..<n.len:
n[i] = ctx.transformBreaksAndContinuesInWhile(n[i], before, after)

# TODO: Remove this
proc transformBreaksInBlock(ctx: var Ctx, n: PNode, label, after: PNode): PNode =
result = n
case n.kind
Expand Down Expand Up @@ -344,8 +369,12 @@ proc collectExceptState(ctx: var Ctx, n: PNode): PNode {.inline.} =
ifBranch.add(c[^1])
ifStmt.add(ifBranch)

copyStateSplitFlag(ifBranch, c)
copyStateSplitFlag(ifStmt, ifBranch)

if ifStmt.len != 0:
result = newTree(nkStmtList, ctx.newNullifyCurExc(n.info), ifStmt)
copyStateSplitFlag(result, ifStmt)
else:
result = ctx.g.emptyNode

Expand Down Expand Up @@ -432,6 +461,7 @@ proc newNotCall(g: ModuleGraph; e: PNode): PNode =
result.typ = g.getSysType(e.info, tyBool)

proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
# TODO: Fix this to maintain nfStateSplit flag
result = n
case n.kind
of nkSkip:
Expand Down Expand Up @@ -854,6 +884,200 @@ proc transformReturnsInTry(ctx: var Ctx, n: PNode): PNode =
for i in 0..<n.len:
n[i] = ctx.transformReturnsInTry(n[i])

proc transformClosureIterBody(ctx: var Ctx, n: PNode): PNode =
# Go through control flow control nodes if they need to be split, and split them.
# Whenever a node is split, its last state must end up in the end of ctx.states

result = n
if nfStateSplit notin n.flags and n.kind notin { nkYieldStmt, nkStmtList, nkStmtListExpr }:
assert(not n.hasYields, "Unexpected yield in " & $n.kind)
return

case n.kind
of nkYieldStmt:
result = newNodeI(nkStmtList, n.info)
result.add(n)
let gotoNewState = newNodeI(nkGotoState, n.info)
let nextState = newNode(nkStmtList)
let s = ctx.newState(nextState, gotoNewState)
result.add(gotoNewState)

of nkStmtList, nkStmtListExpr:
let numStates = ctx.states.len
for i in 0 ..< n.len:
n[i] = transformClosureIterBody(ctx, n[i])
if numStates != ctx.states.len: # split happened
let s = newNodeI(nkStmtList, n[i].info)
for j in i + 1..<n.len:
s.add(n[j])
n.sons.setLen(i + 1)
ctx.states[^1][1].add(transformClosureIterBody(ctx, s))
break

of nkIfStmt, nkCaseStmt:
var nodesToAddExitTo: seq[PNode]
for i in 0 ..< n.len:
let numStates = ctx.states.len
n[i] = transformClosureIterBody(ctx, n[i])
if numStates != ctx.states.len:
nodesToAddExitTo.add(ctx.states[^1][1])

assert(nodesToAddExitTo.len != 0, "internal error")
let gotoNewState = newNodeI(nkGotoState, n.info)
let nextState = newNode(nkStmtList)
let s = ctx.newState(nextState, gotoNewState)
result = toStmtList(result)
result.add(gotoNewState)
for n in nodesToAddExitTo:
n.add(gotoNewState)

of nkElifBranch, nkElifExpr, nkOfBranch, nkElse:
result[^1] = transformClosureIterBody(ctx, result[^1])

of nkWhileStmt:
# while e:
# s
# ->
# BEGIN_STATE:
# if e:
# s
# goto BEGIN_STATE
# else:
# goto OUT
# OUT:
let numStates = ctx.states.len
n[1] = transformClosureIterBody(ctx, n[1])
if numStates != ctx.states.len:
let bodyEnd = ctx.states[^1][1]

result = newNodeI(nkGotoState, n.info)

let ifNode = newNodeI(nkIfStmt, n.info)
let elifBranch = newNodeI(nkElifBranch, n.info)
elifBranch.add(n[0])

discard ctx.newState(toStmtList(ifNode), result) # BEGIN_STATE

bodyEnd.add(result)

elifBranch.add(n[1])
ifNode.add(elifBranch)

let gotoNewState = newNodeI(nkGotoState, n.info)
let nextState = newNode(nkStmtList)
let s = ctx.newState(nextState, gotoNewState)
let elseBranch = newTree(nkElse, gotoNewState)
ifNode.add(elseBranch)

of nkBlockStmt:
ctx.blockStack.add(Block(kind: bkBlock, label: n[0]))
let numStates = ctx.states.len
result = transformClosureIterBody(ctx, n[1])
if numStates != ctx.states.len:
# If there are breaks in this block, create a new state after the block,
# and replace breaks with goto new state
if ctx.blockStack[^1].breaks.len != 0:
let gotoNewState = newNodeI(nkGotoState, n.info)
let nextState = newNode(nkStmtList)
let s = ctx.newState(nextState, gotoNewState)
for b in ctx.blockStack[^1].breaks:
b[0] = gotoNewState

discard ctx.blockStack.pop()

of nkBreakStmt:
let label = n[0]
result = n.toStmtList()
var targetBlk = -1
for i in countdown(ctx.blockStack.high, ctx.blockStack.low):
if ctx.blockStack[i].kind == bkBlock and (ctx.blockStack[i].label.sym == label.sym):
targetBlk = i
break

assert(targetBlk != -1)
ctx.blockStack[targetBlk].breaks.add(result)

of nkTryStmt, nkHiddenTryStmt:
# See explanation above about how this works
ctx.hasExceptions = true

result = newNodeI(nkGotoState, n.info)
var tryBody = toStmtList(n[0])
var exceptBody = ctx.collectExceptState(n)
var finallyBody = newTree(nkStmtList, getFinallyNode(ctx, n))
finallyBody = ctx.transformReturnsInTry(finallyBody)
finallyBody.add(ctx.newEndFinallyNode(finallyBody.info))


# The following index calculation is based on the knowledge how state
# indexes are assigned
let tryIdx = ctx.states.len
var exceptIdx, finallyIdx: int
if exceptBody.kind != nkEmpty:
exceptIdx = -(tryIdx + 1)
finallyIdx = tryIdx + 2
else:
exceptIdx = tryIdx + 1
finallyIdx = tryIdx + 1

let outToFinally = newNodeI(nkGotoState, finallyBody.info)
let gotoOut = newNodeI(nkGotoState, finallyBody.info)

tryBody = tryBody.addGotoOut(outToFinally)
if exceptBody.kind != nkEmpty:
exceptBody = exceptBody.addGotoOut(outToFinally)
finallyBody = finallyBody.addGotoOut(gotoOut)

block: # Create initial states.
let oldExcHandlingState = ctx.curExcHandlingState
ctx.curExcHandlingState = exceptIdx
let realTryIdx = ctx.newState(tryBody, result)
assert(realTryIdx == tryIdx)

if exceptBody.kind != nkEmpty:
ctx.curExcHandlingState = finallyIdx
let realExceptIdx = ctx.newState(exceptBody, nil)
assert(realExceptIdx == -exceptIdx)

ctx.curExcHandlingState = oldExcHandlingState
let realFinallyIdx = ctx.newState(finallyBody, outToFinally)
assert(realFinallyIdx == finallyIdx)

block: # Subdivide the states
let oldNearestFinally = ctx.nearestFinally
ctx.nearestFinally = finallyIdx

let oldExcHandlingState = ctx.curExcHandlingState

ctx.curExcHandlingState = exceptIdx

if ctx.transformReturnsInTry(tryBody) != tryBody:
internalError(ctx.g.config, "transformReturnsInTry != tryBody")
if ctx.transformClosureIterBody(tryBody) != tryBody:
internalError(ctx.g.config, "transformClosureIteratorBody != tryBody")

ctx.curExcHandlingState = finallyIdx
ctx.addElseToExcept(exceptBody)
if ctx.transformReturnsInTry(exceptBody) != exceptBody:
internalError(ctx.g.config, "transformReturnsInTry != exceptBody")
if ctx.transformClosureIterBody(exceptBody) != exceptBody:
internalError(ctx.g.config, "transformClosureIteratorBody != exceptBody")

ctx.curExcHandlingState = oldExcHandlingState
ctx.nearestFinally = oldNearestFinally
if ctx.transformClosureIterBody(finallyBody) != finallyBody:
internalError(ctx.g.config, "transformClosureIteratorBody != finallyBody")

block: # Create out state
let outState = newNode(nkStmtList)
discard ctx.newState(outState, gotoOut)

else:
if n.hasYields:
echo renderTree(n)
assert(not n.hasYields, "Node has unexpected yield: " & $n.kind)

# TODO: Delete this function
proc transformClosureIteratorBody(ctx: var Ctx, n: PNode, gotoOut: PNode): PNode =
result = n
case n.kind
Expand Down Expand Up @@ -1283,6 +1507,71 @@ proc deleteEmptyStates(ctx: var Ctx) =
else:
inc i

type
MarkYieldsContext = object
nodeStack: seq[PNode]

proc isTryWithFinally(n: PNode): bool =
n.kind == nkTryStmt and n[^1].kind == nkFinally

proc markYields*(ctx: var MarkYieldsContext, n: PNode) =
if n.kind in nkSkip: return

ctx.nodeStack.add(n)
for c in n:
markYields(ctx, c)
discard ctx.nodeStack.pop()

if n.kind == nkYieldStmt:
for i in countdown(ctx.nodeStack.high, ctx.nodeStack.low):
if nfStateSplit in ctx.nodeStack[i].flags:
break
ctx.nodeStack[i].flags.incl(nfStateSplit)

proc markTryNodesToSplit(ctx: var MarkYieldsContext, n: PNode) =
# Mark nkTry blocks which might not have a yield, but have a finally and a break/continue stmt
# corresponding to a block/while with yields
case n.kind
of nkSkip:
discard

of nkWhileStmt, nkBlockStmt:
ctx.nodeStack.add(n)
for c in n:
markTryNodesToSplit(ctx, c)
discard ctx.nodeStack.pop()

of nkTryStmt:
ctx.nodeStack.add(n)
for c in n:
if c.kind != nkFinally:
markTryNodesToSplit(ctx, c)
discard ctx.nodeStack.pop()
if n[^1].kind == nkFinally:
markTryNodesToSplit(ctx, n[^1])

of nkBreakStmt:
let label = n[0].sym
var targetBlk = -1
for i in countdown(ctx.nodeStack.high, ctx.nodeStack.low):
if ctx.nodeStack[i].kind == nkBlockStmt and label == ctx.nodeStack[i][0].sym:
targetBlk = i
break

assert(targetBlk != -1)
if nfStateSplit in ctx.nodeStack[targetBlk].flags:
for i in targetBlk .. ctx.nodeStack.high:
if ctx.nodeStack[i].isTryWithFinally():
ctx.nodeStack[i].flags.incl(nfStateSplit)
else:
for c in n:
markTryNodesToSplit(ctx, c)

proc markYields(n: PNode) =
var ctx: MarkYieldsContext
markYields(ctx, n)
markTryNodesToSplit(ctx, n)

proc transformClosureIterator*(g: ModuleGraph; fn: PSym, n: PNode): PNode =
var ctx: Ctx
ctx.g = g
Expand All @@ -1298,16 +1587,22 @@ proc transformClosureIterator*(g: ModuleGraph; fn: PSym, n: PNode): PNode =
var n = n.toStmtList

discard ctx.newState(n, nil)
let gotoOut = newTree(nkGotoState, g.newIntLit(n.info, -1))

var ns = false
n = ctx.lowerStmtListExprs(n, ns)

if n.hasYieldsInExpressions():
internalError(ctx.g.config, "yield in expr not lowered")

# TODO: This should go before lowerStmtListExprs,
# but first lowerStmtListExprs has to be fixed to maintain the nfStateSplitFlag.
# Just for performance considerations.
markYields(n)
if nfStateSplit notin n.flags:
return n

# Splitting transformation
discard ctx.transformClosureIteratorBody(n, gotoOut)
discard ctx.transformClosureIterBody(n)

# Optimize empty states away
ctx.deleteEmptyStates()
Expand Down
Loading

0 comments on commit 27b68b2

Please sign in to comment.