Skip to content

Commit

Permalink
Fix statistical test
Browse files Browse the repository at this point in the history
  • Loading branch information
dpsimpson authored and torymur committed Jan 16, 2025
1 parent a6a88da commit 6bedffa
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions tests/fsm/test_statistical.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from typing import Callable, List, Optional

import numpy as np
import pytest
from outlines_core.fsm import Guide, Index, Vocabulary
from pytest import approx
from scipy.stats import ks_2samp


@pytest.mark.skip("Needs fixing")
def test_generate_length():
class NextToken:
def __init__(
Expand All @@ -31,16 +29,18 @@ def __call__(
return tokens + next_t if tokens is not None else next_t

def generate(model, regex_str) -> Optional[List[int]]:
vocabulary = Vocabulary(3, {"0": [1], "1": [2], "2": [4]})
vocabulary = Vocabulary(3, {"0": [1], "1": [2]})
index = Index(regex_str, vocabulary)
guide = Guide(index)

n_tokens = len(vocabulary)
n_tokens = len(vocabulary) + 1 # include eos token in count
tokens = None
allowed = guide.get_start_tokens()
while not guide.is_finished():
while True:
mask: List[int] = [1 if s in allowed else 0 for s in range(1, n_tokens + 1)]
tokens = model(tokens, mask=mask)
if tokens[-1] == 3:
break
allowed = guide.read_next_token(tokens[-1])
return tokens

Expand Down Expand Up @@ -70,9 +70,9 @@ def prob_markov(token: List[int]) -> np.array:
lengths2: np.array = np.zeros((n_samples,))
for i in range(n_samples):
out1: List[int] = generate(model1, regex_str)
lengths1[i] = len(out1) - 1 # take off the eos token
lengths1[i] = len(out1) - 1
out2: List[int] = generate(model2, regex_str)
lengths2[i] = len(out2) - 1 # take off the eos token
lengths2[i] = len(out2) - 1

# 2 sample KS test to check that lengths has the same distribution as
# L = 1 + 2*X + Y, where X ~ Bern(0.75) and Y ~ Neg-Binom(1, 0.3)
Expand Down

0 comments on commit 6bedffa

Please sign in to comment.