diff --git a/outlines/text/parsing.py b/outlines/text/parsing.py index ec9ea5c27..c4c801410 100644 --- a/outlines/text/parsing.py +++ b/outlines/text/parsing.py @@ -9,6 +9,7 @@ Dict, Generator, Iterable, + List, Optional, Sequence, Set, @@ -29,7 +30,7 @@ from lark.lexer import BasicLexer, ContextualLexer, LexerState, Scanner from lark.parsers.lalr_analysis import Shift from lark.parsers.lalr_interactive_parser import InteractiveParser -from lark.parsers.lalr_parser import ParseConf, ParserState +from lark.parsers.lalr_parser import ParserState if TYPE_CHECKING: from lark.lexer import LexerThread @@ -73,17 +74,12 @@ def __init__(self, scanner: Scanner): self.fsm, self.fsms_to_trans_finals = fsm_union(fsms) - def match(self, text, pos): - """Get the match end position, terminal type, and final FSM state.""" - - text_part = text[pos:] - res = find_partial_matches(self.fsm, text_part, start_state=self.fsm.initial) - - if len(res) == 0: - return None - - ((lex_end, state_seq),) = res + def get_terminal_from_state_seq(self, state_seq) -> Tuple[int, Optional[int]]: + """Get the first complete or partial terminal name match for the sequence. + This also returns the last FSM state when the FSM still accepts more + input. + """ # We need to figure out to which sub-FSMs/terminals this match could # correspond res = None @@ -96,10 +92,25 @@ def match(self, text, pos): res = _res break + assert res is not None + (fsm_id, has_transition) = res type_name = self.terminals[fsm_id].name - return lex_end, type_name, state_seq[-1] if has_transition else None + return type_name, state_seq[-1] if has_transition else None + + def match(self, text, pos): + """Get the match end position, terminal type, and final FSM state.""" + + text_part = text[pos:] + res = find_partial_matches(self.fsm, text_part, start_state=self.fsm.initial) + + if len(res) == 0: + return None + + ((lex_end, state_seq),) = res + + return (lex_end,) + self.get_terminal_from_state_seq(state_seq) class PartialBasicLexer(BasicLexer): @@ -296,7 +307,7 @@ def parse_to_end( def find_partial_matches( fsm: FSM, input_string: str, start_state: Optional[int] = None -) -> Set[Tuple[Optional[int], Tuple[int, ...]]]: +) -> Set[Tuple[int, Tuple[int, ...]]]: """Find the states in the finite state machine `fsm` that accept `input_string`. This will consider all possible states in the finite state machine (FSM) @@ -366,7 +377,7 @@ def _partial_match( for state, trans in transition_maps.items(): if trans_key in trans: last_match_idx, path = _partial_match(trans) - if path is not None: + if last_match_idx is not None and path is not None: res.add((last_match_idx, (state,) + path)) return res @@ -446,44 +457,17 @@ def map_partial_states_to_vocab( return pstate_to_vocab, possible_paths -def terminals_to_lalr_states(lp: Lark) -> DefaultDict[str, Set[int]]: - terminals_to_states = defaultdict(set) +def terminal_reverse_maps(lp: Lark): + reverse_shifts = defaultdict(set) + symbols_to_states = defaultdict(set) parse_table = lp.parser.parser.parser.parse_table for state, tokens_to_ops in parse_table.states.items(): for token, op in tokens_to_ops.items(): if op[0] == Shift: - # `op[1]` is the state we shift to when `token` is observed - terminals_to_states[token].add(op[1]) - - return terminals_to_states + reverse_shifts[op[1]].add((token, state)) + symbols_to_states[token].add((state, op)) - -def create_pmatch_parser_states( - lp: Lark, - terminals_to_states: Dict[str, Set[int]], - term_type: str, - ptoken: str, - pmatch: Tuple[int, Tuple[int, ...]], -) -> Tuple[ParserState, ...]: - parse_table = lp.parser.parser.parser.parse_table - - # TODO: We need to effectively disable the callbacks that build the - # trees, because we aren't actually parsing a valid state that can, say, - # be reduced - def noop(*args, **kwargs): - pass - - callbacks = {rule: noop for rule, cb in lp._callbacks.items()} - parse_conf = ParseConf(parse_table, callbacks, lp.options.start[0]) - lexer_thread = lp.parser._make_lexer_thread(ptoken) - lexer_state = lexer_thread.state - lexer_state.line_ctr.char_pos = pmatch[0] + 1 - lexer_state.last_token = Token(term_type, "") - res = tuple( - ParserState(parse_conf, lexer_thread, [state], None) - for state in terminals_to_states[term_type] - ) - return res + return dict(symbols_to_states), dict(reverse_shifts) def fsm_union(fsms): @@ -607,3 +591,282 @@ def get_sub_fsms_from_seq( for fsm_idx, (transitions, finals) in fsms_to_trans_finals.items() if pmatch_transitions.issubset(transitions) ) + + +def get_lex_token_seqs(lp: Lark, v: str) -> Set[Tuple[int, Tuple[str, ...]]]: + """Lex a string `v` and obtain FSM start states and terminal token sequences. + + Parameters + ---------- + lp + The `lark` object holding the parser and its configuration. + v + The string to fully lex. + + """ + context_lexer = lp.parser.lexer.lexer + + results = set() + + terminals = context_lexer.root_lexer.terminals + ignored_types = context_lexer.root_lexer.ignore_types + + possible_seqs: List[Tuple[int, Optional[int], Tuple[str, ...]]] = [ + (0, None, ()), + ] + + # The Python grammar has ~100 terminals + fsms = [] + fsm_prefix_postfixes = [] + max_prefix = 0 + max_postfix = 0 + for t in terminals: + regex_str = t.pattern.to_regexp() + pattern = interegular.parse_pattern(regex_str) + + fsm_prefix_postfixes.append(pattern.prefix_postfix) + max_prefix = max(max_prefix, pattern.prefix_postfix[0]) + max_postfix = max(max_postfix, pattern.prefix_postfix[1]) + + # TODO FIXME: We don't support this right now. + assert max_prefix == 0 + assert max_postfix == 0 + + fsm = pattern.to_fsm().reduce() + fsms.append(fsm) + + # TODO: Determine if there is a transition from the previous + # terminal to this one. + def can_transition(terminal_name, state_seq): + return True + + while possible_seqs: + last_pos, start_fsm_state, state_seq = possible_seqs.pop(0) + + text_part = v[last_pos:] + + for i, fsm in enumerate(fsms): + terminal_name = terminals[i].name + + if not can_transition(terminal_name, state_seq): + continue + + # The initial runs start matching anywhere in the FSMs. Once those + # start branching out, and `start_fsm_state` is not `None`, we start + # the FSMs at their initial states. + start_state = fsm.initial if start_fsm_state is not None else None + + res = find_partial_matches(fsm, text_part, start_state=start_state) + + for end_pos, fsm_state_seq in res: + next_pos = last_pos + end_pos + 1 + + is_ignored = terminal_name in ignored_types + + if not is_ignored: + _start_fsm_state: Optional[int] = ( + fsm_state_seq[0] if start_fsm_state is None else start_fsm_state + ) + next_state_seq = state_seq + (terminal_name,) + else: + _start_fsm_state = start_fsm_state + next_state_seq = state_seq + + if next_pos == len(v): + # TODO: What do we want to do about strings like `" "`? + # In this case, `_start_fsm_state is None` and + # `next_state_seq == ()` + assert _start_fsm_state is not None + # We've hit the end of the string + results.add((_start_fsm_state, next_state_seq)) + else: + # This token isn't finished; keep scanning + possible_seqs.append((next_pos, _start_fsm_state, next_state_seq)) + + return results + + +def set_index(index, key_seq, value): + """Add `value` to the path `key_seq` in the trie-like `index`.""" + _index = index + for key in key_seq: + tokens, _index = _index.setdefault(key, (set(), {})) + tokens.add(value) + return index + + +def get_index(index, key_seq): + """Get the values under the path `key_seq` in the trie-like `index`.""" + tokens = None + _index = index + for key in key_seq: + res = _index.get(key) + if res is None: + return tokens + tokens, _index = res + + return tokens + + +def create_parser_vocab_index(lp: Lark, vocabulary: Iterable[str]): + """Create a trie-like index for for the strings in `vocabulary`. + + The keys of the resulting trie-like index consist of the FSM start states + concatenated with the parser state sequences. + + Parameters + ---------- + lp + The `lark` object holding the parser and its configuration. + vocabulary + The vocabulary containing strings to be indexed by their lexer FSM and + parser states. + + """ + + symbols_to_states, reverse_shifts = terminal_reverse_maps(lp) + + vocab_index: Dict = {} + for v in vocabulary: + res = get_lex_token_seqs(lp, v) + + for fsm_start, token_seq in res: + # Convert token sequences to parser state sequences + state_seqs: List[List] = [] + for token in token_seq: + state_seqs = parse_to_terminal( + lp, state_seqs, token, symbols_to_states, reverse_shifts + ) + + for state_seq in state_seqs: + assert len(state_seq) > 1 + vocab_index = set_index( + # We need to remove the last states from the sequences, + # because they're irrelevant (i.e. they correspond to the + # states *after* reading a token). + # Also, the sequences need to be reversed so that we can do + # prefix searches against the top of a stack. + vocab_index, + (fsm_start,) + tuple(reversed(state_seq[:-1])), + v, + ) + + return vocab_index + + +def parse_to_terminal( + lp: Lark, + state_stacks: List[List[int]], + terminal_name: str, + symbols_to_states, + reverse_shifts, +): + """Take a single step of the parser reading `terminal_name` starting from each stack in `state_stacks`. + + In the case of reduction states, this will walk backwards and generate all + the viable originating state stacks and then walk them forward from the + reduction. + + These results can be used to determine the minimal top-of-stack states that + will accept a sequence of tokens (i.e. when `states_stack` results from + previous calls to this function using other tokens). + + Parameters + ---------- + lp + The parser configuration/information. + state_stacks + The parser state stacks. The top of each stack is on the right (i.e. + ``state_stacks[n][-1]``). + terminal_name + The terminal read by the parser. + symbols_to_states + reverse_shifts + + Returns + ------- + All of the possible state stacks resulting from the parse step. + + """ + parse_table = lp.parser.parser.parser.parse_table + + if len(state_stacks) == 0: + res = {state for state, _ in symbols_to_states[terminal_name]} + state_stacks = [[state] for state in res] + + new_stacks = [] + # Run the parser forward until a reduction is found, and + # generate new sequences with additional antecedent states + # that satisfy the reduction's requirements. + # The resulting state sequences represent the stack configuration + # prefixes that accept the tracked sequence of tokens. + candidate_stacks = list(state_stacks) + while candidate_stacks: + state_stack = candidate_stacks.pop() + current_state = state_stack[-1] + action_arg = parse_table.states[current_state].get(terminal_name) + + if action_arg is None: + # This state doesn't accept the terminal type + continue + + action, arg = action_arg + + if action is Shift: + # Move to the next state + new_stacks.append(state_stack + [arg]) + continue + else: + rule = arg + expansion_size = len(rule.expansion) + stack_size = len(state_stack) + + if stack_size > expansion_size: + # There are already enough states in this stack, so just + # walk forward + del state_stack[stack_size - expansion_size :] + + _action, new_state = parse_table.states[state_stack[-1]][ + rule.origin.name + ] + assert _action is Shift + state_stack.append(new_state) + + candidate_stacks.append(state_stack) + else: + # There aren't enough states in the stack for this + # reduction, so we need to generate stacks for this + # particular situation + + # Create state sequences that would've been popped and reduced + # TODO: Could we use wildcard states for every state but the last? + reverse_expansions = set() + expansion_candidates = [(state_stack[0], expansion_size - stack_size)] + while expansion_candidates: + expansion_state, expansion_idx = expansion_candidates.pop() + expansion_symbol = rule.expansion[expansion_idx] + for symbol, from_state in reverse_shifts[expansion_state]: + if symbol != expansion_symbol.name: + continue + next_expansion_idx = expansion_idx - 1 + if next_expansion_idx < 0: + reverse_expansions.add(from_state) + else: + expansion_candidates.append( + (from_state, next_expansion_idx) + ) + + # Now, of those sequences we just produced, find antecedents that + # accept `rule.origin.name` + next_stacks = set() + while reverse_expansions: + subsequent_state = reverse_expansions.pop() + transition = parse_table.states[subsequent_state].get( + rule.origin.name + ) + if transition is not None and transition[0] is Shift: + next_stacks.add(transition[1]) + + candidate_stacks.extend([stack] for stack in next_stacks) + + return new_stacks diff --git a/tests/text/test_parsing.py b/tests/text/test_parsing.py index 9f3e34708..1dd516275 100644 --- a/tests/text/test_parsing.py +++ b/tests/text/test_parsing.py @@ -10,14 +10,16 @@ from outlines.text.parsing import ( PartialPythonIndenter, copy_parser_state, - create_pmatch_parser_states, + create_parser_vocab_index, find_partial_matches, fsm_union, + get_index, + get_lex_token_seqs, get_sub_fsms_from_seq, map_partial_states_to_vocab, parse_to_end, + set_index, terminals_to_fsms, - terminals_to_lalr_states, ) @@ -306,69 +308,6 @@ def test_map_partial_states_to_vocab_python(): assert possible_paths["DEF"] == {0: {1}, 1: {2, 3}, 2: {3}} -def test_parse_from_partial_match(): - """Make sure we can continue parsing from an FSM-based partial match.""" - lp = Lark( - r""" -start: funcdef - -funcdef: "def" name "(" ")" ":" attr_pattern - -attr_pattern: NAME ("." NAME)+ -> value - -%ignore /[\t \f]+/ // WS - -!name: NAME | "match" | "case" -NAME: /[^\W\d]\w*/ - - - """, - parser="lalr", - postlex=PartialPythonIndenter(), - ) - - terminals_to_states = terminals_to_lalr_states(lp) - symbol_names_and_fsms = terminals_to_fsms(lp) - - term_type = "DEF" - term_fsm = symbol_names_and_fsms[term_type] - - # TODO FIXME: This is broken, and it's a bug in `lark`'s Python grammar? - # ptoken = "defx" - - ptoken = "ef foo" - pmatches = find_partial_matches(term_fsm, ptoken) - first_pmatch = next(pm for pm in pmatches if pm[1][-1] in term_fsm.finals) - (parser_state,) = create_pmatch_parser_states( - lp, terminals_to_states, term_type, ptoken, first_pmatch - ) - # These copies also patch the lexers in the parse state, which is now - # needed for use with `parse_to_end` - parser_state = copy_parser_state(parser_state) - new_parser_state, (terminal_name, fsm_state) = parse_to_end(parser_state) - assert terminal_name == "NAME" - - ptoken = "ef foo():" - pmatches = find_partial_matches(term_fsm, ptoken) - first_pmatch = next(pm for pm in pmatches if pm[1][-1] in term_fsm.finals) - (parser_state,) = create_pmatch_parser_states( - lp, terminals_to_states, term_type, ptoken, first_pmatch - ) - parser_state = copy_parser_state(parser_state) - new_parser_state, (terminal_name, fsm_state) = parse_to_end(parser_state) - assert terminal_name is None - - ptoken = "ef (" - pmatches = find_partial_matches(term_fsm, ptoken) - first_pmatch = next(pm for pm in pmatches if pm[1][-1] in term_fsm.finals) - (parser_state,) = create_pmatch_parser_states( - lp, terminals_to_states, term_type, ptoken, first_pmatch - ) - parser_state = copy_parser_state(parser_state) - with pytest.raises(UnexpectedToken): - parse_to_end(parser_state) - - def test_map_partial_states_to_vocab_regex(): regex_string = r"([0-9]+([.][0-9]*)?|[.][0-9]+)" regex_pattern = interegular.parse_pattern(regex_string) @@ -518,3 +457,289 @@ def test_get_sub_fsms_from_seq(): res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(3, False)] + + +def test_get_lex_state_seqs(): + lp = Lark( + r""" + start: funcdef | assign_stmt + + funcdef: "def" name "(" ")" ":" "pass" + + assign_stmt: name "=" name + + %ignore /[\t \f]+/ // WS + + !name: NAME | "match" | "case" + NAME: /[^\W\d]\w*/ + + """, + parser="lalr", + postlex=PartialPythonIndenter(), + ) + + assert lp.parse("a = b") + assert lp.parse("def foo(): pass") + + results = get_lex_token_seqs(lp, "oo(") + + assert results == {(0, ("NAME", "LPAR")), (1, ("NAME", "LPAR"))} + + results = get_lex_token_seqs(lp, "def ") + + assert results == {(0, ("DEF",)), (0, ("NAME",)), (1, ("NAME",))} + + results = get_lex_token_seqs(lp, " def") + + assert results == {(0, ("DEF",)), (0, ("NAME",)), (1, ("NAME",))} + + # All invalid tokens + results = get_lex_token_seqs(lp, "%") + assert not results + + # Good then invalid token + results = get_lex_token_seqs(lp, "a%") + assert not results + + # TODO: All ignored tokens + # results = get_lex_token_seqs(lp, " \n") + + +def test_index_operations(): + tindex = set_index({}, (0, 1), "oo(") + assert tindex == {0: ({"oo("}, {1: ({"oo("}, {})})} + + assert get_index(tindex, (1,)) is None + assert get_index(tindex, (0,)) == {"oo("} + assert get_index(tindex, (0, 1)) == {"oo("} + + tindex = set_index(tindex, (0,), "a") + assert tindex == {0: ({"oo(", "a"}, {1: ({"oo("}, {})})} + + assert get_index(tindex, (0, 1)) == {"oo("} + + tindex = set_index(tindex, (1,), "b") + assert tindex == {0: ({"oo(", "a"}, {1: ({"oo("}, {})}), 1: ({"b"}, {})} + + assert get_index(tindex, (0, 5)) == {"oo(", "a"} + assert get_index(tindex, (0, 1, 2)) == {"oo("} + + +def test_parse_to_terminal(): + from outlines.text.parsing import parse_to_terminal, terminal_reverse_maps + + lp = Lark( + r""" + start: funcdef | assign + + funcdef: "def" name "(" ")" ":" attr_pattern + + assign: name "=" name + + attr_pattern: NAME ("." NAME)+ -> value + + %ignore /[\t \f]+/ // WS + + !name: NAME | "match" | "case" + NAME: /[^\W\d]\w*/ + + """, + parser="lalr", + postlex=PartialPythonIndenter(), + debug=True, + ) + lp, ordered_states = make_deterministic(lp) + + def ordered_str(x): + return str(sorted(x, key=lambda x: str(x))) + + str_names_to_states = {ordered_str(state): state for state in ordered_states} + + symbols_to_states, reverse_shifts = terminal_reverse_maps(lp) + + # We start with no states and let it tell us where we could've started + res = parse_to_terminal(lp, [], "NAME", symbols_to_states, reverse_shifts) + + exp_res = { + ( + str_names_to_states[ + "[<$root_start : * start>, , , , , , , ]" + ], + str_names_to_states["[]"], + ), + ( + str_names_to_states["[<__attr_pattern_plus_0 : DOT * NAME>]"], + str_names_to_states["[<__attr_pattern_plus_0 : DOT NAME * >]"], + ), + ( + str_names_to_states[ + "[<__attr_pattern_plus_0 : __attr_pattern_plus_0 DOT * NAME>]" + ], + str_names_to_states[ + "[<__attr_pattern_plus_0 : __attr_pattern_plus_0 DOT NAME * >]" + ], + ), + ( + str_names_to_states[ + "[, , , ]" + ], + str_names_to_states["[]"], + ), + ( + str_names_to_states[ + "[, ]" + ], + str_names_to_states[ + "[<__attr_pattern_plus_0 : * DOT NAME>, <__attr_pattern_plus_0 : * __attr_pattern_plus_0 DOT NAME>, ]" + ], + ), + ( + str_names_to_states[ + "[, , , ]" + ], + str_names_to_states["[]"], + ), + } + res = {tuple(r) for r in res} + assert res == exp_res + + # A state right after reading a `NAME` token + stack = [str_names_to_states["[]"]] + + # We start here: + # {}: {'LPAR': (Reduce, + # Rule(NonTerminal(Token('RULE', 'name')), [Terminal('NAME')], None, RuleOptions(True, False, None, None))), + # '$END': (Reduce, + # Rule(NonTerminal(Token('RULE', 'name')), [Terminal('NAME')], None, RuleOptions(True, False, None, None))), + # 'EQUAL': (Reduce, + # Rule(NonTerminal(Token('RULE', 'name')), [Terminal('NAME')], None, RuleOptions(True, False, None, None)))}, + # + # Then pop that state and back up to this state: + # + # {, , , }: {'CASE': (Shift, + # {}), + # 'name': (Shift, {}), + # 'NAME': (Shift, {}), + # 'MATCH': (Shift, {})}, + # + # and move forward (via `'name'`) to: + # + # {}: {'LPAR': (Shift, + # {})}, + # + res = parse_to_terminal(lp, [stack], "LPAR", symbols_to_states, reverse_shifts) + + exp_res = [ + [ + str_names_to_states[ + "[]" + ], + str_names_to_states[ + "[]" + ], + ] + ] + + assert res == exp_res + + # Parse this result forward and make sure it matches + ip = lp.parse_interactive("def foo(") + parser_state = copy_parser_state(ip.parser_state) + parser_state, (terminal_name, fsm_state) = parse_to_end(parser_state) + assert res[0] == parser_state.state_stack[-2:] + + +@pytest.mark.xfail(reason="Not finished") +def test_create_parser_vocab_index(): + lp = Lark( + r""" + start: funcdef | assign + + funcdef: "def" name "(" ")" ":" attr_pattern + + assign: name "=" name + + attr_pattern: NAME ("." NAME)+ -> value + + %ignore /[\t \f]+/ // WS + + !name: NAME | "match" | "case" + NAME: /[^\W\d]\w*/ + + + """, + parser="lalr", + postlex=PartialPythonIndenter(), + ) + + vocabulary = [ + "oo(", + "def ", + "f ", + "):", + "%", + # " " + ] + + res = create_parser_vocab_index(lp, vocabulary) + + # TODO: Implement some assertions + assert res + assert False + + +@pytest.mark.skip(reason="Not finished, and currently very slow") +def test_sequential_parser_index(): + """Sequentially parse and make sure the index works.""" + input_tokens = [ + "x ", + "= ", + "1", + "\nde", + "f ", + "foo(", + "x)", + ":\n", + # " ", + " return x", + " + 1", + "\n", + "z ", + "= ", + "foo(", + '"hi', + '")\n', + ] + vocab = set(input_tokens) + + lp = Lark.open_from_package( + "tests", + "partial_python.lark", + ["text"], + parser="lalr", + postlex=PartialPythonIndenter(), + start="file_input", + ) + ip = lp.parse_interactive("") + parser_state = copy_parser_state(ip.parser_state) + + vocab_index = create_parser_vocab_index(lp, vocab) + + token_seq = "" + for i, token in enumerate(input_tokens): + token_seq += token + + lex_state = parser_state.lexer.state + lex_state.text = token_seq + + parser_state, (terminal_name, fsm_state) = parse_to_end(parser_state) + + # TODO: Reverse the order of the index key sequences? + next_vocab = get_index( + vocab_index, [fsm_state] + reversed(parser_state.state_stack) + ) + + if i + 1 < len(input_tokens): + assert input_tokens[i + 1] in next_vocab + else: + assert all(tk in next_vocab for tk in ["\n", "\nde", " ", " + 1"])