diff --git a/query-engine/core/src/compiler/mod.rs b/query-engine/core/src/compiler/mod.rs index 26170861f25..4fbad2b50d6 100644 --- a/query-engine/core/src/compiler/mod.rs +++ b/query-engine/core/src/compiler/mod.rs @@ -6,6 +6,7 @@ use std::sync::Arc; pub use expression::Expression; use schema::QuerySchema; use thiserror::Error; +use quaint::connector::ConnectionInfo; pub use translate::{translate, TranslateError}; use crate::{QueryDocument, QueryGraphBuilder}; @@ -19,11 +20,11 @@ pub enum CompileError { TranslateError(#[from] TranslateError), } -pub fn compile(query_schema: &Arc, query_doc: QueryDocument) -> crate::Result { +pub fn compile(query_schema: &Arc, query_doc: QueryDocument, connection_info: &ConnectionInfo) -> crate::Result { let QueryDocument::Single(query) = query_doc else { return Err(CompileError::UnsupportedRequest.into()); }; let (graph, _serializer) = QueryGraphBuilder::new(query_schema).build(query)?; - Ok(translate(graph).map_err(CompileError::from)?) + Ok(translate(graph, connection_info).map_err(CompileError::from)?) } diff --git a/query-engine/core/src/compiler/translate.rs b/query-engine/core/src/compiler/translate.rs index 650d03e936f..fa46356c017 100644 --- a/query-engine/core/src/compiler/translate.rs +++ b/query-engine/core/src/compiler/translate.rs @@ -2,7 +2,7 @@ mod query; use query::translate_query; use thiserror::Error; - +use quaint::connector::ConnectionInfo; use crate::{EdgeRef, Node, NodeRef, Query, QueryGraph}; use super::expression::{Binding, Expression}; @@ -18,28 +18,30 @@ pub enum TranslateError { pub type TranslateResult = Result; -pub fn translate(mut graph: QueryGraph) -> TranslateResult { +pub fn translate(mut graph: QueryGraph, connection_info: &ConnectionInfo) -> TranslateResult { graph .root_nodes() .into_iter() - .map(|node| NodeTranslator::new(&mut graph, node, &[]).translate()) + .map(|node| NodeTranslator::new(&mut graph, node, &[], connection_info).translate()) .collect::>>() .map(Expression::Seq) } -struct NodeTranslator<'a, 'b> { +struct NodeTranslator<'a, 'b, 'c> { graph: &'a mut QueryGraph, node: NodeRef, #[allow(dead_code)] parent_edges: &'b [EdgeRef], + connection_info: &'c ConnectionInfo, } -impl<'a, 'b> NodeTranslator<'a, 'b> { - fn new(graph: &'a mut QueryGraph, node: NodeRef, parent_edges: &'b [EdgeRef]) -> Self { +impl<'a, 'b, 'c> NodeTranslator<'a, 'b, 'c> { + fn new(graph: &'a mut QueryGraph, node: NodeRef, parent_edges: &'b [EdgeRef], connection_info: &'c ConnectionInfo) -> Self { Self { graph, node, parent_edges, + connection_info, } } @@ -64,7 +66,7 @@ impl<'a, 'b> NodeTranslator<'a, 'b> { .try_into() .expect("current node must be query"); - translate_query(query) + translate_query(query, self.connection_info) } #[allow(dead_code)] @@ -99,7 +101,7 @@ impl<'a, 'b> NodeTranslator<'a, 'b> { .into_iter() .map(|(_, node)| { let edges = self.graph.incoming_edges(&node); - NodeTranslator::new(self.graph, node, &edges).translate() + NodeTranslator::new(self.graph, node, &edges, self.connection_info).translate() }) .collect::, _>>()?; @@ -121,7 +123,7 @@ impl<'a, 'b> NodeTranslator<'a, 'b> { .map(|(_, node)| { let name = node.id(); let edges = self.graph.incoming_edges(&node); - let expr = NodeTranslator::new(self.graph, node, &edges).translate()?; + let expr = NodeTranslator::new(self.graph, node, &edges, self.connection_info).translate()?; Ok(Binding { name, expr }) }) .collect::>>()?; diff --git a/query-engine/core/src/compiler/translate/query.rs b/query-engine/core/src/compiler/translate/query.rs index a54c0fe1cea..107bcb88774 100644 --- a/query-engine/core/src/compiler/translate/query.rs +++ b/query-engine/core/src/compiler/translate/query.rs @@ -3,7 +3,7 @@ mod read; mod write; use quaint::{ - prelude::{ConnectionInfo, ExternalConnectionInfo, SqlFamily}, + prelude::{ConnectionInfo, SqlFamily}, visitor::Visitor, }; use query_builder::DbQuery; @@ -15,14 +15,8 @@ use crate::{compiler::expression::Expression, Query}; use super::TranslateResult; -pub(crate) fn translate_query(query: Query) -> TranslateResult { - let connection_info = ConnectionInfo::External(ExternalConnectionInfo::new( - SqlFamily::Postgres, - "public".to_owned(), - None, - )); - - let ctx = Context::new(&connection_info, None); +pub(crate) fn translate_query(query: Query, connection_info: &ConnectionInfo) -> TranslateResult { + let ctx = Context::new(connection_info, None); match query { Query::Read(rq) => translate_read_query(rq, &ctx), @@ -30,8 +24,14 @@ pub(crate) fn translate_query(query: Query) -> TranslateResult { } } -fn build_db_query<'a>(query: impl Into>) -> TranslateResult { - let (sql, params) = quaint::visitor::Postgres::build(query)?; +fn build_db_query<'a>(query: impl Into>, ctx: &Context<'_>) -> TranslateResult { + let (sql, params) = match ctx.connection_info.sql_family() { + SqlFamily::Postgres => quaint::visitor::Postgres::build(query)?, + SqlFamily::Mysql => quaint::visitor::Mysql::build(query)?, + SqlFamily::Sqlite => quaint::visitor::Sqlite::build(query)?, + SqlFamily::Mssql => quaint::visitor::Mssql::build(query)?, + }; + let params = params .into_iter() .map(convert::quaint_value_to_prisma_value) diff --git a/query-engine/core/src/compiler/translate/query/read.rs b/query-engine/core/src/compiler/translate/query/read.rs index ab540d02472..0839416ce7c 100644 --- a/query-engine/core/src/compiler/translate/query/read.rs +++ b/query-engine/core/src/compiler/translate/query/read.rs @@ -31,7 +31,7 @@ pub(crate) fn translate_read_query(query: ReadQuery, ctx: &Context<'_>) -> Trans ) .limit(1); - let expr = Expression::Query(build_db_query(query)?); + let expr = Expression::Query(build_db_query(query, ctx)?); let expr = Expression::Unique(Box::new(expr)); if rq.nested.is_empty() { @@ -56,7 +56,7 @@ pub(crate) fn translate_read_query(query: ReadQuery, ctx: &Context<'_>) -> Trans ctx, ); - let expr = Expression::Query(build_db_query(query)?); + let expr = Expression::Query(build_db_query(query, ctx)?); let expr = if needs_reversed_order { Expression::Reverse(Box::new(expr)) @@ -186,7 +186,7 @@ fn build_read_one2m_query(rrq: RelatedRecordsQuery, ctx: &Context<'_>) -> Transl let query = if to_one_relation { query.limit(1) } else { query }; - let mut expr = Expression::Query(build_db_query(query)?); + let mut expr = Expression::Query(build_db_query(query, ctx)?); if to_one_relation { expr = Expression::Unique(Box::new(expr)); diff --git a/query-engine/core/src/compiler/translate/query/write.rs b/query-engine/core/src/compiler/translate/query/write.rs index 2dbbf12327c..4e361a6f746 100644 --- a/query-engine/core/src/compiler/translate/query/write.rs +++ b/query-engine/core/src/compiler/translate/query/write.rs @@ -19,7 +19,7 @@ pub(crate) fn translate_write_query(query: WriteQuery, ctx: &Context<'_>) -> Tra // TODO: we probably need some additional node type or extra info in the WriteQuery node // to help the client executor figure out the returned ID in the case when it's inferred // from the query arguments. - Expression::Query(build_db_query(query)?) + Expression::Query(build_db_query(query, ctx)?) } WriteQuery::CreateManyRecords(cmr) => { @@ -33,7 +33,7 @@ pub(crate) fn translate_write_query(query: WriteQuery, ctx: &Context<'_>) -> Tra ctx, ) .into_iter() - .map(build_db_query) + .map(|query| build_db_query(query, ctx)) .map(|maybe_db_query| maybe_db_query.map(Expression::Execute)) .collect::>>()?, ) @@ -41,7 +41,7 @@ pub(crate) fn translate_write_query(query: WriteQuery, ctx: &Context<'_>) -> Tra Expression::Sum( generate_insert_statements(&cmr.model, cmr.args, cmr.skip_duplicates, None, ctx) .into_iter() - .map(build_db_query) + .map(|query| build_db_query(query, ctx)) .map(|maybe_db_query| maybe_db_query.map(Expression::Execute)) .collect::>>()?, ) diff --git a/query-engine/query-builders/sql-query-builder/src/context.rs b/query-engine/query-builders/sql-query-builder/src/context.rs index 6bb1f2a1414..9634df12b43 100644 --- a/query-engine/query-builders/sql-query-builder/src/context.rs +++ b/query-engine/query-builders/sql-query-builder/src/context.rs @@ -2,7 +2,7 @@ use quaint::prelude::ConnectionInfo; use telemetry::TraceParent; pub struct Context<'a> { - connection_info: &'a ConnectionInfo, + pub connection_info: &'a ConnectionInfo, pub(crate) traceparent: Option, /// Maximum rows allowed at once for an insert query. /// None is unlimited. diff --git a/query-engine/query-engine-node-api/src/engine.rs b/query-engine/query-engine-node-api/src/engine.rs index 1d17eb56ff8..ba78329c49d 100644 --- a/query-engine/query-engine-node-api/src/engine.rs +++ b/query-engine/query-engine-node-api/src/engine.rs @@ -4,7 +4,7 @@ use napi::{threadsafe_function::ThreadSafeCallContext, Env, JsFunction, JsObject use napi_derive::napi; use prisma_metrics::{MetricFormat, WithMetricsInstrumentation}; use psl::PreviewFeature; -use quaint::connector::ExternalConnector; +use quaint::connector::{ConnectionInfo, ExternalConnector}; use query_core::{protocol::EngineProtocol, relation_load_strategy, schema, TransactionOptions, TxId}; use query_engine_common::{ engine::{ @@ -368,7 +368,12 @@ impl QueryEngine { .into_doc(engine.query_schema()) .map_err(|err| napi::Error::from_reason(err.to_string()))?; - let plan = query_core::compiler::compile(engine.query_schema(), query_doc).map_err(ApiError::from)?; + let connection_info = match self.connector_mode { + ConnectorMode::Js { ref adapter } => ConnectionInfo::External(adapter.get_connection_info().await.map_err(|err| napi::Error::from_reason(err.to_string()))?), + ConnectorMode::Rust => return Err(napi::Error::from_reason("Query compiler requires JS driver adapter".to_string())), + }; + + let plan = query_core::compiler::compile(engine.query_schema(), query_doc, &connection_info).map_err(ApiError::from)?; let response = if human_readable { plan.to_string() diff --git a/query-engine/query-engine-wasm/src/wasm/engine.rs b/query-engine/query-engine-wasm/src/wasm/engine.rs index a1c7cc7846e..1f6c7ad82e6 100644 --- a/query-engine/query-engine-wasm/src/wasm/engine.rs +++ b/query-engine/query-engine-wasm/src/wasm/engine.rs @@ -8,7 +8,7 @@ use crate::{ use driver_adapters::JsObject; use js_sys::Function as JsFunction; use psl::ConnectorRegistry; -use quaint::connector::ExternalConnector; +use quaint::connector::{ConnectionInfo, ExternalConnector}; use query_core::{ protocol::EngineProtocol, relation_load_strategy, @@ -375,7 +375,9 @@ impl QueryEngine { let request = RequestBody::try_from_str(&request, engine.engine_protocol())?; let query_doc = request.into_doc(engine.query_schema())?; - let plan = query_core::compiler::compile(engine.query_schema(), query_doc).map_err(ApiError::from)?; + let connection_info = ConnectionInfo::External(self.adapter.get_connection_info().await?); + + let plan = query_core::compiler::compile(engine.query_schema(), query_doc, &connection_info).map_err(ApiError::from)?; Ok(serde_json::to_string(&plan)?) } .with_subscriber(dispatcher) diff --git a/query-engine/schema/src/query_schema.rs b/query-engine/schema/src/query_schema.rs index b14febbdc8c..bde266acb4e 100644 --- a/query-engine/schema/src/query_schema.rs +++ b/query-engine/schema/src/query_schema.rs @@ -25,7 +25,7 @@ pub struct QuerySchema { pub internal_data_model: InternalDataModel, pub(crate) enable_raw_queries: bool, - pub(crate) connector: &'static dyn Connector, + pub connector: &'static dyn Connector, /// Indexes query and mutation fields by their own query info for easier access. query_info_map: HashMap<(Operation, QueryInfo), usize>,