diff --git a/app-server/src/api/v1/mod.rs b/app-server/src/api/v1/mod.rs index e8241fbf..24145de9 100644 --- a/app-server/src/api/v1/mod.rs +++ b/app-server/src/api/v1/mod.rs @@ -3,5 +3,6 @@ pub mod evaluations; pub mod machine_manager; pub mod metrics; pub mod pipelines; +pub mod queues; pub mod semantic_search; pub mod traces; diff --git a/app-server/src/api/v1/queues.rs b/app-server/src/api/v1/queues.rs new file mode 100644 index 00000000..903f5c02 --- /dev/null +++ b/app-server/src/api/v1/queues.rs @@ -0,0 +1,134 @@ +use std::collections::HashMap; + +use actix_web::{post, web, HttpResponse}; +use chrono::Utc; +use serde::Deserialize; +use serde_json::Value; +use uuid::Uuid; + +use crate::{ + cache::Cache, + db::{ + project_api_keys::ProjectApiKey, + spans::{Span, SpanType}, + trace::TraceType, + DB, + }, + evaluations::utils::LabelingQueueEntry, + features::{is_feature_enabled, Feature}, + routes::types::ResponseResult, + traces::span_attributes::ASSOCIATION_PROPERTIES_PREFIX, +}; + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct PushItem { + #[serde(default = "Uuid::new_v4")] + id: Uuid, + #[serde(default)] + name: String, + #[serde(default)] + attributes: HashMap, + #[serde(default)] + input: Option, + #[serde(default)] + output: Option, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct PushToQueueRequest { + items: Vec, + queue_name: String, +} + +#[post("/queues/push")] +async fn push_to_queue( + project_api_key: ProjectApiKey, + req: web::Json, + db: web::Data, + cache: web::Data, +) -> ResponseResult { + let db = db.into_inner(); + let req = req.into_inner(); + let project_id = project_api_key.project_id; + let queue_name = req.queue_name; + let request_items = req.items; + let cache = cache.into_inner(); + + let Some(queue) = + crate::db::labeling_queues::get_labeling_queue_by_name(&db.pool, &queue_name, &project_id) + .await? + else { + return Ok(HttpResponse::NotFound().body(format!("Queue not found: {}", queue_name))); + }; + + let mut span_ids = Vec::with_capacity(request_items.len()); + let num_spans = request_items.len(); + crate::db::stats::add_spans_and_events_to_project_usage_stats( + &db.pool, + &project_id, + num_spans as i64, + 0, + ) + .await?; + if is_feature_enabled(Feature::UsageLimit) { + if let Ok(limits_exceeded) = + crate::traces::limits::update_workspace_limit_exceeded_by_project_id( + db.clone(), + cache.clone(), + project_id, + ) + .await + { + if limits_exceeded.spans { + return Ok(HttpResponse::TooManyRequests().body("Workspace span limit exceeded")); + } + } + } + + for request_item in request_items { + let mut attributes = request_item.attributes; + attributes.insert( + format!("{ASSOCIATION_PROPERTIES_PREFIX}.trace_type"), + // Temporary, in order not to show spans in the default trace view + serde_json::to_value(TraceType::EVENT).unwrap(), + ); + let mut span = Span { + span_id: request_item.id, + trace_id: Uuid::new_v4(), + parent_span_id: None, + name: request_item.name, + start_time: Utc::now(), + end_time: Utc::now(), + attributes: serde_json::to_value(attributes).unwrap(), + span_type: SpanType::DEFAULT, + input: request_item.input, + output: request_item.output, + events: None, + labels: None, + }; + + let span_usage = crate::traces::utils::get_llm_usage_for_span( + &mut span.get_attributes(), + db.clone(), + cache.clone(), + ) + .await; + + crate::traces::utils::record_span_to_db(db.clone(), &span_usage, &project_id, &mut span) + .await?; + span_ids.push(span.span_id); + } + + let queue_entries = span_ids + .iter() + .map(|span_id| LabelingQueueEntry { + span_id: span_id.clone(), + action: Value::Null, + }) + .collect::>(); + crate::db::labeling_queues::push_to_labeling_queue(&db.pool, &queue.id, &queue_entries).await?; + + Ok(HttpResponse::Ok().body("Items uploaded successfully")) +} diff --git a/app-server/src/main.rs b/app-server/src/main.rs index 20ae9c92..97ae0d8d 100644 --- a/app-server/src/main.rs +++ b/app-server/src/main.rs @@ -428,10 +428,10 @@ fn main() -> anyhow::Result<()> { .service(api::v1::evaluations::create_evaluation) .service(api::v1::metrics::process_metrics) .service(api::v1::semantic_search::semantic_search) + .service(api::v1::queues::push_to_queue) .service(api::v1::machine_manager::start_machine) .service(api::v1::machine_manager::terminate_machine) - .service(api::v1::machine_manager::execute_computer_action) - .app_data(PayloadConfig::new(10 * 1024 * 1024)), + .service(api::v1::machine_manager::execute_computer_action), ) // Scopes with generic auth .service( diff --git a/frontend/app/api/projects/[projectId]/queues/[queueId]/remove/route.ts b/frontend/app/api/projects/[projectId]/queues/[queueId]/remove/route.ts index 87ae4726..264b8d80 100644 --- a/frontend/app/api/projects/[projectId]/queues/[queueId]/remove/route.ts +++ b/frontend/app/api/projects/[projectId]/queues/[queueId]/remove/route.ts @@ -31,10 +31,10 @@ const removeQueueItemSchema = z.object({ }), reasoning: z.string().optional().nullable() })), - action: z.object({ + action: z.null().or(z.object({ resultId: z.string().optional(), datasetId: z.string().optional() - }) + })) }); // remove an item from the queue @@ -61,16 +61,18 @@ export async function POST(request: Request, { params }: { params: { projectId: labelSource: "MANUAL" as const, })); - const insertedLabels = await db.insert(labels).values(newLabels).onConflictDoUpdate({ - target: [labels.spanId, labels.classId, labels.userId], - set: { - value: sql`excluded.value`, - labelSource: sql`excluded.label_source`, - reasoning: sql`COALESCE(excluded.reasoning, labels.reasoning)`, - } - }).returning(); + const insertedLabels = newLabels.length > 0 + ? await db.insert(labels).values(newLabels).onConflictDoUpdate({ + target: [labels.spanId, labels.classId, labels.userId], + set: { + value: sql`excluded.value`, + labelSource: sql`excluded.label_source`, + reasoning: sql`COALESCE(excluded.reasoning, labels.reasoning)`, + } + }).returning() + : []; - if (action.resultId) { + if (action?.resultId) { const resultId = action.resultId; const userName = user.name ? ` (${user.name})` : ''; @@ -119,8 +121,7 @@ export async function POST(request: Request, { params }: { params: { projectId: } } - if (action.datasetId) { - + if (action?.datasetId) { const span = await db.query.spans.findFirst({ where: and(eq(spans.spanId, spanId), eq(spans.projectId, params.projectId)) }); @@ -135,7 +136,7 @@ export async function POST(request: Request, { params }: { params: { projectId: metadata: { spanId: span.spanId, }, - datasetId: action.datasetId, + datasetId: action?.datasetId, }).returning(); await db.insert(datapointToSpan).values({ diff --git a/frontend/app/api/projects/[projectId]/queues/route.ts b/frontend/app/api/projects/[projectId]/queues/route.ts index 3324a92a..82f8139b 100644 --- a/frontend/app/api/projects/[projectId]/queues/route.ts +++ b/frontend/app/api/projects/[projectId]/queues/route.ts @@ -11,8 +11,6 @@ export async function POST( ): Promise { const projectId = params.projectId; - - const body = await req.json(); const { name } = body;