Skip to content

Commit dfd7bb1

Browse files
tests: add supported models in e2e (#185)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
1 parent def4e3b commit dfd7bb1

File tree

5 files changed

+118
-74
lines changed

5 files changed

+118
-74
lines changed

Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/forge_app/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ schemars = "0.8.21"
3535
anyhow = "1.0.75"
3636

3737
[dev-dependencies]
38+
futures = "0.3.31"
3839
insta = "1.41.1"
3940
pretty_assertions = "1.4.1"
4041
tempfile = "3.10.1"

crates/forge_app/src/prompts/title.md

-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ Example output structure:
4848
8. Potential objections: [List any weaknesses for each title]
4949
9. SEO alignment: [Reflect on SEO best practices for each title]
5050
10. Selected title: [Explain your choice and its alignment]
51-
11. Tool call preparation: generate_title(title: "Selected Title")
5251

5352
</title_generation_process>
5453

crates/forge_app/src/service/api.rs

+98-64
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::path::PathBuf;
12
use std::sync::Arc;
23

34
use anyhow::Result;
@@ -29,7 +30,7 @@ pub trait APIService: Send + Sync {
2930

3031
impl Service {
3132
pub async fn api_service() -> Result<impl APIService> {
32-
Live::new().await
33+
Live::new(std::env::current_dir()?).await
3334
}
3435
}
3536

@@ -45,8 +46,8 @@ struct Live {
4546
}
4647

4748
impl Live {
48-
async fn new() -> Result<Self> {
49-
let env = Service::environment_service().get().await?;
49+
async fn new(cwd: PathBuf) -> Result<Self> {
50+
let env = Service::environment_service(cwd).get().await?;
5051

5152
let cwd: String = env.cwd.clone();
5253
let provider = Arc::new(Service::provider_service(env.api_key.clone()));
@@ -146,81 +147,114 @@ impl APIService for Live {
146147

147148
#[cfg(test)]
148149
mod tests {
150+
use std::path::Path;
151+
149152
use forge_domain::ModelId;
153+
use futures::future::join_all;
150154
use tokio_stream::StreamExt;
151155

152156
use super::*;
153157

154158
#[tokio::test]
155159
async fn test_e2e() {
160+
let api = Live::new(Path::new("../../").to_path_buf()).await.unwrap();
161+
let task = include_str!("./api_task.md");
162+
156163
const MAX_RETRIES: usize = 3;
157164
const MATCH_THRESHOLD: f64 = 0.7; // 70% of crates must be found
158-
159-
let api = Live::new().await.unwrap();
160-
let task = include_str!("./api_task.md");
161-
let request = ChatRequest::new(ModelId::new("anthropic/claude-3.5-sonnet"), task);
162-
163-
let expected_crates = [
164-
"forge_app",
165-
"forge_ci",
166-
"forge_domain",
167-
"forge_main",
168-
"forge_open_router",
169-
"forge_prompt",
170-
"forge_tool",
171-
"forge_tool_macros",
172-
"forge_walker",
165+
const SUPPORTED_MODELS: &[&str] = &[
166+
"anthropic/claude-3.5-sonnet:beta",
167+
"openai/gpt-4o-2024-11-20",
168+
"anthropic/claude-3.5-sonnet",
169+
"openai/gpt-4o",
170+
"openai/gpt-4o-mini",
171+
"google/gemini-flash-1.5",
172+
"anthropic/claude-3-sonnet",
173173
];
174174

175-
let mut last_error = None;
176-
177-
for attempt in 0..MAX_RETRIES {
178-
let response = api
179-
.chat(request.clone())
180-
.await
181-
.unwrap()
182-
.filter_map(|message| match message.unwrap() {
183-
ChatResponse::Text(text) => Some(text),
184-
_ => None,
185-
})
186-
.collect::<Vec<_>>()
187-
.await
188-
.join("")
189-
.trim()
190-
.to_string();
191-
192-
let found_crates: Vec<&str> = expected_crates
193-
.iter()
194-
.filter(|&crate_name| response.contains(&format!("<crate>{}</crate>", crate_name)))
195-
.cloned()
196-
.collect();
197-
198-
let match_percentage = found_crates.len() as f64 / expected_crates.len() as f64;
199-
200-
if match_percentage >= MATCH_THRESHOLD {
201-
println!(
202-
"Successfully found {:.2}% of expected crates",
203-
match_percentage * 100.0
204-
);
205-
return;
175+
let test_futures = SUPPORTED_MODELS.iter().map(|&model| {
176+
let api = api.clone();
177+
let task = task.to_string();
178+
179+
async move {
180+
let request = ChatRequest::new(ModelId::new(model), task);
181+
let expected_crates = [
182+
"forge_app",
183+
"forge_ci",
184+
"forge_domain",
185+
"forge_main",
186+
"forge_open_router",
187+
"forge_prompt",
188+
"forge_tool",
189+
"forge_tool_macros",
190+
"forge_walker",
191+
];
192+
193+
for attempt in 0..MAX_RETRIES {
194+
let response = api
195+
.chat(request.clone())
196+
.await
197+
.unwrap()
198+
.filter_map(|message| match message.unwrap() {
199+
ChatResponse::Text(text) => Some(text),
200+
_ => None,
201+
})
202+
.collect::<Vec<_>>()
203+
.await
204+
.join("")
205+
.trim()
206+
.to_string();
207+
208+
let found_crates: Vec<&str> = expected_crates
209+
.iter()
210+
.filter(|&crate_name| {
211+
response.contains(&format!("<crate>{}</crate>", crate_name))
212+
})
213+
.cloned()
214+
.collect();
215+
216+
let match_percentage = found_crates.len() as f64 / expected_crates.len() as f64;
217+
218+
if match_percentage >= MATCH_THRESHOLD {
219+
println!(
220+
"[{}] Successfully found {:.2}% of expected crates",
221+
model,
222+
match_percentage * 100.0
223+
);
224+
return Ok::<_, String>(());
225+
}
226+
227+
if attempt < MAX_RETRIES - 1 {
228+
println!(
229+
"[{}] Attempt {}/{}: Found {}/{} crates: {:?}",
230+
model,
231+
attempt + 1,
232+
MAX_RETRIES,
233+
found_crates.len(),
234+
expected_crates.len(),
235+
found_crates
236+
);
237+
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
238+
} else {
239+
return Err(format!(
240+
"[{}] Failed: Found only {}/{} crates: {:?}",
241+
model,
242+
found_crates.len(),
243+
expected_crates.len(),
244+
found_crates
245+
));
246+
}
247+
}
248+
249+
unreachable!()
206250
}
251+
});
207252

208-
last_error = Some(format!(
209-
"Attempt {}: Only found {}/{} crates: {:?}",
210-
attempt + 1,
211-
found_crates.len(),
212-
expected_crates.len(),
213-
found_crates
214-
));
253+
let results = join_all(test_futures).await;
254+
let errors: Vec<_> = results.into_iter().filter_map(Result::err).collect();
215255

216-
// Add a small delay between retries to allow for different LLM generations
217-
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
256+
if !errors.is_empty() {
257+
panic!("Test failures:\n{}", errors.join("\n"));
218258
}
219-
220-
panic!(
221-
"Failed after {} attempts. Last error: {}",
222-
MAX_RETRIES,
223-
last_error.unwrap_or_default()
224-
);
225259
}
226260
}

crates/forge_app/src/service/env.rs

+18-9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::path::PathBuf;
2+
13
use anyhow::Result;
24
use forge_domain::Environment;
35
use forge_walker::Walker;
@@ -11,27 +13,34 @@ pub trait EnvironmentService {
1113
}
1214

1315
impl Service {
14-
pub fn environment_service() -> impl EnvironmentService {
15-
Live::new()
16+
pub fn environment_service(base_dir: PathBuf) -> impl EnvironmentService {
17+
Live::new(base_dir)
1618
}
1719
}
1820

19-
struct Live(Mutex<Option<Environment>>);
21+
struct Live {
22+
env: Mutex<Option<Environment>>,
23+
base_dir: PathBuf,
24+
}
2025

2126
impl Live {
22-
pub fn new() -> Self {
23-
Self(Mutex::new(None))
27+
pub fn new(base_dir: PathBuf) -> Self {
28+
Self { env: Mutex::new(None), base_dir }
2429
}
2530

26-
async fn from_env() -> Result<Environment> {
31+
async fn from_env(cwd: Option<PathBuf>) -> Result<Environment> {
2732
dotenv::dotenv().ok();
2833
let api_key = std::env::var("FORGE_KEY").expect("FORGE_KEY must be set");
2934
let large_model_id =
3035
std::env::var("FORGE_LARGE_MODEL").unwrap_or("anthropic/claude-3.5-sonnet".to_owned());
3136
let small_model_id =
3237
std::env::var("FORGE_SMALL_MODEL").unwrap_or("anthropic/claude-3.5-haiku".to_owned());
3338

34-
let cwd = std::env::current_dir()?;
39+
let cwd = if let Some(cwd) = cwd {
40+
cwd
41+
} else {
42+
std::env::current_dir()?
43+
};
3544
let files = match Walker::new(cwd.clone())
3645
.with_max_depth(usize::MAX)
3746
.get()
@@ -65,12 +74,12 @@ impl Live {
6574
#[async_trait::async_trait]
6675
impl EnvironmentService for Live {
6776
async fn get(&self) -> Result<Environment> {
68-
let mut guard = self.0.lock().await;
77+
let mut guard = self.env.lock().await;
6978

7079
if let Some(env) = guard.as_ref() {
7180
return Ok(env.clone());
7281
} else {
73-
*guard = Some(Live::from_env().await?);
82+
*guard = Some(Live::from_env(Some(self.base_dir.clone())).await?);
7483
Ok(guard.as_ref().unwrap().clone())
7584
}
7685
}

0 commit comments

Comments
 (0)