-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathprompt.py
78 lines (68 loc) · 2.37 KB
/
prompt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from transformers import AutoModelForCausalLM, AutoTokenizer
from empower_functions.prompt import prompt_messages
import json
device = "cuda"
model_path = "empower-dev/llama3-empower-functions-small"
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)
functions = [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
]
messages = [
{"role": "user", "content": "Hi, can you tell me the current weather in San Francisco and New York City in Fahrenheit?"},
{"role": "assistant", "tool_calls": [
{
"id": "get_current_weather_san_francisco",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": {
"location": "San Francisco, CA",
"unit": "fahrenheit"
}
}
},
{
"id": "get_current_weather_new_york",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": {
"location": "New York City, NY",
"unit": "fahrenheit"
}
}
}
]},
{
"role": "tool",
"tool_call_id": "get_current_weather_san_francisco",
"content": json.dumps({"temperature": 75})
},
{
"role": "tool",
"tool_call_id": "get_current_weather_new_york",
"content": json.dumps({"temperature": 82})
}
]
# Enable CoT(Chain of Thought) by toggling include_thinking=True
messages = prompt_messages(messages, functions, include_thinking=False)
model_inputs = tokenizer.apply_chat_template(
messages, return_tensors="pt").to(model.device)
generated_ids = model.generate(model_inputs, max_new_tokens=128)
decoded = tokenizer.batch_decode(generated_ids)
print(decoded[0])