1
+ use std:: path:: PathBuf ;
1
2
use std:: sync:: Arc ;
2
3
3
4
use anyhow:: Result ;
@@ -29,7 +30,7 @@ pub trait APIService: Send + Sync {
29
30
30
31
impl Service {
31
32
pub async fn api_service ( ) -> Result < impl APIService > {
32
- Live :: new ( ) . await
33
+ Live :: new ( std :: env :: current_dir ( ) ? ) . await
33
34
}
34
35
}
35
36
@@ -45,8 +46,8 @@ struct Live {
45
46
}
46
47
47
48
impl Live {
48
- async fn new ( ) -> Result < Self > {
49
- let env = Service :: environment_service ( ) . get ( ) . await ?;
49
+ async fn new ( cwd : PathBuf ) -> Result < Self > {
50
+ let env = Service :: environment_service ( cwd ) . get ( ) . await ?;
50
51
51
52
let cwd: String = env. cwd . clone ( ) ;
52
53
let provider = Arc :: new ( Service :: provider_service ( env. api_key . clone ( ) ) ) ;
@@ -146,81 +147,114 @@ impl APIService for Live {
146
147
147
148
#[ cfg( test) ]
148
149
mod tests {
150
+ use std:: path:: Path ;
151
+
149
152
use forge_domain:: ModelId ;
153
+ use futures:: future:: join_all;
150
154
use tokio_stream:: StreamExt ;
151
155
152
156
use super :: * ;
153
157
154
158
#[ tokio:: test]
155
159
async fn test_e2e ( ) {
160
+ let api = Live :: new ( Path :: new ( "../../" ) . to_path_buf ( ) ) . await . unwrap ( ) ;
161
+ let task = include_str ! ( "./api_task.md" ) ;
162
+
156
163
const MAX_RETRIES : usize = 3 ;
157
164
const MATCH_THRESHOLD : f64 = 0.7 ; // 70% of crates must be found
158
-
159
- let api = Live :: new ( ) . await . unwrap ( ) ;
160
- let task = include_str ! ( "./api_task.md" ) ;
161
- let request = ChatRequest :: new ( ModelId :: new ( "anthropic/claude-3.5-sonnet" ) , task) ;
162
-
163
- let expected_crates = [
164
- "forge_app" ,
165
- "forge_ci" ,
166
- "forge_domain" ,
167
- "forge_main" ,
168
- "forge_open_router" ,
169
- "forge_prompt" ,
170
- "forge_tool" ,
171
- "forge_tool_macros" ,
172
- "forge_walker" ,
165
+ const SUPPORTED_MODELS : & [ & str ] = & [
166
+ "anthropic/claude-3.5-sonnet:beta" ,
167
+ "openai/gpt-4o-2024-11-20" ,
168
+ "anthropic/claude-3.5-sonnet" ,
169
+ "openai/gpt-4o" ,
170
+ "openai/gpt-4o-mini" ,
171
+ "google/gemini-flash-1.5" ,
172
+ "anthropic/claude-3-sonnet" ,
173
173
] ;
174
174
175
- let mut last_error = None ;
176
-
177
- for attempt in 0 ..MAX_RETRIES {
178
- let response = api
179
- . chat ( request. clone ( ) )
180
- . await
181
- . unwrap ( )
182
- . filter_map ( |message| match message. unwrap ( ) {
183
- ChatResponse :: Text ( text) => Some ( text) ,
184
- _ => None ,
185
- } )
186
- . collect :: < Vec < _ > > ( )
187
- . await
188
- . join ( "" )
189
- . trim ( )
190
- . to_string ( ) ;
191
-
192
- let found_crates: Vec < & str > = expected_crates
193
- . iter ( )
194
- . filter ( |& crate_name| response. contains ( & format ! ( "<crate>{}</crate>" , crate_name) ) )
195
- . cloned ( )
196
- . collect ( ) ;
197
-
198
- let match_percentage = found_crates. len ( ) as f64 / expected_crates. len ( ) as f64 ;
199
-
200
- if match_percentage >= MATCH_THRESHOLD {
201
- println ! (
202
- "Successfully found {:.2}% of expected crates" ,
203
- match_percentage * 100.0
204
- ) ;
205
- return ;
175
+ let test_futures = SUPPORTED_MODELS . iter ( ) . map ( |& model| {
176
+ let api = api. clone ( ) ;
177
+ let task = task. to_string ( ) ;
178
+
179
+ async move {
180
+ let request = ChatRequest :: new ( ModelId :: new ( model) , task) ;
181
+ let expected_crates = [
182
+ "forge_app" ,
183
+ "forge_ci" ,
184
+ "forge_domain" ,
185
+ "forge_main" ,
186
+ "forge_open_router" ,
187
+ "forge_prompt" ,
188
+ "forge_tool" ,
189
+ "forge_tool_macros" ,
190
+ "forge_walker" ,
191
+ ] ;
192
+
193
+ for attempt in 0 ..MAX_RETRIES {
194
+ let response = api
195
+ . chat ( request. clone ( ) )
196
+ . await
197
+ . unwrap ( )
198
+ . filter_map ( |message| match message. unwrap ( ) {
199
+ ChatResponse :: Text ( text) => Some ( text) ,
200
+ _ => None ,
201
+ } )
202
+ . collect :: < Vec < _ > > ( )
203
+ . await
204
+ . join ( "" )
205
+ . trim ( )
206
+ . to_string ( ) ;
207
+
208
+ let found_crates: Vec < & str > = expected_crates
209
+ . iter ( )
210
+ . filter ( |& crate_name| {
211
+ response. contains ( & format ! ( "<crate>{}</crate>" , crate_name) )
212
+ } )
213
+ . cloned ( )
214
+ . collect ( ) ;
215
+
216
+ let match_percentage = found_crates. len ( ) as f64 / expected_crates. len ( ) as f64 ;
217
+
218
+ if match_percentage >= MATCH_THRESHOLD {
219
+ println ! (
220
+ "[{}] Successfully found {:.2}% of expected crates" ,
221
+ model,
222
+ match_percentage * 100.0
223
+ ) ;
224
+ return Ok :: < _ , String > ( ( ) ) ;
225
+ }
226
+
227
+ if attempt < MAX_RETRIES - 1 {
228
+ println ! (
229
+ "[{}] Attempt {}/{}: Found {}/{} crates: {:?}" ,
230
+ model,
231
+ attempt + 1 ,
232
+ MAX_RETRIES ,
233
+ found_crates. len( ) ,
234
+ expected_crates. len( ) ,
235
+ found_crates
236
+ ) ;
237
+ tokio:: time:: sleep ( std:: time:: Duration :: from_secs ( 2 ) ) . await ;
238
+ } else {
239
+ return Err ( format ! (
240
+ "[{}] Failed: Found only {}/{} crates: {:?}" ,
241
+ model,
242
+ found_crates. len( ) ,
243
+ expected_crates. len( ) ,
244
+ found_crates
245
+ ) ) ;
246
+ }
247
+ }
248
+
249
+ unreachable ! ( )
206
250
}
251
+ } ) ;
207
252
208
- last_error = Some ( format ! (
209
- "Attempt {}: Only found {}/{} crates: {:?}" ,
210
- attempt + 1 ,
211
- found_crates. len( ) ,
212
- expected_crates. len( ) ,
213
- found_crates
214
- ) ) ;
253
+ let results = join_all ( test_futures) . await ;
254
+ let errors: Vec < _ > = results. into_iter ( ) . filter_map ( Result :: err) . collect ( ) ;
215
255
216
- // Add a small delay between retries to allow for different LLM generations
217
- tokio :: time :: sleep ( std :: time :: Duration :: from_secs ( 2 ) ) . await ;
256
+ if !errors . is_empty ( ) {
257
+ panic ! ( "Test failures: \n {}" , errors . join ( " \n " ) ) ;
218
258
}
219
-
220
- panic ! (
221
- "Failed after {} attempts. Last error: {}" ,
222
- MAX_RETRIES ,
223
- last_error. unwrap_or_default( )
224
- ) ;
225
259
}
226
260
}
0 commit comments