Skip to content

Commit

Permalink
Always return last read character position in find_partial_matches
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jul 26, 2023
1 parent 36d8b3d commit 5b97beb
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 19 deletions.
14 changes: 7 additions & 7 deletions outlines/text/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,10 @@ def find_partial_matches(
Returns
-------
A set of tuples corresponding to each valid starting state in the FSM. The
first element of each tuple contains either ``None`` or an integer
indicating the position in `input_string` at which the FSM terminated. The
second element is the tuple of states visited during execution of the FSM
plus the next, unvisited transition state.
first element of each tuple contains an integer indicating the position in
`input_string` at which the FSM stopped. The second element is the tuple
of states visited during execution of the FSM plus the next, unvisited
transition state.
"""
if len(input_string) == 0 or input_string[0] not in fsm.alphabet:
Expand Down Expand Up @@ -357,17 +357,17 @@ def _partial_match(
if not terminated and state == fsm.initial:
return None, None

return None if not terminated else i, accepted_states
return i, accepted_states

res = set()
transition_maps = (
fsm.map if start_state is None else {start_state: fsm.map[start_state]}
)
for state, trans in transition_maps.items():
if trans_key in trans:
n_matched, path = _partial_match(trans)
last_match_idx, path = _partial_match(trans)
if path is not None:
res.add((n_matched, (state,) + path))
res.add((last_match_idx, (state,) + path))

return res

Expand Down
22 changes: 10 additions & 12 deletions tests/text/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def test_sequential_parse_example():
assert all(tk in next_vocab for tk in ["\n", "\nde", " ", " + 1"])


def test_partial_match():
def test_find_partial_matches():
name_pattern = interegular.parse_pattern(r"[^\W\d]\w*")
name_fsm = name_pattern.to_fsm().reduce()
assert name_fsm.initial == 0
Expand All @@ -201,12 +201,12 @@ def test_partial_match():
assert def_fsm.initial == 0

assert find_partial_matches(def_fsm, "def") == {(2, (0, 1, 2, 3))}
assert find_partial_matches(def_fsm, "de") == {(None, (0, 1, 2))}
assert find_partial_matches(def_fsm, "d") == {(None, (0, 1))}
assert find_partial_matches(def_fsm, "de") == {(1, (0, 1, 2))}
assert find_partial_matches(def_fsm, "d") == {(0, (0, 1))}
assert find_partial_matches(def_fsm, "") == set()
assert find_partial_matches(def_fsm, "df") == set()
assert find_partial_matches(def_fsm, "ef") == {(1, (1, 2, 3))}
assert find_partial_matches(def_fsm, "e") == {(None, (1, 2))}
assert find_partial_matches(def_fsm, "e") == {(0, (1, 2))}
assert find_partial_matches(def_fsm, "f") == {(0, (2, 3))}
assert find_partial_matches(def_fsm, "ef foo") == {(1, (1, 2, 3))}

Expand Down Expand Up @@ -235,11 +235,9 @@ def test_partial_match():
# from adequately reproducing the exact state sequences in this case.
# It seems to stem from `_CharGroup`s and the FSM map construction process.
res = find_partial_matches(float_fsm, ".")
assert {v[0] for v in res} == {0, 0, None}
# Make sure that the terminated sequences actually end in final states
assert all(v[1][-1] in float_fsm.finals for v in res if v[0] == 0)
# Make sure that the non-terminated sequences don't end in final states
assert all(v[1][-1] not in float_fsm.finals for v in res if v[0] != 0)
assert {v[0] for v in res} == {0, 0, 0}
assert sum(v[1][-1] in float_fsm.finals for v in res) == 2
assert sum(v[1][-1] not in float_fsm.finals for v in res) == 1


def test_map_partial_states_to_vocab_python():
Expand Down Expand Up @@ -340,7 +338,7 @@ def test_parse_from_partial_match():

ptoken = "ef foo"
pmatches = find_partial_matches(term_fsm, ptoken)
first_pmatch = next(pm for pm in pmatches if pm[0] is not None)
first_pmatch = next(pm for pm in pmatches if pm[1][-1] in term_fsm.finals)
(parser_state,) = create_pmatch_parser_states(
lp, terminals_to_states, term_type, ptoken, first_pmatch
)
Expand All @@ -352,7 +350,7 @@ def test_parse_from_partial_match():

ptoken = "ef foo():"
pmatches = find_partial_matches(term_fsm, ptoken)
first_pmatch = next(pm for pm in pmatches if pm[0] is not None)
first_pmatch = next(pm for pm in pmatches if pm[1][-1] in term_fsm.finals)
(parser_state,) = create_pmatch_parser_states(
lp, terminals_to_states, term_type, ptoken, first_pmatch
)
Expand All @@ -362,7 +360,7 @@ def test_parse_from_partial_match():

ptoken = "ef ("
pmatches = find_partial_matches(term_fsm, ptoken)
first_pmatch = next(pm for pm in pmatches if pm[0] is not None)
first_pmatch = next(pm for pm in pmatches if pm[1][-1] in term_fsm.finals)
(parser_state,) = create_pmatch_parser_states(
lp, terminals_to_states, term_type, ptoken, first_pmatch
)
Expand Down

0 comments on commit 5b97beb

Please sign in to comment.