Skip to content

Commit

Permalink
fix: compile memo with proper nesting
Browse files Browse the repository at this point in the history
  • Loading branch information
vberlier committed Oct 14, 2023
1 parent 4484007 commit a88de67
Show file tree
Hide file tree
Showing 26 changed files with 209 additions and 70 deletions.
7 changes: 5 additions & 2 deletions bolt/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,8 +585,11 @@ def memo(
f"{storage} = _bolt_runtime.memo.registry[__file__][{acc.make_ref(node)}, {file_index}]"
)

path = f"_bolt_memo_invocation_path_{node.persistent_id.hex}"
acc.statement(f"{path} = _bolt_runtime.get_nested_location()")

invocation = f"_bolt_memo_invocation_{node.persistent_id.hex}"
acc.statement(f"{invocation} = {storage}[({' '.join(keys)})]")
acc.statement(f"{invocation} = {storage}[({path}, {' '.join(keys)})]")

acc.statement(f"if {invocation}.cached:")
with acc.block():
Expand All @@ -597,7 +600,7 @@ def memo(
acc.statement("else:")
with acc.block():
acc.statement(
f"with _bolt_runtime.memo.record(_bolt_runtime, {invocation}, __name__):"
f"with _bolt_runtime.memo.record(_bolt_runtime, {invocation}, {path}, __name__):"
)
with acc.block():
yield from visit_body(cast(AstRoot, body), acc)
Expand Down
16 changes: 6 additions & 10 deletions bolt/memo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
)
from beet.core.utils import log_time, remove_path
from mecha import AstChildren, AstRoot, DiagnosticCollection, Mecha, rule
from mecha.contrib.nesting import InplaceNestingPredicate

from .ast import AstMemo, AstMemoResult
from .emit import CommandEmitter
Expand Down Expand Up @@ -302,13 +301,12 @@ class MemoHandler:
mc: Mecha
registry: MemoRegistry
generate: Optional[Generator] = None
inplace_nesting_predicate: Optional[InplaceNestingPredicate] = None

def __post_init__(self):
self.mc.serialize.extend(serialize_memo_result)

def restore(self, emit: CommandEmitter, invocation: MemoInvocation):
if not self.generate or not self.inplace_nesting_predicate:
if self.generate is None:
return

invocation.epoch = self.registry.epoch_counter
Expand All @@ -324,9 +322,10 @@ def record(
self,
emit: CommandEmitter,
invocation: MemoInvocation,
invocation_path: str,
name: Optional[str] = None,
):
if not self.generate or not self.inplace_nesting_predicate or not name:
if self.generate is None or name is None:
yield
return

Expand All @@ -346,16 +345,14 @@ def record(
previous_queue = database.queue
previous_step = database.step
previous_current = database.current
previous_callback = self.inplace_nesting_predicate.callback
try:
database.session = set()
database.queue = []
self.inplace_nesting_predicate.callback = (
lambda target: previous_callback(target)
or target is previous_current
)
function = self.mc.compile(
root,
filename=compilation_unit.filename,
resource_location=invocation_path,
within=draft.data,
report=diagnostics,
initial_step=database.step + 1,
)
Expand All @@ -365,7 +362,6 @@ def record(
database.queue = previous_queue
database.step = previous_step
database.current = previous_current
self.inplace_nesting_predicate.callback = previous_callback

if output := function.text.rstrip():
emit.commands.append(AstMemoResult(serialized=output))
Expand Down
10 changes: 5 additions & 5 deletions bolt/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
rule,
)
from mecha.contrib.nested_location import NestedLocationResolver
from mecha.contrib.nesting import InplaceNestingPredicate
from mecha.contrib.relative_location import resolve_relative_location
from pathspec import PathSpec
from tokenstream import set_location
Expand Down Expand Up @@ -84,9 +83,11 @@ def __init__(self, ctx: Union[Context, Mecha]):
"generate_tree",
lambda *args, **kwargs: generate_tree(
(
root := kwargs.pop("root")
if "root" in kwargs
else self.modules.current_path
root := (
kwargs.pop("root")
if "root" in kwargs
else self.get_nested_location()
)
),
*args,
name=(
Expand All @@ -103,7 +104,6 @@ def __init__(self, ctx: Union[Context, Mecha]):
mc,
registry=ctx.inject(MemoRegistry),
generate=ctx.generate,
inplace_nesting_predicate=ctx.inject(InplaceNestingPredicate),
)

else:
Expand Down
6 changes: 6 additions & 0 deletions examples/bolt_memo2/beet.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
require:
- bolt
data_pack:
load: "src"
pipeline:
- mecha
4 changes: 4 additions & 0 deletions examples/bolt_memo2/src/data/demo/functions/foo.mcfunction
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
say ok
memo:
append function __name__:
say wat
6 changes: 6 additions & 0 deletions examples/bolt_memo3/beet.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
require:
- bolt
data_pack:
load: "src"
pipeline:
- mecha
7 changes: 7 additions & 0 deletions examples/bolt_memo3/src/data/demo/functions/foo.mcfunction
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
n = 5
execute function ~/{n}:
say 1
memo n:
say 2
append function ~/:
say 3
6 changes: 6 additions & 0 deletions examples/bolt_memo4/beet.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
require:
- bolt
data_pack:
load: "src"
pipeline:
- mecha
13 changes: 13 additions & 0 deletions examples/bolt_memo4/src/data/demo/functions/foo.mcfunction
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
def stuff():
memo:
append function __name__:
say something

say before
stuff()
stuff()
say after

function ./bar:
say bop
stuff()
16 changes: 8 additions & 8 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ include = ["bolt/py.typed"]

[tool.poetry.dependencies]
python = "^3.10"
beet = ">=0.95.4"
mecha = ">=0.78.2"
beet = ">=0.96.1"
mecha = ">=0.79.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.2"
Expand Down
12 changes: 7 additions & 5 deletions tests/snapshots/bolt__parse_313__1.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_bolt_lineno = [1, 20], [1, 3]
_bolt_lineno = [1, 21], [1, 3]
_bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d = None
_bolt_helper_children = _bolt_runtime.helpers['children']
_bolt_memo_storage_23b8c1e9392456de3eb13b9046685257 = None
Expand All @@ -11,11 +11,12 @@ with _bolt_runtime.scope() as _bolt_var6:
_bolt_var1 = _bolt_var1 + _bolt_var2
if _bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d is None:
_bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d = _bolt_runtime.memo.registry[__file__][_bolt_refs[0], 0]
_bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d = _bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d[(_bolt_var0, _bolt_var1,)]
_bolt_memo_invocation_path_bdd640fb06671ad11c80317fa3b1799d = _bolt_runtime.get_nested_location()
_bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d = _bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d[(_bolt_memo_invocation_path_bdd640fb06671ad11c80317fa3b1799d, _bolt_var0, _bolt_var1,)]
if _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d.cached:
_bolt_runtime.memo.restore(_bolt_runtime, _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d)
else:
with _bolt_runtime.memo.record(_bolt_runtime, _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d, __name__):
with _bolt_runtime.memo.record(_bolt_runtime, _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d, _bolt_memo_invocation_path_bdd640fb06671ad11c80317fa3b1799d, __name__):
pass
_bolt_var3 = list
_bolt_var3 = _bolt_var3()
Expand All @@ -25,11 +26,12 @@ with _bolt_runtime.scope() as _bolt_var6:
_bolt_var3 = _bolt_var3(_bolt_var4)
if _bolt_memo_storage_23b8c1e9392456de3eb13b9046685257 is None:
_bolt_memo_storage_23b8c1e9392456de3eb13b9046685257 = _bolt_runtime.memo.registry[__file__][_bolt_refs[1], 0]
_bolt_memo_invocation_23b8c1e9392456de3eb13b9046685257 = _bolt_memo_storage_23b8c1e9392456de3eb13b9046685257[(_bolt_var3,)]
_bolt_memo_invocation_path_23b8c1e9392456de3eb13b9046685257 = _bolt_runtime.get_nested_location()
_bolt_memo_invocation_23b8c1e9392456de3eb13b9046685257 = _bolt_memo_storage_23b8c1e9392456de3eb13b9046685257[(_bolt_memo_invocation_path_23b8c1e9392456de3eb13b9046685257, _bolt_var3,)]
if _bolt_memo_invocation_23b8c1e9392456de3eb13b9046685257.cached:
_bolt_runtime.memo.restore(_bolt_runtime, _bolt_memo_invocation_23b8c1e9392456de3eb13b9046685257)
else:
with _bolt_runtime.memo.record(_bolt_runtime, _bolt_memo_invocation_23b8c1e9392456de3eb13b9046685257, __name__):
with _bolt_runtime.memo.record(_bolt_runtime, _bolt_memo_invocation_23b8c1e9392456de3eb13b9046685257, _bolt_memo_invocation_path_23b8c1e9392456de3eb13b9046685257, __name__):
pass
_bolt_var7 = _bolt_helper_replace(_bolt_refs[2], commands=_bolt_helper_children(_bolt_var6))
---
Expand Down
5 changes: 3 additions & 2 deletions tests/snapshots/bolt__parse_314__1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ with _bolt_runtime.scope() as _bolt_var3:
_bolt_var1 = _bolt_var1 + _bolt_var2
if _bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d is None:
_bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d = _bolt_runtime.memo.registry[__file__][_bolt_refs[0], 0]
_bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d = _bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d[(_bolt_var0, _bolt_var1,)]
_bolt_memo_invocation_path_bdd640fb06671ad11c80317fa3b1799d = _bolt_runtime.get_nested_location()
_bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d = _bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d[(_bolt_memo_invocation_path_bdd640fb06671ad11c80317fa3b1799d, _bolt_var0, _bolt_var1,)]
if _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d.cached:
_bolt_runtime.memo.restore(_bolt_runtime, _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d)
else:
with _bolt_runtime.memo.record(_bolt_runtime, _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d, __name__):
with _bolt_runtime.memo.record(_bolt_runtime, _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d, _bolt_memo_invocation_path_bdd640fb06671ad11c80317fa3b1799d, __name__):
pass
_bolt_var4 = _bolt_helper_replace(_bolt_refs[1], commands=_bolt_helper_children(_bolt_var3))
---
Expand Down
7 changes: 4 additions & 3 deletions tests/snapshots/bolt__parse_315__1.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_bolt_lineno = [1, 18, 21], [1, 2, 3]
_bolt_lineno = [1, 19, 22], [1, 2, 3]
_bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d = None
_bolt_helper_interpolate_message = _bolt_runtime.helpers['interpolate_message']
_bolt_helper_children = _bolt_runtime.helpers['children']
Expand All @@ -10,11 +10,12 @@ with _bolt_runtime.scope() as _bolt_var4:
bar = _bolt_var1
if _bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d is None:
_bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d = _bolt_runtime.memo.registry[__file__][_bolt_refs[0], 0]
_bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d = _bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d[(foo, bar,)]
_bolt_memo_invocation_path_bdd640fb06671ad11c80317fa3b1799d = _bolt_runtime.get_nested_location()
_bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d = _bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d[(_bolt_memo_invocation_path_bdd640fb06671ad11c80317fa3b1799d, foo, bar,)]
if _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d.cached:
_bolt_runtime.memo.restore(_bolt_runtime, _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d)
else:
with _bolt_runtime.memo.record(_bolt_runtime, _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d, __name__):
with _bolt_runtime.memo.record(_bolt_runtime, _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d, _bolt_memo_invocation_path_bdd640fb06671ad11c80317fa3b1799d, __name__):
_bolt_var2 = foo
_bolt_var2 = _bolt_helper_interpolate_message(_bolt_var2, _bolt_refs[1])
_bolt_runtime.commands.append(_bolt_helper_replace(_bolt_refs[2], arguments=_bolt_helper_children([_bolt_var2])))
Expand Down
10 changes: 6 additions & 4 deletions tests/snapshots/bolt__parse_316__1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@ _bolt_helper_replace = _bolt_runtime.helpers['replace']
with _bolt_runtime.scope() as _bolt_var0:
if _bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d is None:
_bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d = _bolt_runtime.memo.registry[__file__][_bolt_refs[0], 0]
_bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d = _bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d[()]
_bolt_memo_invocation_path_bdd640fb06671ad11c80317fa3b1799d = _bolt_runtime.get_nested_location()
_bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d = _bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d[(_bolt_memo_invocation_path_bdd640fb06671ad11c80317fa3b1799d, )]
if _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d.cached:
_bolt_runtime.memo.restore(_bolt_runtime, _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d)
else:
with _bolt_runtime.memo.record(_bolt_runtime, _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d, __name__):
with _bolt_runtime.memo.record(_bolt_runtime, _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d, _bolt_memo_invocation_path_bdd640fb06671ad11c80317fa3b1799d, __name__):
pass
if _bolt_memo_storage_23b8c1e9392456de3eb13b9046685257 is None:
_bolt_memo_storage_23b8c1e9392456de3eb13b9046685257 = _bolt_runtime.memo.registry[__file__][_bolt_refs[1], 1]
_bolt_memo_invocation_23b8c1e9392456de3eb13b9046685257 = _bolt_memo_storage_23b8c1e9392456de3eb13b9046685257[()]
_bolt_memo_invocation_path_23b8c1e9392456de3eb13b9046685257 = _bolt_runtime.get_nested_location()
_bolt_memo_invocation_23b8c1e9392456de3eb13b9046685257 = _bolt_memo_storage_23b8c1e9392456de3eb13b9046685257[(_bolt_memo_invocation_path_23b8c1e9392456de3eb13b9046685257, )]
if _bolt_memo_invocation_23b8c1e9392456de3eb13b9046685257.cached:
_bolt_runtime.memo.restore(_bolt_runtime, _bolt_memo_invocation_23b8c1e9392456de3eb13b9046685257)
else:
with _bolt_runtime.memo.record(_bolt_runtime, _bolt_memo_invocation_23b8c1e9392456de3eb13b9046685257, __name__):
with _bolt_runtime.memo.record(_bolt_runtime, _bolt_memo_invocation_23b8c1e9392456de3eb13b9046685257, _bolt_memo_invocation_path_23b8c1e9392456de3eb13b9046685257, __name__):
pass
_bolt_var1 = _bolt_helper_replace(_bolt_refs[2], commands=_bolt_helper_children(_bolt_var0))
---
Expand Down
17 changes: 10 additions & 7 deletions tests/snapshots/bolt__parse_317__1.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_bolt_lineno = [1, 18, 22, 33, 37, 48], [1, 2, 3, 4, 5, 6]
_bolt_lineno = [1, 19, 23, 35, 39, 51], [1, 2, 3, 4, 5, 6]
_bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d = None
_bolt_helper_children = _bolt_runtime.helpers['children']
_bolt_helper_get_rebind = _bolt_runtime.helpers['get_rebind']
Expand All @@ -10,11 +10,12 @@ with _bolt_runtime.scope() as _bolt_var9:
foo = _bolt_var0
if _bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d is None:
_bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d = _bolt_runtime.memo.registry[__file__][_bolt_refs[0], 0]
_bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d = _bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d[(foo,)]
_bolt_memo_invocation_path_bdd640fb06671ad11c80317fa3b1799d = _bolt_runtime.get_nested_location()
_bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d = _bolt_memo_storage_bdd640fb06671ad11c80317fa3b1799d[(_bolt_memo_invocation_path_bdd640fb06671ad11c80317fa3b1799d, foo,)]
if _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d.cached:
_bolt_runtime.memo.restore(_bolt_runtime, _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d)
else:
with _bolt_runtime.memo.record(_bolt_runtime, _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d, __name__):
with _bolt_runtime.memo.record(_bolt_runtime, _bolt_memo_invocation_bdd640fb06671ad11c80317fa3b1799d, _bolt_memo_invocation_path_bdd640fb06671ad11c80317fa3b1799d, __name__):
_bolt_var1 = print
_bolt_var2 = foo
_bolt_var1 = _bolt_var1(_bolt_var2)
Expand All @@ -25,11 +26,12 @@ with _bolt_runtime.scope() as _bolt_var9:
foo = _bolt_rebind(foo)
if _bolt_memo_storage_23b8c1e9392456de3eb13b9046685257 is None:
_bolt_memo_storage_23b8c1e9392456de3eb13b9046685257 = _bolt_runtime.memo.registry[__file__][_bolt_refs[1], 0]
_bolt_memo_invocation_23b8c1e9392456de3eb13b9046685257 = _bolt_memo_storage_23b8c1e9392456de3eb13b9046685257[(foo,)]
_bolt_memo_invocation_path_23b8c1e9392456de3eb13b9046685257 = _bolt_runtime.get_nested_location()
_bolt_memo_invocation_23b8c1e9392456de3eb13b9046685257 = _bolt_memo_storage_23b8c1e9392456de3eb13b9046685257[(_bolt_memo_invocation_path_23b8c1e9392456de3eb13b9046685257, foo,)]
if _bolt_memo_invocation_23b8c1e9392456de3eb13b9046685257.cached:
_bolt_runtime.memo.restore(_bolt_runtime, _bolt_memo_invocation_23b8c1e9392456de3eb13b9046685257)
else:
with _bolt_runtime.memo.record(_bolt_runtime, _bolt_memo_invocation_23b8c1e9392456de3eb13b9046685257, __name__):
with _bolt_runtime.memo.record(_bolt_runtime, _bolt_memo_invocation_23b8c1e9392456de3eb13b9046685257, _bolt_memo_invocation_path_23b8c1e9392456de3eb13b9046685257, __name__):
_bolt_var4 = print
_bolt_var5 = foo
_bolt_var4 = _bolt_var4(_bolt_var5)
Expand All @@ -40,11 +42,12 @@ with _bolt_runtime.scope() as _bolt_var9:
foo = _bolt_rebind(foo)
if _bolt_memo_storage_bd9c66b3ad3c2d6d1a3d1fa7bc8960a9 is None:
_bolt_memo_storage_bd9c66b3ad3c2d6d1a3d1fa7bc8960a9 = _bolt_runtime.memo.registry[__file__][_bolt_refs[2], 1]
_bolt_memo_invocation_bd9c66b3ad3c2d6d1a3d1fa7bc8960a9 = _bolt_memo_storage_bd9c66b3ad3c2d6d1a3d1fa7bc8960a9[(foo,)]
_bolt_memo_invocation_path_bd9c66b3ad3c2d6d1a3d1fa7bc8960a9 = _bolt_runtime.get_nested_location()
_bolt_memo_invocation_bd9c66b3ad3c2d6d1a3d1fa7bc8960a9 = _bolt_memo_storage_bd9c66b3ad3c2d6d1a3d1fa7bc8960a9[(_bolt_memo_invocation_path_bd9c66b3ad3c2d6d1a3d1fa7bc8960a9, foo,)]
if _bolt_memo_invocation_bd9c66b3ad3c2d6d1a3d1fa7bc8960a9.cached:
_bolt_runtime.memo.restore(_bolt_runtime, _bolt_memo_invocation_bd9c66b3ad3c2d6d1a3d1fa7bc8960a9)
else:
with _bolt_runtime.memo.record(_bolt_runtime, _bolt_memo_invocation_bd9c66b3ad3c2d6d1a3d1fa7bc8960a9, __name__):
with _bolt_runtime.memo.record(_bolt_runtime, _bolt_memo_invocation_bd9c66b3ad3c2d6d1a3d1fa7bc8960a9, _bolt_memo_invocation_path_bd9c66b3ad3c2d6d1a3d1fa7bc8960a9, __name__):
_bolt_var7 = print
_bolt_var8 = foo
_bolt_var7 = _bolt_var7(_bolt_var8)
Expand Down
Loading

0 comments on commit a88de67

Please sign in to comment.