From a3465c7a4df84e64530666654d4b3b8c18782e2c Mon Sep 17 00:00:00 2001 From: Antonis Geralis <43617260+planetis-m@users.noreply.github.com> Date: Thu, 3 Dec 2020 21:32:18 +0200 Subject: [PATCH] add collect with infered init, refs #16078 fixes #14332 (#16089) * changelog * add testcase, fixes #14332 --- changelog.md | 3 ++ lib/pure/sugar.nim | 100 ++++++++++++++++++++++++++-------------- tests/stdlib/tsugar.nim | 22 +++++++-- 3 files changed, 88 insertions(+), 37 deletions(-) diff --git a/changelog.md b/changelog.md index 05958118bcb1f..bb49dc7f6ab03 100644 --- a/changelog.md +++ b/changelog.md @@ -22,6 +22,9 @@ literals remain in the "raw" string form so that client code can easily treat small and large numbers uniformly. +- Added an overload for the `collect` macro that inferes the container type based + on the syntax of the last expression. Works with std seqs, tables and sets. + - Added `randState` template that exposes the default random number generator. Useful for library authors. diff --git a/lib/pure/sugar.nim b/lib/pure/sugar.nim index 047104972233f..68672951518cc 100644 --- a/lib/pure/sugar.nim +++ b/lib/pure/sugar.nim @@ -58,7 +58,7 @@ macro `=>`*(p, b: untyped): untyped = runnableExamples: proc passTwoAndTwo(f: (int, int) -> int): int = f(2, 2) - + doAssert passTwoAndTwo((x, y) => x + y) == 4 type @@ -270,54 +270,79 @@ since (1, 1): underscoredCalls(result, calls, tmp) result.add tmp - -proc transLastStmt(n, res, bracketExpr: NimNode): (NimNode, NimNode, NimNode) {.since: (1, 1).} = +proc trans(n, res, bracketExpr: NimNode): (NimNode, NimNode, NimNode) {.since: (1, 1).} = # Looks for the last statement of the last statement, etc... case n.kind - of nnkIfExpr, nnkIfStmt, nnkTryStmt, nnkCaseStmt: + of nnkIfExpr, nnkIfStmt, nnkTryStmt, nnkCaseStmt, nnkWhenStmt: result[0] = copyNimTree(n) result[1] = copyNimTree(n) result[2] = copyNimTree(n) - for i in ord(n.kind == nnkCaseStmt)..= 1: - (result[0][^1], result[1][^1], result[2][^1]) = transLastStmt(n[^1], res, bracketExpr) + (result[0][^1], result[1][^1], result[2][^1]) = trans(n[^1], + res, bracketExpr) of nnkTableConstr: result[1] = n[0][0] result[2] = n[0][1] + if bracketExpr.len == 0: + bracketExpr.add(ident"initTable") # don't import tables if bracketExpr.len == 1: - bracketExpr.add([newCall(bindSym"typeof", newEmptyNode()), newCall( - bindSym"typeof", newEmptyNode())]) + bracketExpr.add([newCall(bindSym"typeof", + newEmptyNode()), newCall(bindSym"typeof", newEmptyNode())]) template adder(res, k, v) = res[k] = v result[0] = getAst(adder(res, n[0][0], n[0][1])) of nnkCurly: result[2] = n[0] + if bracketExpr.len == 0: + bracketExpr.add(ident"initHashSet") if bracketExpr.len == 1: bracketExpr.add(newCall(bindSym"typeof", newEmptyNode())) template adder(res, v) = res.incl(v) result[0] = getAst(adder(res, n[0])) else: result[2] = n + if bracketExpr.len == 0: + bracketExpr.add(bindSym"newSeq") if bracketExpr.len == 1: bracketExpr.add(newCall(bindSym"typeof", newEmptyNode())) template adder(res, v) = res.add(v) result[0] = getAst(adder(res, n)) +proc collectImpl(init, body: NimNode): NimNode {.since: (1, 1).} = + let res = genSym(nskVar, "collectResult") + var bracketExpr: NimNode + if init != nil: + expectKind init, {nnkCall, nnkIdent, nnkSym} + bracketExpr = newTree(nnkBracketExpr, + if init.kind == nnkCall: freshIdentNodes(init[0]) else: freshIdentNodes(init)) + else: + bracketExpr = newTree(nnkBracketExpr) + let (resBody, keyType, valueType) = trans(body, res, bracketExpr) + if bracketExpr.len == 3: + bracketExpr[1][1] = keyType + bracketExpr[2][1] = valueType + else: + bracketExpr[1][1] = valueType + let call = newTree(nnkCall, bracketExpr) + if init != nil and init.kind == nnkCall: + for i in 1 ..< init.len: + call.add init[i] + result = newTree(nnkStmtListExpr, newVarStmt(res, call), resBody, res) + macro collect*(init, body: untyped): untyped {.since: (1, 1).} = - ## Comprehension for seq/set/table collections. ``init`` is - ## the init call, and so custom collections are supported. + ## Comprehension for seqs/sets/tables. ## - ## The last statement of ``body`` has special syntax that specifies - ## the collection's add operation. Use ``{e}`` for set's ``incl``, - ## ``{k: v}`` for table's ``[]=`` and ``e`` for seq's ``add``. - ## - ## The ``init`` proc can be called with any number of arguments, - ## i.e. ``initTable(initialSize)``. + ## The last expression of `body` has special syntax that specifies + ## the collection's add operation. Use `{e}` for set's `incl`, + ## `{k: v}` for table's `[]=` and `e` for seq's `add`. + # analyse the body, find the deepest expression 'it' and replace it via + # 'result.add it' runnableExamples: import sets, tables let data = @["bird", "word"] @@ -343,20 +368,27 @@ macro collect*(init, body: untyped): untyped {.since: (1, 1).} = for i, d in data.pairs: {i: d} assert z == {0: "bird", 1: "word"}.toTable - # analyse the body, find the deepest expression 'it' and replace it via - # 'result.add it' - let res = genSym(nskVar, "collectResult") - expectKind init, {nnkCall, nnkIdent, nnkSym} - let bracketExpr = newTree(nnkBracketExpr, - if init.kind == nnkCall: init[0] else: init) - let (resBody, keyType, valueType) = transLastStmt(body, res, bracketExpr) - if bracketExpr.len == 3: - bracketExpr[1][1] = keyType - bracketExpr[2][1] = valueType - else: - bracketExpr[1][1] = valueType - let call = newTree(nnkCall, bracketExpr) - if init.kind == nnkCall: - for i in 1 ..< init.len: - call.add init[i] - result = newTree(nnkStmtListExpr, newVarStmt(res, call), resBody, res) + result = collectImpl(init, body) + +macro collect*(body: untyped): untyped {.since: (1, 5).} = + ## Same as `collect` but without an `init` parameter. + runnableExamples: + import sets, tables + # Seq: + let data = @["bird", "word"] + let k = collect: + for i, d in data.pairs: + if i mod 2 == 0: d + + assert k == @["bird"] + ## HashSet: + let n = collect: + for d in data.items: {d} + + assert n == data.toHashSet + ## Table: + let m = collect: + for i, d in data.pairs: {i: d} + + assert m == {0: "bird", 1: "word"}.toTable + result = collectImpl(nil, body) \ No newline at end of file diff --git a/tests/stdlib/tsugar.nim b/tests/stdlib/tsugar.nim index a6fd70f3418ad..cca1fe75a7cca 100644 --- a/tests/stdlib/tsugar.nim +++ b/tests/stdlib/tsugar.nim @@ -40,7 +40,8 @@ import random const b = @[0, 1, 2] let c = b.dup shuffle() -doAssert c.len == 3 +doAssert b[0] == 0 +doAssert b[1] == 1 #test collect import sets, tables @@ -83,12 +84,27 @@ let z = collect(newSeq): else: d assert z == @["word", "word"] - proc tforum = let ans = collect(newSeq): for y in 0..10: if y mod 5 == 2: for x in 0..y: x - tforum() + +block: + let x = collect: + for d in data.items: + when d is int: "word" + else: d + assert x == @["bird", "word"] +assert collect(for (i, d) in pairs(data): (i, d)) == @[(0, "bird"), (1, "word")] +assert collect(for d in data.items: (try: parseInt(d) except: 0)) == @[0, 0] +assert collect(for (i, d) in pairs(data): {i: d}) == {1: "word", + 0: "bird"}.toTable +assert collect(for d in data.items: {d}) == data.toHashSet + +# bug #14332 +template foo = + discard collect(newSeq, for i in 1..3: i) +foo() \ No newline at end of file