-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtest.py
48 lines (35 loc) · 1014 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""
a test script to test MWEs in the `surprisal` module
"""
from matplotlib import pyplot as plt
import surprisal
g = surprisal.AutoHuggingFaceModel.from_pretrained(model_id="gpt2")
# b = surprisal.AutoHuggingFaceModel.from_pretrained(model_id="bert-base-uncased")
stims = [
"I am a cat on the mat",
# "The cat sat on the mat.",
# "The cat sat on the pizza.",
# "How likely is a spicy donkey?",
# "How likely is a spicy clock?",
# "How likely is a spicy dish?",
# "How likely is a spicy computer?",
# "How likely is a spicy burrito?",
]
surps = [*g.surprise(stims), *g.surprise(stims, use_bos_token=False)]
f, a = plt.subplots()
for surp in surps:
print(surp)
surp.lineplot(
f,
a,
# cumulative=True
)
# break
plt.show()
*_, surp = surps
print(f"tokens: {surp}")
for wslc in [0, 1, slice(0, 1)]:
print(f"span of interest (word index): {wslc}")
print(f"recovered surprisal: {surp[wslc, 'word']}")
print("=" * 32)
pass