Skip to content

Commit

Permalink
feat: Add optional tool_call_id to Message struct and update serializ…
Browse files Browse the repository at this point in the history
…ation tests

Signed-off-by: Eden Reich <eden.reich@gmail.com>
  • Loading branch information
edenreich committed Feb 9, 2025
1 parent 0fc63e0 commit d51e911
Showing 1 changed file with 49 additions and 2 deletions.
51 changes: 49 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ pub struct Message {
pub role: MessageRole,
/// Content of the message
pub content: String,
/// Unique identifier of the tool call
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}

/// Tool to use for generation
Expand Down Expand Up @@ -546,6 +549,40 @@ mod tests {
}
}

#[test]
fn test_message_serialization_with_tool_call_id() {
let message_with_tool = Message {
role: MessageRole::Tool,
content: "The weather is sunny".to_string(),
tool_call_id: Some("call_123".to_string()),
};

let serialized = serde_json::to_string(&message_with_tool).unwrap();
let expected_with_tool =
r#"{"role":"tool","content":"The weather is sunny","tool_call_id":"call_123"}"#;
assert_eq!(serialized, expected_with_tool);

let message_without_tool = Message {
role: MessageRole::User,
content: "What's the weather?".to_string(),
tool_call_id: None,
};

let serialized = serde_json::to_string(&message_without_tool).unwrap();
let expected_without_tool = r#"{"role":"user","content":"What's the weather?"}"#;
assert_eq!(serialized, expected_without_tool);

let deserialized: Message = serde_json::from_str(expected_with_tool).unwrap();
assert_eq!(deserialized.role, MessageRole::Tool);
assert_eq!(deserialized.content, "The weather is sunny");
assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));

let deserialized: Message = serde_json::from_str(expected_without_tool).unwrap();
assert_eq!(deserialized.role, MessageRole::User);
assert_eq!(deserialized.content, "What's the weather?");
assert_eq!(deserialized.tool_call_id, None);
}

#[test]
fn test_provider_display() {
let providers = vec![
Expand All @@ -571,10 +608,12 @@ mod tests {
Message {
role: MessageRole::System,
content: "You are a helpful assistant.".to_string(),
tool_call_id: None,
},
Message {
role: MessageRole::User,
content: "What is the current weather in Toronto?".to_string(),
tool_call_id: None,
},
],
stream: false,
Expand Down Expand Up @@ -644,7 +683,6 @@ mod tests {
async fn test_authentication_header() -> Result<(), GatewayError> {
let mut server = Server::new_async().await;

// Test with token
let mock_with_auth = server
.mock("GET", "/llms")
.match_header("authorization", "Bearer test-token")
Expand All @@ -658,7 +696,6 @@ mod tests {
client.list_models().await?;
mock_with_auth.assert();

// Test without token
let mock_without_auth = server
.mock("GET", "/llms")
.match_header("authorization", Matcher::Missing)
Expand Down Expand Up @@ -784,6 +821,7 @@ mod tests {
let messages = vec![Message {
role: MessageRole::User,
content: "Hello".to_string(),
tool_call_id: None,
}];
let response = client
.generate_content(Provider::Ollama, "llama2", messages, None)
Expand Down Expand Up @@ -830,6 +868,7 @@ mod tests {
let messages = vec![Message {
role: MessageRole::User,
content: "Hello".to_string(),
tool_call_id: None,
}];

let response = client
Expand Down Expand Up @@ -864,6 +903,7 @@ mod tests {
let messages = vec![Message {
role: MessageRole::User,
content: "Hello".to_string(),
tool_call_id: None,
}];
let error = client
.generate_content(Provider::Groq, "mixtral-8x7b", messages, None)
Expand Down Expand Up @@ -952,6 +992,7 @@ mod tests {
let messages = vec![Message {
role: MessageRole::User,
content: "Hello".to_string(),
tool_call_id: None,
}];

let response = client
Expand Down Expand Up @@ -989,6 +1030,7 @@ mod tests {
let messages = vec![Message {
role: MessageRole::User,
content: "Test message".to_string(),
tool_call_id: None,
}];

let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
Expand Down Expand Up @@ -1043,6 +1085,7 @@ mod tests {
let messages = vec![Message {
role: MessageRole::User,
content: "Test message".to_string(),
tool_call_id: None,
}];

let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
Expand Down Expand Up @@ -1114,6 +1157,7 @@ mod tests {
let messages = vec![Message {
role: MessageRole::User,
content: "What's the weather in London?".to_string(),
tool_call_id: None,
}];

let response = client
Expand Down Expand Up @@ -1168,6 +1212,7 @@ mod tests {
let messages = vec![Message {
role: MessageRole::User,
content: "Hi".to_string(),
tool_call_id: None,
}];

let response = client
Expand Down Expand Up @@ -1254,10 +1299,12 @@ mod tests {
Message {
role: MessageRole::System,
content: "You are a helpful assistant.".to_string(),
tool_call_id: None,
},
Message {
role: MessageRole::User,
content: "What is the current weather in Toronto?".to_string(),
tool_call_id: None,
},
];

Expand Down

0 comments on commit d51e911

Please sign in to comment.