Skip to content

Commit 6257d0b

Browse files
committed
test: use mock tokenizer
1 parent 6771e1c commit 6257d0b

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

applications/Chat/tests/test_experience.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from coati.trainer.ppo import _set_default_generate_kwargs
1212
from coati.trainer.strategies import DDPStrategy, GeminiStrategy
1313
from coati.trainer.strategies.colossalai import LowLevelZeroStrategy
14-
from transformers import PreTrainedTokenizer
1514
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
1615

1716
from colossalai.testing import rerun_if_address_is_in_use, spawn
@@ -58,8 +57,13 @@ def make_and_consume_experience(strategy):
5857
actor, critic, initial_model, reward_model = \
5958
strategy.prepare(actor, critic, initial_model, reward_model)
6059

61-
tokenizer = PreTrainedTokenizer()
62-
tokenizer.padding_side = "left"
60+
class MockTokenizer():
61+
def __init__(self):
62+
self.padding_side = "left"
63+
self.eos_token_id = 0
64+
self.pad_token_id = 0
65+
66+
tokenizer = MockTokenizer()
6367
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer)
6468
data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
6569

applications/Chat/tests/test_models.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
1414
from coati.models.opt import OPTRM, OPTActor, OPTCritic
1515
from coati.models.utils import calc_action_log_probs, masked_mean
16-
from transformers import PreTrainedTokenizer
1716

1817

1918
@pytest.mark.gpu
@@ -38,10 +37,16 @@ def test_generation(actor_maker: Callable[[], Actor],
3837
seq_len: int,
3938
generate_kwargs: Dict[str, Any]
4039
):
40+
41+
class MockTokenizer():
42+
def __init__(self):
43+
self.padding_side = "left"
44+
self.eos_token_id = 0
45+
self.pad_token_id = 0
46+
4147
actor = actor_maker()
4248
input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
43-
tokenizer = PreTrainedTokenizer()
44-
tokenizer.padding_side = "left"
49+
tokenizer = MockTokenizer()
4550
sequences = generate(actor.cuda(), input_ids, tokenizer, **generate_kwargs)
4651
assert sequences.shape == (batch_size, generate_kwargs["max_length"])
4752

0 commit comments

Comments
 (0)