Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests: add supported models in e2e #185

Merged
merged 5 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/forge_app/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ schemars = "0.8.21"
anyhow = "1.0.75"

[dev-dependencies]
futures = "0.3.31"
insta = "1.41.1"
pretty_assertions = "1.4.1"
tempfile = "3.10.1"
1 change: 0 additions & 1 deletion crates/forge_app/src/prompts/title.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ Example output structure:
8. Potential objections: [List any weaknesses for each title]
9. SEO alignment: [Reflect on SEO best practices for each title]
10. Selected title: [Explain your choice and its alignment]
11. Tool call preparation: generate_title(title: "Selected Title")

</title_generation_process>

Expand Down
162 changes: 98 additions & 64 deletions crates/forge_app/src/service/api.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::path::PathBuf;
use std::sync::Arc;

use anyhow::Result;
Expand Down Expand Up @@ -29,7 +30,7 @@ pub trait APIService: Send + Sync {

impl Service {
pub async fn api_service() -> Result<impl APIService> {
Live::new().await
Live::new(std::env::current_dir()?).await
}
}

Expand All @@ -45,8 +46,8 @@ struct Live {
}

impl Live {
async fn new() -> Result<Self> {
let env = Service::environment_service().get().await?;
async fn new(cwd: PathBuf) -> Result<Self> {
let env = Service::environment_service(cwd).get().await?;

let cwd: String = env.cwd.clone();
let provider = Arc::new(Service::provider_service(env.api_key.clone()));
Expand Down Expand Up @@ -145,81 +146,114 @@ impl APIService for Live {

#[cfg(test)]
mod tests {
use std::path::Path;

use forge_domain::ModelId;
use futures::future::join_all;
use tokio_stream::StreamExt;

use super::*;

#[tokio::test]
async fn test_e2e() {
let api = Live::new(Path::new("../../").to_path_buf()).await.unwrap();
let task = include_str!("./api_task.md");

const MAX_RETRIES: usize = 3;
const MATCH_THRESHOLD: f64 = 0.7; // 70% of crates must be found

let api = Live::new().await.unwrap();
let task = include_str!("./api_task.md");
let request = ChatRequest::new(ModelId::new("anthropic/claude-3.5-sonnet"), task);

let expected_crates = [
"forge_app",
"forge_ci",
"forge_domain",
"forge_main",
"forge_open_router",
"forge_prompt",
"forge_tool",
"forge_tool_macros",
"forge_walker",
const SUPPORTED_MODELS: &[&str] = &[
"anthropic/claude-3.5-sonnet:beta",
"openai/gpt-4o-2024-11-20",
"anthropic/claude-3.5-sonnet",
"openai/gpt-4o",
"openai/gpt-4o-mini",
"google/gemini-flash-1.5",
"anthropic/claude-3-sonnet",
];

let mut last_error = None;

for attempt in 0..MAX_RETRIES {
let response = api
.chat(request.clone())
.await
.unwrap()
.filter_map(|message| match message.unwrap() {
ChatResponse::Text(text) => Some(text),
_ => None,
})
.collect::<Vec<_>>()
.await
.join("")
.trim()
.to_string();

let found_crates: Vec<&str> = expected_crates
.iter()
.filter(|&crate_name| response.contains(&format!("<crate>{}</crate>", crate_name)))
.cloned()
.collect();

let match_percentage = found_crates.len() as f64 / expected_crates.len() as f64;

if match_percentage >= MATCH_THRESHOLD {
println!(
"Successfully found {:.2}% of expected crates",
match_percentage * 100.0
);
return;
let test_futures = SUPPORTED_MODELS.iter().map(|&model| {
let api = api.clone();
let task = task.to_string();

async move {
let request = ChatRequest::new(ModelId::new(model), task);
let expected_crates = [
"forge_app",
"forge_ci",
"forge_domain",
"forge_main",
"forge_open_router",
"forge_prompt",
"forge_tool",
"forge_tool_macros",
"forge_walker",
];

for attempt in 0..MAX_RETRIES {
let response = api
.chat(request.clone())
.await
.unwrap()
.filter_map(|message| match message.unwrap() {
ChatResponse::Text(text) => Some(text),
_ => None,
})
.collect::<Vec<_>>()
.await
.join("")
.trim()
.to_string();

let found_crates: Vec<&str> = expected_crates
.iter()
.filter(|&crate_name| {
response.contains(&format!("<crate>{}</crate>", crate_name))
})
.cloned()
.collect();

let match_percentage = found_crates.len() as f64 / expected_crates.len() as f64;

if match_percentage >= MATCH_THRESHOLD {
println!(
"[{}] Successfully found {:.2}% of expected crates",
model,
match_percentage * 100.0
);
return Ok::<_, String>(());
}

if attempt < MAX_RETRIES - 1 {
println!(
"[{}] Attempt {}/{}: Found {}/{} crates: {:?}",
model,
attempt + 1,
MAX_RETRIES,
found_crates.len(),
expected_crates.len(),
found_crates
);
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
} else {
return Err(format!(
"[{}] Failed: Found only {}/{} crates: {:?}",
model,
found_crates.len(),
expected_crates.len(),
found_crates
));
}
}

unreachable!()
}
});

last_error = Some(format!(
"Attempt {}: Only found {}/{} crates: {:?}",
attempt + 1,
found_crates.len(),
expected_crates.len(),
found_crates
));
let results = join_all(test_futures).await;
let errors: Vec<_> = results.into_iter().filter_map(Result::err).collect();

// Add a small delay between retries to allow for different LLM generations
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
if !errors.is_empty() {
panic!("Test failures:\n{}", errors.join("\n"));
}

panic!(
"Failed after {} attempts. Last error: {}",
MAX_RETRIES,
last_error.unwrap_or_default()
);
}
}
27 changes: 18 additions & 9 deletions crates/forge_app/src/service/env.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::path::PathBuf;

use anyhow::Result;
use forge_domain::Environment;
use forge_walker::Walker;
Expand All @@ -11,27 +13,34 @@ pub trait EnvironmentService {
}

impl Service {
pub fn environment_service() -> impl EnvironmentService {
Live::new()
pub fn environment_service(base_dir: PathBuf) -> impl EnvironmentService {
Live::new(base_dir)
}
}

struct Live(Mutex<Option<Environment>>);
struct Live {
env: Mutex<Option<Environment>>,
base_dir: PathBuf,
}

impl Live {
pub fn new() -> Self {
Self(Mutex::new(None))
pub fn new(base_dir: PathBuf) -> Self {
Self { env: Mutex::new(None), base_dir }
}

async fn from_env() -> Result<Environment> {
async fn from_env(cwd: Option<PathBuf>) -> Result<Environment> {
dotenv::dotenv().ok();
let api_key = std::env::var("FORGE_KEY").expect("FORGE_KEY must be set");
let large_model_id =
std::env::var("FORGE_LARGE_MODEL").unwrap_or("anthropic/claude-3.5-sonnet".to_owned());
let small_model_id =
std::env::var("FORGE_SMALL_MODEL").unwrap_or("anthropic/claude-3.5-haiku".to_owned());

let cwd = std::env::current_dir()?;
let cwd = if let Some(cwd) = cwd {
cwd
} else {
std::env::current_dir()?
};
let files = match Walker::new(cwd.clone())
.with_max_depth(usize::MAX)
.get()
Expand Down Expand Up @@ -65,12 +74,12 @@ impl Live {
#[async_trait::async_trait]
impl EnvironmentService for Live {
async fn get(&self) -> Result<Environment> {
let mut guard = self.0.lock().await;
let mut guard = self.env.lock().await;

if let Some(env) = guard.as_ref() {
return Ok(env.clone());
} else {
*guard = Some(Live::from_env().await?);
*guard = Some(Live::from_env(Some(self.base_dir.clone())).await?);
Ok(guard.as_ref().unwrap().clone())
}
}
Expand Down
Loading