diff --git a/tests/fsm/test_statistical.py b/tests/fsm/test_statistical.py index acb1b54c..7c4af98a 100644 --- a/tests/fsm/test_statistical.py +++ b/tests/fsm/test_statistical.py @@ -1,13 +1,11 @@ from typing import Callable, List, Optional import numpy as np -import pytest from outlines_core.fsm import Guide, Index, Vocabulary from pytest import approx from scipy.stats import ks_2samp -@pytest.mark.skip("Needs fixing") def test_generate_length(): class NextToken: def __init__( @@ -31,16 +29,18 @@ def __call__( return tokens + next_t if tokens is not None else next_t def generate(model, regex_str) -> Optional[List[int]]: - vocabulary = Vocabulary(3, {"0": [1], "1": [2], "2": [4]}) + vocabulary = Vocabulary(3, {"0": [1], "1": [2]}) index = Index(regex_str, vocabulary) guide = Guide(index) - n_tokens = len(vocabulary) + n_tokens = len(vocabulary) + 1 # include eos token in count tokens = None allowed = guide.get_start_tokens() - while not guide.is_finished(): + while True: mask: List[int] = [1 if s in allowed else 0 for s in range(1, n_tokens + 1)] tokens = model(tokens, mask=mask) + if tokens[-1] == 3: + break allowed = guide.read_next_token(tokens[-1]) return tokens @@ -70,9 +70,9 @@ def prob_markov(token: List[int]) -> np.array: lengths2: np.array = np.zeros((n_samples,)) for i in range(n_samples): out1: List[int] = generate(model1, regex_str) - lengths1[i] = len(out1) - 1 # take off the eos token + lengths1[i] = len(out1) - 1 out2: List[int] = generate(model2, regex_str) - lengths2[i] = len(out2) - 1 # take off the eos token + lengths2[i] = len(out2) - 1 # 2 sample KS test to check that lengths has the same distribution as # L = 1 + 2*X + Y, where X ~ Bern(0.75) and Y ~ Neg-Binom(1, 0.3)