-
Notifications
You must be signed in to change notification settings - Fork 1
/
score.py
69 lines (44 loc) · 2.94 KB
/
score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import torch
from prompt import fill_in_query_discrimination_template, fill_in_brief_nli_discrimination_template, fill_in_nli_discrimination_template
def score_relevance(character, statement, query, discriminator):
if type(discriminator) != str:
prompt = fill_in_query_discrimination_template(character, statement, query)
scores = discriminator.classifier(**discriminator.tok(prompt, return_tensors="pt").to("cuda:0")).logits[0].softmax(-1)
else:
prompt = fill_in_query_discrimination_template(character, statement, query)
system_prompt = "You are a helpful agent to build AI characters, your job is to determine whether an utterance from the human user to a role-playing AI should be responded by including the information in the given persona statement or not."
response = openai.ChatCompletion.create(
model=discriminator,
temperature=0.0,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
).choices[0]['message']["content"]
scores = torch.FloatTensor([response.lower() == label for label in ["no", "yes"]])
return scores
def score_nli(character, statement, query, response, discriminator):
if type(discriminator) != str:
prompt = fill_in_brief_nli_discrimination_template(character, statement, query, response)
scores = discriminator.classifier(**discriminator.tok(prompt, return_tensors="pt").to("cuda:0")).logits[0].softmax(-1)
else:
system_prompt = "You are a helpful agent to build AI characters, your job is to discriminate whether the given persona statement is entailed, neutral, contradict to the response in natural language inference."
prompt = fill_in_nli_discrimination_template(character, statement, query, response)
response = openai.ChatCompletion.create(
model=discriminator,
temperature=0.0,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
).choices[0]['message']["content"]
scores = torch.FloatTensor([response.lower() == label for label in ["contradict", "neutral", "entailed"]])
return scores
def score_apc(character, statement, query, response, relevance_discriminator, nli_discriminator):
relevance_score = score_relevance(character, statement, query, relevance_discriminator)
nli_score = score_nli(character, statement, query, response, nli_discriminator)
apc_score = relevance_score[0] * (1 - nli_score[0]) + relevance_score[1] * nli_score[2]
return apc_score
def score_APC(character, statements, query, response, relevance_discriminator, nli_discriminator):
return torch.stack([score_apc(character, statement, query, response, relevance_discriminator, nli_discriminator) for statement in statements]).sum()