forked from facebookresearch/SimulEval
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfirst_agent.py
26 lines (20 loc) · 811 Bytes
/
first_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
from simuleval.utils import entrypoint
from simuleval.agents import TextToTextAgent
from simuleval.agents.actions import ReadAction, WriteAction
@entrypoint
class DummyWaitkTextAgent(TextToTextAgent):
waitk = 3
vocab = [chr(i) for i in range(ord("A"), ord("Z") + 1)]
def policy(self):
lagging = len(self.states.source) - len(self.states.target)
if lagging >= self.waitk or self.states.source_finished:
prediction = random.choice(self.vocab)
return WriteAction(prediction, finished=(lagging <= 1))
else:
return ReadAction()