-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_gpt.py
53 lines (39 loc) · 1.56 KB
/
generate_gpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# Загрузка модели и токенизатора GPT-3
# model_name = 'ai-forever/ruGPT-3.5-13B'
model_name = 'ai-forever/rugpt3large_based_on_gpt2'
cache_dir = "D:/models/huggingface"
print("Init model: " + model_name)
print("cache path: " + cache_dir)
model = GPT2LMHeadModel.from_pretrained(model_name, cache_dir=cache_dir)
tokenizer = GPT2Tokenizer.from_pretrained(model_name, cache_dir=cache_dir)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)
# Перевод модели в режим оценки
model.eval()
def generate_text(prompt, max_length=100, temperature=1.0, top_k=50):
input_ids = tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=False)
with torch.no_grad():
output = model.generate(
input_ids,
# max_length=max_length,
temperature=temperature,
max_new_tokens=max_length,
num_beams=4,
top_k=top_k,
do_sample=True
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
def main():
print("ruGPT3.5 Text Generation. Type 'exit' to quit.")
while True:
prompt = input("Enter your prompt: ")
if prompt.lower() == 'exit':
print("Exiting...")
break
generated_text = generate_text(prompt)
print("\nGenerated Text:\n" + generated_text + "\n")
if __name__ == "__main__":
main()