diff --git a/bolt/__init__.py b/bolt/__init__.py index 74ce947..046a18c 100644 --- a/bolt/__init__.py +++ b/bolt/__init__.py @@ -10,3 +10,4 @@ from .parse import * from .plugin import * from .runtime import * +from .semantics import * diff --git a/bolt/parse.py b/bolt/parse.py index aa0afac..f014d11 100644 --- a/bolt/parse.py +++ b/bolt/parse.py @@ -1,29 +1,23 @@ __all__ = [ "get_bolt_parsers", - "get_stream_builtins", - "get_stream_identifiers", - "get_stream_pending_identifiers", - "get_stream_identifiers_storage", - "get_stream_pending_identifiers_storage", - "get_stream_deferred_locals", - "get_stream_branch_scope", + "get_stream_lexical_scope", "get_stream_macro_scope", "get_stream_pending_macros", "ToplevelHandler", "create_bolt_root_parser", + "create_bolt_command_parser", "UndefinedIdentifier", "UndefinedIdentifierErrorHandler", "BranchScopeManager", - "check_final_expression", "FinalExpressionParser", "InterpolationParser", "DisableInterpolationParser", - "parse_statement", + "ForkLexicalScopeParser", + "AssignmentStatementParser", "parse_decorator", "AssignmentTargetParser", "IfElseLoweringParser", "BreakContinueConstraint", - "FlushPendingIdentifiersParser", "parse_deferred_root", "parse_function_signature", "parse_proc_macro_signature", @@ -34,7 +28,6 @@ "parse_class_name", "parse_class_bases", "parse_class_root", - "ClassBodyScoping", "parse_del_target", "parse_identifier", "TrailingCommaParser", @@ -46,7 +39,7 @@ "GlobalNonlocalHandler", "StatementSubcommandHandler", "DeferredRootBacktracker", - "FunctionConstraint", + "LexicalScopeConstraint", "RootScopeHandler", "BinaryParser", "UnaryParser", @@ -63,7 +56,7 @@ import re from dataclasses import dataclass, field, replace -from typing import Any, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, cast +from typing import Any, Dict, FrozenSet, List, Literal, Optional, Set, Tuple, Type, cast from beet.core.utils import extra_field from mecha import ( @@ -161,7 +154,17 @@ STRING_PATTERN, TRUE_PATTERN, ) -from .utils import internal, suggest_typo +from .semantics import ( + ClassScope, + FunctionScope, + GlobalScope, + LexicalScope, + MacroScope, + ProcMacroScope, + UnboundLocalIdentifier, + UndefinedIdentifier, +) +from .utils import internal IMPORT_REGEX = re.compile(rf"^{MODULE_PATTERN}$") @@ -189,11 +192,7 @@ def get_bolt_parsers( macro_handler=macro_handler, ), "nested_root": create_bolt_root_parser(parsers["nested_root"], macro_handler), - "command": StatementSubcommandHandler( - UndefinedIdentifierErrorHandler( - ImportStatementHandler(GlobalNonlocalHandler(macro_handler), modules) - ) - ), + "command": create_bolt_command_parser(macro_handler, modules), "command:argument:bolt:if_block": delegate("bolt:if_block"), "command:argument:bolt:elif_condition": delegate("bolt:elif_condition"), "command:argument:bolt:elif_block": delegate("bolt:elif_block"), @@ -222,28 +221,52 @@ def get_bolt_parsers( ################################################################################ "bolt:if_block": BranchScopeManager( parser=delegate("nested_root"), - update_before=True, - ), - "bolt:elif_condition": BranchScopeManager( - parser=delegate("bolt:expression"), - mask=True, - update_after=True, + init=True, + fork=True, ), + "bolt:elif_condition": BranchScopeManager(delegate("bolt:expression")), "bolt:elif_block": BranchScopeManager( parser=delegate("nested_root"), - mask=True, + fork=True, ), "bolt:else_block": BranchScopeManager( parser=delegate("nested_root"), - mask=True, + fork=True, + ), + "bolt:statement": AlternativeParser( + [ + ForkLexicalScopeParser( + FlushPendingBindingsParser( + FinalExpressionParser( + AssignmentStatementParser( + delegate("bolt:augmented_assignment_target") + ) + ), + after=True, + ) + ), + ForkLexicalScopeParser( + FlushPendingBindingsParser( + FinalExpressionParser( + AssignmentStatementParser( + delegate("bolt:assignment_target"), + binding_only=True, + ) + ), + after=True, + ) + ), + FinalExpressionParser( + AssignmentStatementParser(delegate("bolt:expression")) + ), + FinalExpressionParser(delegate("bolt:decorator")), + ] ), - "bolt:statement": parse_statement, "bolt:decorator": parse_decorator, - "bolt:assignment_target": AssignmentTargetParser( - allow_undefined_identifiers=True, - allow_multiple=True, + "bolt:assignment_target": AssignmentTargetParser(), + "bolt:augmented_assignment_target": AssignmentTargetParser( + require_local_binding=True, require_single=True ), - "bolt:augmented_assignment_target": AssignmentTargetParser(), "bolt:deferred_root": parse_deferred_root, "bolt:function_signature": parse_function_signature, "bolt_macro_literal": delegate("bolt:macro_literal"), @@ -265,15 +288,16 @@ def get_bolt_parsers( "bolt:proc_macro": ProcMacroParser(modules), "bolt:class_name": parse_class_name, "bolt:class_bases": parse_class_bases, - "bolt:class_root": parse_class_root, + "bolt:class_root": FlushPendingBindingsParser(parse_class_root, after=True), "bolt:del_target": parse_del_target, "bolt:interpolation": BuiltinCallRestriction( - PrimaryParser(delegate("bolt:identifier"), truncate=True) + PrimaryParser(delegate("bolt:identifier"), truncate=True), + builtins=modules.builtins, ), "bolt:identifier": parse_identifier, "bolt:with_expression": TrailingCommaParser(delegate("bolt:expression")), "bolt:with_target": TrailingCommaParser( - AssignmentTargetParser(allow_undefined_identifiers=True) + AssignmentTargetParser(require_single=True) ), "bolt:import": AlternativeParser( [ @@ -489,39 +513,9 @@ def get_bolt_parsers( } -def get_stream_builtins(stream: TokenStream) -> Set[str]: - """Return the set of builtin identifiers currently associated with the token stream.""" - return stream.data.setdefault("builtins", set()) - - -def get_stream_identifiers(stream: TokenStream) -> Set[str]: - """Return the set of accessible identifiers currently associated with the token stream.""" - return stream.data.setdefault("identifiers", set()) - - -def get_stream_pending_identifiers(stream: TokenStream) -> Set[str]: - """Return the set of pending identifiers currently associated with the token stream.""" - return stream.data.setdefault("pending_identifiers", set()) - - -def get_stream_identifiers_storage(stream: TokenStream) -> Dict[str, str]: - """Return a dict that associates storage type to identifiers.""" - return stream.data.setdefault("identifiers_storage", {}) - - -def get_stream_pending_identifiers_storage(stream: TokenStream) -> Dict[str, str]: - """Return a dict that associates pending storage type to identifiers.""" - return stream.data.setdefault("pending_identifiers_storage", {}) - - -def get_stream_deferred_locals(stream: TokenStream) -> Set[str]: - """Return a set of identifiers that will be available in the next deferred root.""" - return stream.data.setdefault("deferred_locals", set()) - - -def get_stream_branch_scope(stream: TokenStream) -> Set[str]: - """Return the identifiers available inside the branch.""" - return stream.data.setdefault("branch_scope", set()) +def get_stream_lexical_scope(stream: TokenStream) -> LexicalScope: + """Return the lexical scope currently associated with the token stream.""" + return stream.data.setdefault("lexical_scope", LexicalScope()) def get_stream_macro_scope(stream: TokenStream) -> Dict[str, AstMacro]: @@ -534,11 +528,6 @@ def get_stream_pending_macros(stream: TokenStream) -> List[AstMacro]: return stream.data.setdefault("pending_macros", []) -def get_stream_class_scope(stream: TokenStream) -> Optional[Set[str]]: - """Return the set of outer identifiers available inside the class.""" - return stream.data.get("class_scope") - - @dataclass class ToplevelHandler: """Handle toplevel root node.""" @@ -550,13 +539,17 @@ class ToplevelHandler: def __call__(self, stream: TokenStream) -> Any: current = self.modules.database.current + resource_location = self.modules.database[current].resource_location - with self.modules.parse_push(current), stream.provide( - resource_location=self.modules.database[current].resource_location, - builtins=self.modules.builtins, + global_scope = GlobalScope( identifiers=set(self.modules.globals) | self.modules.builtins | {"__name__"}, + ) + + with self.modules.parse_push(current), stream.provide( + resource_location=resource_location, + lexical_scope=global_scope.push(LexicalScope), ): node = self.parser(stream) @@ -569,7 +562,8 @@ def __call__(self, stream: TokenStream) -> Any: def create_bolt_root_parser(parser: Parser, macro_handler: "MacroHandler"): - """Return parser for the root node for bolt.""" + """Compose root parsers.""" + parser = FlushPendingBindingsParser(parser, before=True) parser = IfElseLoweringParser(parser) parser = BreakContinueConstraint( parser, @@ -578,8 +572,9 @@ def create_bolt_root_parser(parser: Parser, macro_handler: "MacroHandler"): ("for", "target", "in", "iterable", "body"), }, ) - parser = FunctionConstraint( + parser = LexicalScopeConstraint( parser, + type=FunctionScope, command_identifiers={ "return", "return:value", @@ -590,8 +585,6 @@ def create_bolt_root_parser(parser: Parser, macro_handler: "MacroHandler"): "nonlocal:subcommand", }, ) - parser = FlushPendingIdentifiersParser(parser) - parser = ClassBodyScoping(parser) parser = DeferredRootBacktracker(parser, macro_handler=macro_handler) parser = DecoratorResolver(parser) parser = ProcMacroExpansion(parser) @@ -599,22 +592,13 @@ def create_bolt_root_parser(parser: Parser, macro_handler: "MacroHandler"): return parser -class UndefinedIdentifier(InvalidSyntax): - """Raised when an identifier is not defined.""" - - identifier: str - possible_identifiers: Iterable[str] - - def __init__(self, identifier: str, identifiers: Iterable[str]): - super().__init__(identifier, identifiers) - self.identifier = identifier - self.identifiers = identifiers - - def __str__(self) -> str: - msg = f'Identifier "{self.identifier}" is not defined.' - if suggestion := suggest_typo(self.identifier, self.identifiers): - msg += f" Did you mean {suggestion}?" - return msg +def create_bolt_command_parser(parser: Parser, modules: ModuleManager): + """Compose command parsers.""" + parser = GlobalNonlocalHandler(parser) + parser = ImportStatementHandler(parser, modules) + parser = UndefinedIdentifierErrorHandler(parser) + parser = StatementSubcommandHandler(parser) + return parser @dataclass @@ -629,7 +613,9 @@ def __call__(self, stream: TokenStream) -> Any: except UndefinedIdentifier: raise except InvalidSyntax as exc: - for alt in exc.alternatives.get(UndefinedIdentifier, []): + alts = list(exc.alternatives.get(UnboundLocalIdentifier, [])) + alts += exc.alternatives.get(UndefinedIdentifier, []) + for alt in alts: if alt.end_location.pos + 1 >= exc.location.pos: # kind of a cheat alt.notes.append(str(exc)) raise alt from None @@ -641,55 +627,26 @@ class BranchScopeManager: """Parser that manages accessible identifiers for conditional branches.""" parser: Parser - mask: bool = False - update_before: bool = False - update_after: bool = False + init: bool = False + fork: bool = False def __call__(self, stream: TokenStream) -> Any: - identifiers = get_stream_identifiers(stream) - identifiers_storage = get_stream_identifiers_storage(stream) - branch_scope = get_stream_branch_scope(stream) - - mask = branch_scope and identifiers - branch_scope - mask_storage = { - identifier: identifiers_storage[identifier] - for identifier in set(identifiers_storage) & mask - } - - if self.mask: - identifiers -= mask - for identifier in mask_storage: - del identifiers_storage[identifier] - - if self.update_before: - branch_scope = identifiers.copy() - - try: - return self.parser(stream) - finally: - if self.update_after: - branch_scope = identifiers.copy() - - if self.mask: - identifiers.update(mask) - identifiers_storage.update(mask_storage) + lexical_scope = get_stream_lexical_scope(stream) - stream.data["branch_scope"] = branch_scope + if self.init or not lexical_scope.next_branch: + lexical_scope.next_branch = lexical_scope.fork() + branch_scope = lexical_scope.next_branch + if self.fork: + branch_scope = branch_scope.fork() -def check_final_expression(stream: TokenStream): - current_index = stream.index - - if consume_line_continuation(stream): - exc = InvalidSyntax("Invalid indent following final expression.") - raise set_location(exc, stream.get()) + with stream.provide(lexical_scope=branch_scope): + node = self.parser(stream) - next_token = stream.get() - if next_token and not next_token.match("newline", "eof"): - exc = InvalidSyntax("Trailing input following final expression.") - raise set_location(exc, next_token) + if self.fork: + lexical_scope.reconcile(branch_scope) - stream.index = current_index + return node @dataclass @@ -755,80 +712,63 @@ def __call__(self, stream: TokenStream) -> Any: return self.parser(stream) -def parse_statement(stream: TokenStream) -> Any: - """Parse statement.""" - identifiers = get_stream_identifiers(stream) - pending_identifiers = get_stream_pending_identifiers(stream) - identifiers_storage = get_stream_identifiers_storage(stream) - pending_identifiers_storage = get_stream_pending_identifiers_storage(stream) +@dataclass +class ForkLexicalScopeParser: + """Parser forking lexical scope.""" - for parser, alternative in stream.choose( - "bolt:augmented_assignment_target", - "bolt:assignment_target", - "bolt:expression", - "bolt:decorator", - ): - with alternative: - pending_identifiers.clear() - pending_identifiers_storage.clear() - node = delegate(parser, stream) + parser: Parser - pattern = r"=(?!=)|\+=|-=|\*=|//=|/=|%=|&=|\|=|\^=|<<=|>>=|\*\*=" + def __call__(self, stream: TokenStream) -> Any: + lexical_scope = get_stream_lexical_scope(stream) + temporary_scope = lexical_scope.fork() + with stream.provide(lexical_scope=temporary_scope): + node = self.parser(stream) + lexical_scope.reconcile(temporary_scope) + return node - if isinstance(node, AstTarget): - if parser == "bolt:assignment_target": - pattern = r"=(?!=)" - with stream.syntax(assignment=pattern): - op = stream.expect("assignment") - expression = delegate("bolt:expression", stream) +@dataclass +class AssignmentStatementParser: + """Parser for assignment statements.""" - identifiers |= pending_identifiers - pending_identifiers.clear() - identifiers_storage.update(pending_identifiers_storage) - pending_identifiers_storage.clear() + parser: Parser + binding_only: bool = False - node = AstAssignment(operator=op.value, target=node, value=expression) - node = set_location(node, node.target, node.value) + def __call__(self, stream: TokenStream) -> Any: + node = self.parser(stream) - elif isinstance(node, AstAttribute): - with stream.syntax(assignment=pattern): - op = stream.get("assignment") + assignment_pattern = r"=(?!=)|\+=|-=|\*=|//=|/=|%=|&=|\|=|\^=|<<=|>>=|\*\*=" - if op: - expression = delegate("bolt:expression", stream) - target = AstTargetAttribute( - name=node.name, - value=node.value, - ) - node = AstAssignment( - operator=op.value, - target=set_location(target, node), - value=expression, - ) - node = set_location(node, node.target, node.value) + if isinstance(node, AstTarget): + if self.binding_only: + assignment_pattern = r"=(?!=)" + with stream.syntax(assignment=assignment_pattern): + op = stream.expect("assignment") - elif isinstance(node, AstLookup): - with stream.syntax(assignment=pattern): - op = stream.get("assignment") + expression = delegate("bolt:expression", stream) - if op: - expression = delegate("bolt:expression", stream) - target = AstTargetItem( - value=node.value, - arguments=node.arguments, - ) - node = AstAssignment( - operator=op.value, - target=set_location(target, node), - value=expression, - ) - node = set_location(node, node.target, node.value) + node = AstAssignment(operator=op.value, target=node, value=expression) + node = set_location(node, node.target, node.value) - # Make sure that the statement is not followed by anything. - check_final_expression(stream) + elif isinstance(node, (AstAttribute, AstLookup)): + with stream.syntax(assignment=assignment_pattern): + op = stream.get("assignment") - return node + if op: + expression = delegate("bolt:expression", stream) + target = ( + AstTargetAttribute(name=node.name, value=node.value) + if isinstance(node, AstAttribute) + else AstTargetItem(value=node.value, arguments=node.arguments) + ) + node = AstAssignment( + operator=op.value, + target=set_location(target, node), + value=expression, + ) + node = set_location(node, node.target, node.value) + + return node def parse_decorator(stream: TokenStream) -> Any: @@ -889,40 +829,30 @@ def __call__(self, stream: TokenStream) -> AstRoot: class AssignmentTargetParser: """Parser for assignment targets.""" - allow_undefined_identifiers: bool = False - allow_multiple: bool = False + require_local_binding: bool = False + require_single: bool = False def __call__(self, stream: TokenStream) -> AstTarget: - identifiers = get_stream_identifiers(stream) - pending_identifiers = get_stream_pending_identifiers(stream) - identifiers_storage = get_stream_identifiers_storage(stream) - pending_identifiers_storage = get_stream_pending_identifiers_storage(stream) + lexical_scope = get_stream_lexical_scope(stream) nodes: List[AstTarget] = [] with stream.syntax(identifier=IDENTIFIER_PATTERN, comma=r","): while True: token = stream.expect("identifier") - is_defined = token.value in identifiers - with_storage = token.value in identifiers_storage - rebind = is_defined and with_storage - - if self.allow_undefined_identifiers: - pending_identifiers.add(token.value) - if not with_storage: - pending_identifiers_storage[token.value] = "local" - elif not rebind: - exc = UndefinedIdentifier(token.value, list(identifiers_storage)) - if is_defined and not with_storage: - exc.notes.append( - f"Use 'global {token.value}' or 'nonlocal {token.value}' to mutate the variable defined in outer scope." - ) - raise set_location(exc, token) + rebind = lexical_scope.has_binding(token.value) target = AstTargetIdentifier(value=token.value, rebind=rebind) - nodes.append(set_location(target, token)) + target = set_location(target, token) - if not self.allow_multiple or not stream.get("comma"): + if self.require_local_binding: + lexical_scope.reference_binding(token.value, target) + + lexical_scope.create_pending_binding(token.value, target) + + nodes.append(target) + + if self.require_single or not stream.get("comma"): break if len(nodes) == 1: @@ -1020,39 +950,8 @@ def __call__(self, stream: TokenStream) -> AstRoot: return node -@dataclass -class FlushPendingIdentifiersParser: - """Parser that flushes pending identifiers.""" - - parser: Parser - - def __call__(self, stream: TokenStream) -> Any: - identifiers = get_stream_identifiers(stream) - pending_identifiers = get_stream_pending_identifiers(stream) - identifiers_storage = get_stream_identifiers_storage(stream) - pending_identifiers_storage = get_stream_pending_identifiers_storage(stream) - - identifiers |= pending_identifiers - pending_identifiers.clear() - identifiers_storage.update(pending_identifiers_storage) - pending_identifiers_storage.clear() - - return self.parser(stream) - - def parse_deferred_root(stream: TokenStream) -> AstDeferredRoot: """Parse deferred root.""" - identifiers = get_stream_identifiers(stream) - pending_identifiers = get_stream_pending_identifiers(stream) - deferred_locals = get_stream_deferred_locals(stream) - class_scope = get_stream_class_scope(stream) - - identifiers.update(pending_identifiers) - pending_identifiers.clear() - - if class_scope is not None: - identifiers = class_scope - stream_copy = stream.copy() with stream.syntax(colon=r":"): @@ -1062,16 +961,12 @@ def parse_deferred_root(stream: TokenStream) -> AstDeferredRoot: while consume_line_continuation(stream): stream.expect("statement") - stream_copy.data["identifiers"] = identifiers | deferred_locals - stream_copy.data["pending_identifiers"] = set() - stream_copy.data["identifiers_storage"] = {loc: "local" for loc in deferred_locals} - stream_copy.data["pending_identifiers_storage"] = {} - stream_copy.data["deferred_locals"] = set() - stream_copy.data["branch_scope"] = set() - stream_copy.data["class_scope"] = None - del stream_copy.data["root_scope"] + lexical_scope = get_stream_lexical_scope(stream) - deferred_locals.clear() + deferred_scope = lexical_scope.deferred(LexicalScope) + stream_copy.data["lexical_scope"] = deferred_scope + del stream_copy.data["root_scope"] + lexical_scope.deferred_complete() node = AstDeferredRoot(stream=stream_copy) return set_location(node, token, stream.current) @@ -1079,18 +974,9 @@ def parse_deferred_root(stream: TokenStream) -> AstDeferredRoot: def parse_function_signature(stream: TokenStream) -> AstFunctionSignature: """Parse function signature.""" - identifiers = get_stream_identifiers(stream) - pending_identifiers = get_stream_pending_identifiers(stream) - deferred_locals = get_stream_deferred_locals(stream) - class_scope = get_stream_class_scope(stream) - - if class_scope is not None: - identifiers = class_scope - - scoped_identifiers = set(identifiers) + lexical_scope = get_stream_lexical_scope(stream) arguments: List[AstFunctionSignatureElement] = [] - new_locals: Set[str] = set() encountered_positional = False encountered_default = False @@ -1110,9 +996,12 @@ def parse_function_signature(stream: TokenStream) -> AstFunctionSignature: identifier = stream.expect("identifier") stream.expect(("brace", "(")) - scoped_identifiers.add(identifier.value) + node = set_location(AstFunctionSignature(name=identifier.value), identifier) + lexical_scope.bind_variable(identifier.value, node) - with stream.ignore("newline"): + deferred_scope = lexical_scope.deferred(FunctionScope) + + with stream.ignore("newline"), stream.provide(lexical_scope=deferred_scope): for token in stream.peek_until(("brace", ")")): if encountered_variadic_keyword: exc = InvalidSyntax( @@ -1133,10 +1022,7 @@ def parse_function_signature(stream: TokenStream) -> AstFunctionSignature: default = None if stream.get("equal"): encountered_default = True - with stream.provide( - identifiers=scoped_identifiers | new_locals - ): - default = delegate("bolt:expression", stream) + default = delegate("bolt:expression", stream) elif encountered_default: exc = InvalidSyntax( "Argument without default can not appear after arguments with default values." @@ -1147,8 +1033,9 @@ def parse_function_signature(stream: TokenStream) -> AstFunctionSignature: name=name.value, default=default, ) - arguments.append(set_location(argument, name, stream.current)) - new_locals.add(argument.name) + argument = set_location(argument, name, stream.current) + deferred_scope.bind_variable(argument.name, argument) + arguments.append(argument) elif separator: if encountered_positional: @@ -1176,8 +1063,9 @@ def parse_function_signature(stream: TokenStream) -> AstFunctionSignature: argument = AstFunctionSignatureVariadicKeywordArgument( name=name.value ) - arguments.append(set_location(argument, kwargs, name)) - new_locals.add(argument.name) + argument = set_location(argument, kwargs, name) + deferred_scope.bind_variable(argument.name, argument) + arguments.append(argument) elif args: if encountered_variadic: @@ -1186,8 +1074,9 @@ def parse_function_signature(stream: TokenStream) -> AstFunctionSignature: encountered_variadic = True if name := stream.get("identifier"): argument = AstFunctionSignatureVariadicArgument(name=name.value) - arguments.append(set_location(argument, args, name)) - new_locals.add(argument.name) + argument = set_location(argument, args, name) + deferred_scope.bind_variable(argument.name, argument) + arguments.append(argument) else: expect_named_argument = True argument = AstFunctionSignatureVariadicMarker() @@ -1205,24 +1094,25 @@ def parse_function_signature(stream: TokenStream) -> AstFunctionSignature: ) raise set_location(exc, argument) - pending_identifiers.add(identifier.value) - deferred_locals |= new_locals - - node = AstFunctionSignature(name=identifier.value, arguments=AstChildren(arguments)) - return set_location(node, identifier, stream.current) + node = replace(node, arguments=AstChildren(arguments)) + return set_location(node, node, stream.current) def parse_proc_macro_signature(stream: TokenStream): """Parse proc macro signature.""" + with stream.syntax(brace=r"\(|\)", identifier=IDENTIFIER_PATTERN): begin = stream.expect(("brace", "(")) stream.expect(("identifier", "stream")) end = stream.expect(("brace", ")")) - get_stream_pending_identifiers(stream).add("stream") - get_stream_deferred_locals(stream).add("stream") + node = set_location(AstProcMacroMarker(), begin, end) - return set_location(AstProcMacroMarker(), begin, end) + lexical_scope = get_stream_lexical_scope(stream) + deferred_scope = lexical_scope.deferred(ProcMacroScope) + deferred_scope.bind_variable("stream", node) + + return node @dataclass @@ -1236,17 +1126,12 @@ class MacroMatchParser: def __call__(self, stream: TokenStream) -> AstMacroMatch: spec = get_stream_spec(stream) - pending_identifiers = get_stream_pending_identifiers(stream) - deferred_locals = get_stream_deferred_locals(stream) with stream.syntax(equal=r"(? AstMacroMatch: exc = InvalidSyntax(f'Unrecognized argument parser "{parser_name}".') raise set_location(exc, node.match_argument_parser) + lexical_scope = get_stream_lexical_scope(stream) + deferred_scope = lexical_scope.deferred(MacroScope) + deferred_scope.bind_variable(node.match_identifier.value, node) + if commit.rollback: literal = self.literal_parser(stream) node = set_location(AstMacroMatchLiteral(match=literal), literal) @@ -1504,9 +1393,12 @@ def parse_class_name(stream: TokenStream) -> AstClassName: """Parse class name.""" with stream.syntax(identifier=IDENTIFIER_PATTERN): token = stream.expect("identifier") + node = set_location(AstClassName(value=token.value), token) - get_stream_pending_identifiers(stream).add(token.value) - return set_location(AstClassName(value=token.value), token) + lexical_scope = get_stream_lexical_scope(stream) + lexical_scope.create_pending_binding(node.value, node) + + return node def parse_class_bases(stream: TokenStream) -> AstClassBases: @@ -1528,45 +1420,19 @@ def parse_class_bases(stream: TokenStream) -> AstClassBases: return set_location(node, token, stream.current) -def parse_class_root(stream: TokenStream) -> AstRoot: +def parse_class_root(stream: TokenStream) -> AstClassRoot: """Parse class root.""" - with stream.provide(class_scope=get_stream_identifiers(stream)): - return delegate("nested_root", stream) - - -@dataclass -class ClassBodyScoping: - """Handle class body scoping.""" - - parser: Parser - - def __call__(self, stream: TokenStream) -> Any: - class_scope = get_stream_class_scope(stream) - if class_scope is None: - return self.parser(stream) + lexical_scope = get_stream_lexical_scope(stream) - identifiers = get_stream_identifiers(stream) - pending_identifiers = get_stream_pending_identifiers(stream) - identifiers_storage = get_stream_identifiers_storage(stream) - pending_identifiers_storage = get_stream_pending_identifiers_storage(stream) - - with stream.provide( - function=False, - identifiers=identifiers.copy(), - pending_identifiers=set(), - identifiers_storage=identifiers_storage.copy(), - pending_identifiers_storage={}, - ): - node = self.parser(stream) - identifiers.update(pending_identifiers) - pending_identifiers.clear() - identifiers_storage.update(pending_identifiers_storage) - pending_identifiers_storage.clear() + with stream.provide(lexical_scope=lexical_scope.push(ClassScope)): + if isinstance(node := delegate("nested_root", stream), AstRoot): + node = set_location(AstClassRoot(commands=node.commands), node) - return set_location(AstClassRoot(commands=node.commands), node) + return node def parse_del_target(stream: TokenStream) -> AstTarget: + """Parse del target.""" node = delegate("bolt:expression", stream) if isinstance(node, AstIdentifier): @@ -1584,7 +1450,7 @@ def parse_del_target(stream: TokenStream) -> AstTarget: def parse_identifier(stream: TokenStream) -> AstIdentifier: """Parse identifier.""" - identifiers = get_stream_identifiers(stream) + lexical_scope = get_stream_lexical_scope(stream) with stream.syntax( true=TRUE_PATTERN, @@ -1593,12 +1459,11 @@ def parse_identifier(stream: TokenStream) -> AstIdentifier: identifier=IDENTIFIER_PATTERN, ): token = stream.expect("identifier") + node = set_location(AstIdentifier(value=token.value), token) - if token.value not in identifiers: - exc = UndefinedIdentifier(token.value, list(identifiers)) - raise set_location(exc, token) + lexical_scope.reference_variable(node.value, node) - return set_location(AstIdentifier(value=token.value), token) + return node @dataclass @@ -1653,30 +1518,27 @@ def check_root_scope(self, stream: TokenStream, node: AstCommand): raise set_location(InvalidSyntax(message), node) def handle_import(self, stream: TokenStream, node: AstCommand) -> AstCommand: - identifiers = get_stream_identifiers(stream) - identifiers_storage = get_stream_identifiers_storage(stream) + lexical_scope = get_stream_lexical_scope(stream) if node.identifier == "import:module": if isinstance(module := node.arguments[0], AstResourceLocation): if module.namespace: message = f'Can\'t import "{module.get_value()}" without alias.' raise set_location(InvalidSyntax(message), module) - identifiers.add(module.path.partition(".")[0]) + lexical_scope.bind_variable(module.path.partition(".")[0], module) elif node.identifier == "import:module:as:alias": if isinstance(alias := node.arguments[1], AstImportedItem): if not alias.identifier: exc = InvalidSyntax(f'Invalid identifier "{alias.name}".') raise set_location(exc, alias) - identifiers.add(alias.name) - identifiers_storage.setdefault(alias.name, "local") + lexical_scope.bind_variable(alias.name, alias) return node @internal def handle_from_import(self, stream: TokenStream, node: AstCommand) -> AstCommand: - identifiers = get_stream_identifiers(stream) - identifiers_storage = get_stream_identifiers_storage(stream) + lexical_scope = get_stream_lexical_scope(stream) pending_macros = get_stream_pending_macros(stream) module = cast(AstResourceLocation, node.arguments[0]) @@ -1716,8 +1578,7 @@ def handle_from_import(self, stream: TokenStream, node: AstCommand) -> AstComman elif item.identifier: arguments.append(item) - identifiers.add(item.name) - identifiers_storage.setdefault(item.name, "local") + lexical_scope.bind_variable(item.name, item) else: exc = InvalidSyntax(f'Invalid identifier "{item.name}".') @@ -1766,19 +1627,22 @@ class GlobalNonlocalHandler: parser: Parser - def __call__(self, stream: TokenStream) -> Any: - identifiers_storage = get_stream_identifiers_storage(stream) + storage_qualifiers: Dict[str, Literal["global", "nonlocal"]] = field( + default_factory=lambda: { + "global:subcommand": "global", + "nonlocal:subcommand": "nonlocal", + } + ) + def __call__(self, stream: TokenStream) -> Any: if isinstance(node := self.parser(stream), AstCommand): - if node.identifier in ["global:subcommand", "nonlocal:subcommand"]: - storage, _, _ = node.identifier.partition(":") + if storage := self.storage_qualifiers.get(node.identifier): subcommand = cast(AstCommand, node.arguments[0]) + + lexical_scope = get_stream_lexical_scope(stream) while True: if isinstance(name := subcommand.arguments[0], AstIdentifier): - s = identifiers_storage.setdefault(name.value, storage) - if s != storage: - exc = InvalidSyntax(f"Can't make {s} identifier {storage}.") - raise set_location(exc, name) + lexical_scope.bind_storage(name.value, storage, name) if subcommand.identifier == f"{storage}:name:subcommand": subcommand = cast(AstCommand, subcommand.arguments[1]) else: @@ -1787,6 +1651,25 @@ def __call__(self, stream: TokenStream) -> Any: return node +@dataclass +class FlushPendingBindingsParser: + """Parser that flushes pending bindings.""" + + parser: Parser + before: bool = False + after: bool = False + + def __call__(self, stream: TokenStream) -> Any: + if self.before: + lexical_scope = get_stream_lexical_scope(stream) + lexical_scope.flush_pending_bindings() + node = self.parser(stream) + if self.after: + lexical_scope = get_stream_lexical_scope(stream) + lexical_scope.flush_pending_bindings() + return node + + @dataclass class StatementSubcommandHandler: """Prevent statements as subcommands.""" @@ -1815,7 +1698,8 @@ class DeferredRootBacktracker: def __call__(self, stream: TokenStream) -> AstRoot: node: AstRoot = self.parser(stream) - if isinstance(node, AstClassRoot): + lexical_scope = get_stream_lexical_scope(stream) + if isinstance(lexical_scope, ClassScope): return node return self.resolve_deferred(node, stream) @@ -1855,10 +1739,9 @@ def resolve_deferred(self, node: AstRoot, stream: TokenStream) -> AstRoot: deferred_stream = deferred_root.stream deferred_stream.data["local_spec"] = False deferred_stream.data["spec"] = get_stream_spec(stream) - deferred_stream.data["identifiers"] |= get_stream_identifiers(stream) - deferred_stream.data["function"] = True deferred_stream.data["macro_scope"] = get_stream_macro_scope(stream) deferred_stream.data["pending_macros"] = [] + nested_root = delegate("nested_root", deferred_stream) self.macro_handler.cache_local_spec(deferred_stream) @@ -1886,16 +1769,18 @@ def resolve_deferred(self, node: AstRoot, stream: TokenStream) -> AstRoot: @dataclass -class FunctionConstraint: - """Constraint that makes sure that the given statements only occur in functions.""" +class LexicalScopeConstraint: + """Constraint that restricts specific statements to the given scope type.""" parser: Parser + type: Type[LexicalScope] command_identifiers: Set[str] def __call__(self, stream: TokenStream) -> AstRoot: node = self.parser(stream) - if stream.data.get("function"): + lexical_scope = get_stream_lexical_scope(stream) + if isinstance(lexical_scope, self.type): return node if isinstance(node, AstRoot): @@ -2248,6 +2133,7 @@ class BuiltinCallRestriction: """Only allow call expressions on builtins.""" parser: Parser + builtins: Set[str] def __call__(self, stream: TokenStream) -> Any: parent = None @@ -2258,11 +2144,12 @@ def __call__(self, stream: TokenStream) -> Any: parent = node node = node.value + lexical_scope = get_stream_lexical_scope(stream) if ( isinstance(node, AstIdentifier) and not isinstance(parent, AstCall) - and node.value in get_stream_builtins(stream) - and node.value not in get_stream_identifiers_storage(stream) + and node.value in self.builtins + and not lexical_scope.has_binding(node.value, search_parents=True) ): # Reset the underlying token generator so that the identifier can # be re-parsed as a different token. That shouldn't be necessary @@ -2277,21 +2164,19 @@ def __call__(self, stream: TokenStream) -> Any: def parse_dict_item(stream: TokenStream) -> Any: """Parse dict item node.""" + with stream.syntax(colon=r":", identifier=IDENTIFIER_PATTERN): with stream.checkpoint() as commit: - identifier = stream.expect("identifier") + token = stream.expect("identifier") stream.expect("colon") commit() - if identifier.value in get_stream_identifiers(stream) and ( - not identifier.value in get_stream_builtins(stream) - or identifier.value in get_stream_identifiers_storage(stream) - ): - key = AstIdentifier(value=identifier.value) + lexical_scope = get_stream_lexical_scope(stream) + if lexical_scope.has_binding(token.value, search_parents=True): + key = set_location(AstIdentifier(value=token.value), token) + lexical_scope.reference_variable(key.value, key) else: - key = AstDictUnquotedKey(value=identifier.value) - - key = set_location(key, identifier) + key = set_location(AstDictUnquotedKey(value=token.value), token) if commit.rollback: key = delegate("bolt:expression", stream) diff --git a/bolt/semantics.py b/bolt/semantics.py new file mode 100644 index 0000000..527c041 --- /dev/null +++ b/bolt/semantics.py @@ -0,0 +1,366 @@ +__all__ = [ + "LexicalScope", + "GlobalScope", + "FunctionScope", + "MacroScope", + "ProcMacroScope", + "ClassScope", + "Variable", + "Binding", + "UndefinedIdentifier", + "UnboundLocalIdentifier", + "InconsistentIdentifierStorage", +] + + +from dataclasses import dataclass, field, replace +from typing import Dict, List, Literal, Optional, Set, Tuple, Type, TypeVar + +from beet.core.utils import extra_field +from mecha import AstNode +from tokenstream import InvalidSyntax, set_location + +from .utils import suggest_typo + +LexicalScopeType = TypeVar("LexicalScopeType", bound="LexicalScope") + + +class UndefinedIdentifier(InvalidSyntax): + """Raised when an identifier is not defined. + + Attributes + ---------- + identifier + The identifier that couldn't be found. + lexical_scope + The current scope. + """ + + identifier: str + lexical_scope: "LexicalScope" + + def __init__(self, identifier: str, lexical_scope: "LexicalScope"): + super().__init__(identifier, lexical_scope) + self.identifier = identifier + self.lexical_scope = lexical_scope + + def __str__(self) -> str: + msg = f'Identifier "{self.identifier}" is not defined.' + + variable_names = self.lexical_scope.list_variables() - {self.identifier} + if suggestion := suggest_typo(self.identifier, variable_names): + msg += f" Did you mean {suggestion}?" + + return msg + + +class UnboundLocalIdentifier(UndefinedIdentifier): + """Raised when mutating a variable before assignment. + + This is a specialized error for code that tries to access and mutate a variable + with no local binding in the current scope. + """ + + def __init__(self, identifier: str, lexical_scope: "LexicalScope"): + super().__init__(identifier, lexical_scope) + if lexical_scope.parent and lexical_scope.parent.has_variable(identifier): + self.notes.append( + f'Use "global {identifier}" or "nonlocal {identifier}" to mutate the variable defined in outer scope.' + ) + + +class InconsistentIdentifierStorage(InvalidSyntax): + """Raised when trying to change the storage of an already existing variable. + + Attributes + ---------- + identifier + The identifier with a prior conflicting definition in the current scope. + storage + The target storage of the inconsistent binding. + lexical_scope + The current scope. + """ + + identifier: str + storage: Literal["local", "nonlocal", "global"] + lexical_scope: "LexicalScope" + + def __init__( + self, + identifier: str, + storage: Literal["local", "nonlocal", "global"], + lexical_scope: "LexicalScope", + ): + super().__init__(identifier, storage, lexical_scope) + self.identifier = identifier + self.storage = storage + self.lexical_scope = lexical_scope + + def __str__(self) -> str: + return f'Can\'t make {self.lexical_scope.variables[self.identifier].storage} identifier "{self.identifier}" {self.storage}.' + + +@dataclass +class Binding: + """Variable binding. + + Attributes + ---------- + origin + The code fragment performing the variable binding. + references + Other parts of the code that use the value set by the binding. + cutoff + The number of references in common after forking. + """ + + origin: AstNode + references: List[AstNode] = field(default_factory=list) + + cutoff: int = extra_field(default=0) + + def fork(self) -> "Binding": + return replace( + self, + references=self.references.copy(), + cutoff=len(self.references), + ) + + def reconcile(self, binding: "Binding"): + self.references += binding.references[binding.cutoff :] + + +@dataclass +class Variable: + """Variable definition. + + Attributes + ---------- + storage + The type of storage for the variable. + bindings + Associated bindings updating the value or the storage of the variable. + cutoff + The number of bindings in common after forking. + """ + + storage: Literal["local", "nonlocal", "global"] = "local" + bindings: List[Binding] = field(default_factory=list) + + cutoff: int = extra_field(default=0) + + def fork(self): + return replace( + self, + bindings=[binding.fork() for binding in self.bindings], + cutoff=len(self.bindings), + ) + + def reconcile(self, variable: "Variable"): + for i, binding in enumerate(variable.bindings[: variable.cutoff]): + self.bindings[i].reconcile(binding) + + for new_binding in variable.bindings[variable.cutoff :]: + self.bindings.append(new_binding) + + +@dataclass +class LexicalScope: + """Class for managing variable scopes during parsing. + + Attributes + ---------- + parent + Parent scope if the current scope has one. + variables + Dictionary holding variable definitions. + dirty + Flag to skip reconciliation after forking in case no additional + variables or bindings were introduced. + """ + + variables: Dict[str, Variable] = field(default_factory=dict) + + anchor: Optional[AstNode] = extra_field(default=None) + parent: Optional["LexicalScope"] = extra_field(default=None) + children: List["LexicalScope"] = extra_field(default_factory=list) + + dirty: bool = extra_field(default=False) + + next_branch: Optional["LexicalScope"] = extra_field(default=None) + next_deferred: Optional["LexicalScope"] = extra_field(default=None) + pending_bindings: List[Tuple[str, AstNode]] = extra_field(default_factory=list) + + def has_binding(self, identifier: str, search_parents: bool = False) -> bool: + return identifier in self.variables or bool( + search_parents + and self.parent + and self.parent.has_binding(identifier, search_parents=True) + ) + + def reference_binding(self, identifier: str, node: AstNode): + if not self.has_binding(identifier): + raise set_location(UnboundLocalIdentifier(identifier, self), node) + self.reference_variable(identifier, node) + + def has_variable(self, identifier: str) -> bool: + return identifier in self.variables or bool( + self.parent and self.parent.has_variable(identifier) + ) + + def list_variables(self) -> Set[str]: + all_variables = set(self.variables) + if self.parent: + all_variables |= self.parent.list_variables() + return all_variables + + def reference_variable( + self, + identifier: str, + node: AstNode, + scope: Optional["LexicalScope"] = None, + ): + if not scope: + scope = self + + if variable := self.variables.get(identifier): + binding = variable.bindings[-1] + + if not binding.references or binding.references[-1] is not node: + self.dirty = True + binding.references.append(node) + + if variable.storage != "local" and self.parent: + self.parent.reference_variable(identifier, node, scope) + + return + + if not self.parent: + raise set_location(UndefinedIdentifier(identifier, scope), node) + + self.parent.reference_variable(identifier, node, scope) + + def bind_variable(self, identifier: str, node: AstNode): + self.dirty = True + + if self.has_binding(identifier): + self.reference_variable(identifier, node) + + variable = self.variables.get(identifier) + if not variable: + variable = Variable() + self.variables[identifier] = variable + + variable.bindings.append(Binding(origin=node)) + + if variable.storage != "local" and self.parent: + self.parent.bind_variable(identifier, node) + + def bind_storage( + self, + identifier: str, + storage: Literal["nonlocal", "global"], + node: AstNode, + ): + self.dirty = True + + self.reference_variable(identifier, node) + + variable = self.variables.get(identifier) + if not variable: + variable = Variable(storage=storage) + self.variables[identifier] = variable + elif variable.storage != storage: + exc = InconsistentIdentifierStorage(identifier, storage, self) + raise set_location(exc, node) + + variable.bindings.append(Binding(origin=node)) + + def push(self, scope_type: Type["LexicalScopeType"]) -> "LexicalScopeType": + child = scope_type(parent=self) + self.children.append(child) + return child + + def deferred(self, scope_type: Type["LexicalScopeType"]) -> "LexicalScopeType": + if not isinstance(self.next_deferred, scope_type): + self.next_deferred = self.push(scope_type) + return self.next_deferred # type: ignore + + def deferred_complete(self): + self.next_deferred = None + + def fork(self) -> "LexicalScope": + return replace( + self, + parent=self.parent and self.parent.fork(), + dirty=False, + variables={ + identifier: variable.fork() + for identifier, variable in self.variables.items() + }, + ) + + def reconcile(self, lexical_scope: "LexicalScope"): + if lexical_scope.dirty: + self.dirty = True + for identifier, variable in lexical_scope.variables.items(): + if current_variable := self.variables.get(identifier): + current_variable.reconcile(variable) + else: + self.variables[identifier] = variable + if self.parent and lexical_scope.parent: + self.parent.reconcile(lexical_scope.parent) + + def create_pending_binding(self, identifier: str, node: AstNode): + self.pending_bindings.append((identifier, node)) + + def flush_pending_bindings(self): + for identifier, node in self.pending_bindings: + self.bind_variable(identifier, node) + self.pending_bindings.clear() + + +@dataclass +class GlobalScope(LexicalScope): + """Specialized global scope to fork more efficiently.""" + + identifiers: Set[str] = field(default_factory=set) + + def has_variable(self, identifier: str) -> bool: + return identifier in self.identifiers or super().has_variable(identifier) + + def list_variables(self) -> Set[str]: + return self.identifiers | super().list_variables() + + def reference_variable( + self, + identifier: str, + node: AstNode, + scope: Optional["LexicalScope"] = None, + ): + if identifier not in self.identifiers or identifier in self.variables: + super().reference_variable(identifier, node, scope) + + +class FunctionScope(LexicalScope): + """Dedicated type for function scope.""" + + +class MacroScope(FunctionScope): + """Dedicated type for macro scope.""" + + +class ProcMacroScope(FunctionScope): + """Dedicated type for proc macro scope.""" + + +class ClassScope(LexicalScope): + """Dedicated type for class scope.""" + + def push(self, scope_type: Type["LexicalScopeType"]) -> "LexicalScopeType": + if issubclass(scope_type, FunctionScope): + if not self.parent: + raise ValueError("Class scope has no parent.") + return self.parent.push(scope_type) + return super().push(scope_type) diff --git a/bolt/utils.py b/bolt/utils.py index 1c90f3c..e430d85 100644 --- a/bolt/utils.py +++ b/bolt/utils.py @@ -74,7 +74,9 @@ def fake_traceback(exc: Exception, tb: TracebackType, lineno: int) -> TracebackT def suggest_typo(wrong: str, possibilities: Iterable[str]) -> Optional[str]: - if matches := get_close_matches(wrong, possibilities): + cutoff = 0.6 if len(wrong) < 3 else 0.7 + + if matches := get_close_matches(wrong, possibilities, cutoff=cutoff): matches = [f'"{m}"' for m in matches] if len(matches) == 1: diff --git a/tests/snapshots/bolt__parse_125__0.txt b/tests/snapshots/bolt__parse_125__0.txt index 69d92bb..03ca160 100644 --- a/tests/snapshots/bolt__parse_125__0.txt +++ b/tests/snapshots/bolt__parse_125__0.txt @@ -1,5 +1,5 @@ import math -#>ERROR Identifier "mah" is not defined. +#>ERROR Identifier "mah" is not defined. Did you mean "math"? # line 2, column 1 # 1 | import math # 2 | mah.sin(1) diff --git a/tests/snapshots/bolt__parse_188__0.txt b/tests/snapshots/bolt__parse_188__0.txt index e5b5e93..46e3496 100644 --- a/tests/snapshots/bolt__parse_188__0.txt +++ b/tests/snapshots/bolt__parse_188__0.txt @@ -1,5 +1,5 @@ x = 1 -#>ERROR Can't make local identifier global. +#>ERROR Can't make local identifier "x" global. # line 2, column 8 # 1 | x = 1 # 2 | global x diff --git a/tests/snapshots/bolt__parse_190__0.txt b/tests/snapshots/bolt__parse_190__0.txt index 3ce10b1..ad0409b 100644 --- a/tests/snapshots/bolt__parse_190__0.txt +++ b/tests/snapshots/bolt__parse_190__0.txt @@ -1,5 +1,5 @@ def f(x): -#>ERROR Can't make local identifier global. +#>ERROR Can't make local identifier "x" global. # line 2, column 12 # 1 | def f(x): # 2 | global x diff --git a/tests/snapshots/bolt__parse_191__0.txt b/tests/snapshots/bolt__parse_191__0.txt index 742aaf0..bfeda5e 100644 --- a/tests/snapshots/bolt__parse_191__0.txt +++ b/tests/snapshots/bolt__parse_191__0.txt @@ -1,7 +1,7 @@ x = 1 def f(): global x -#>ERROR Can't make global identifier nonlocal. +#>ERROR Can't make global identifier "x" nonlocal. # line 4, column 14 # 3 | global x # 4 | nonlocal x diff --git a/tests/snapshots/bolt__parse_195__0.txt b/tests/snapshots/bolt__parse_195__0.txt index e60aee0..3aaff46 100644 --- a/tests/snapshots/bolt__parse_195__0.txt +++ b/tests/snapshots/bolt__parse_195__0.txt @@ -1,5 +1,5 @@ for i in range(3): -#>ERROR Identifier "parent" is not defined. +#>ERROR Identifier "parent" is not defined. Did you mean "print"? # line 2, column 5 # 1 | for i in range(3): # 2 | parent[i] = f"param{i}" diff --git a/tests/snapshots/bolt__parse_200__0.txt b/tests/snapshots/bolt__parse_200__0.txt index 1d605bf..d1da8b0 100644 --- a/tests/snapshots/bolt__parse_200__0.txt +++ b/tests/snapshots/bolt__parse_200__0.txt @@ -6,6 +6,6 @@ def f(): # 3 | a += 1 # : ^ # Notes: -# - Use 'global a' or 'nonlocal a' to mutate the variable defined in outer scope. +# - Use "global a" or "nonlocal a" to mutate the variable defined in outer scope. # - Expected assignment but got literal '+='. a += 1 diff --git a/tests/snapshots/bolt__parse_58__0.txt b/tests/snapshots/bolt__parse_58__0.txt index 4fec1ba..506903a 100644 --- a/tests/snapshots/bolt__parse_58__0.txt +++ b/tests/snapshots/bolt__parse_58__0.txt @@ -1,4 +1,4 @@ -#>ERROR Identifier "the" is not defined. Did you mean "hex"? +#>ERROR Identifier "the" is not defined. # line 1, column 7 # 1 | wat = the = "f" # : ^^^ diff --git a/tests/test_semantics.py b/tests/test_semantics.py new file mode 100644 index 0000000..de3ab5a --- /dev/null +++ b/tests/test_semantics.py @@ -0,0 +1,133 @@ +import pytest + +from bolt import ( + AstIdentifier, + Binding, + FunctionScope, + GlobalScope, + InconsistentIdentifierStorage, + LexicalScope, + UnboundLocalIdentifier, + UndefinedIdentifier, + Variable, +) + + +def test_variable(): + lexical_scope = LexicalScope() + assert not lexical_scope.has_variable("foo") + + lexical_scope.bind_variable("foo", AstIdentifier(value="foo")) + assert lexical_scope.has_variable("foo") + assert lexical_scope.list_variables() == {"foo"} + + lexical_scope.reference_variable("foo", AstIdentifier(value="foo")) + assert lexical_scope.variables == { + "foo": Variable( + bindings=[ + Binding( + origin=AstIdentifier(value="foo"), + references=[AstIdentifier(value="foo")], + ) + ] + ) + } + + +def test_global(): + global_scope = GlobalScope(identifiers={"THING"}) + + assert global_scope.has_variable("THING") + assert global_scope.list_variables() == {"THING"} + + assert not global_scope.has_binding("THING") + + lexical_scope = global_scope.push(LexicalScope) + assert lexical_scope.has_variable("THING") + assert lexical_scope.list_variables() == {"THING"} + + lexical_scope.reference_variable("THING", AstIdentifier(value="THING")) + assert lexical_scope.variables == {} + assert global_scope.variables == {} + + +def test_children(): + lexical_scope = LexicalScope() + + with pytest.raises(UndefinedIdentifier, match='"foo" is not defined'): + lexical_scope.reference_variable("foo", AstIdentifier(value="foo")) + with pytest.raises(UnboundLocalIdentifier, match='"foo" is not defined'): + lexical_scope.reference_binding("foo", AstIdentifier(value="foo")) + + lexical_scope.bind_variable("foo", AstIdentifier(value="foo")) + + child_scope = lexical_scope.push(FunctionScope) + + with pytest.raises(UnboundLocalIdentifier) as exc_info: + child_scope.reference_binding("foo", AstIdentifier(value="foo")) + assert exc_info.value.notes == [ + 'Use "global foo" or "nonlocal foo" to mutate the variable defined in outer scope.' + ] + + alternative_scope = child_scope.fork() + + child_scope.bind_variable("foo", AstIdentifier(value="foo")) + + with pytest.raises(InconsistentIdentifierStorage): + child_scope.bind_storage("foo", "nonlocal", AstIdentifier(value="foo")) + + alternative_scope.bind_storage("foo", "nonlocal", AstIdentifier(value="foostorage")) + alternative_scope.reference_variable("foo", AstIdentifier(value="fooparent")) + alternative_scope.bind_variable("foo", AstIdentifier(value="foovalue")) + alternative_scope.reference_variable("foo", AstIdentifier(value="fooref")) + + assert alternative_scope.variables == { + "foo": Variable( + storage="nonlocal", + bindings=[ + Binding( + origin=AstIdentifier(value="foostorage"), + references=[ + AstIdentifier(value="fooparent"), + AstIdentifier(value="foovalue"), + ], + ), + Binding( + origin=AstIdentifier(value="foovalue"), + references=[AstIdentifier(value="fooref")], + ), + ], + ) + } + + assert lexical_scope.variables == { + "foo": Variable( + bindings=[ + Binding( + origin=AstIdentifier(value="foo"), + references=[], + ) + ] + ) + } + + child_scope.reconcile(alternative_scope) + + assert lexical_scope.variables == { + "foo": Variable( + bindings=[ + Binding( + origin=AstIdentifier(value="foo"), + references=[ + AstIdentifier(value="foostorage"), + AstIdentifier(value="fooparent"), + AstIdentifier(value="foovalue"), + ], + ), + Binding( + origin=AstIdentifier(value="foovalue"), + references=[AstIdentifier(value="fooref")], + ), + ] + ) + }