Skip to content

Commit

Permalink
feat(compiler): pass through connection_info
Browse files Browse the repository at this point in the history
  • Loading branch information
FGoessler committed Jan 15, 2025
1 parent 6bb3256 commit 63ed1c7
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 34 deletions.
5 changes: 3 additions & 2 deletions query-engine/core/src/compiler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::sync::Arc;
pub use expression::Expression;
use schema::QuerySchema;
use thiserror::Error;
use quaint::connector::ConnectionInfo;
pub use translate::{translate, TranslateError};

use crate::{QueryDocument, QueryGraphBuilder};
Expand All @@ -19,11 +20,11 @@ pub enum CompileError {
TranslateError(#[from] TranslateError),

Check warning on line 20 in query-engine/core/src/compiler/mod.rs

View workflow job for this annotation

GitHub Actions / rustfmt

Diff in /home/runner/work/prisma-engines/prisma-engines/query-engine/core/src/compiler/mod.rs
}

pub fn compile(query_schema: &Arc<QuerySchema>, query_doc: QueryDocument) -> crate::Result<Expression> {
pub fn compile(query_schema: &Arc<QuerySchema>, query_doc: QueryDocument, connection_info: &ConnectionInfo) -> crate::Result<Expression> {
let QueryDocument::Single(query) = query_doc else {
return Err(CompileError::UnsupportedRequest.into());
};

let (graph, _serializer) = QueryGraphBuilder::new(query_schema).build(query)?;
Ok(translate(graph).map_err(CompileError::from)?)
Ok(translate(graph, connection_info).map_err(CompileError::from)?)
}
20 changes: 11 additions & 9 deletions query-engine/core/src/compiler/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod query;

use query::translate_query;
use thiserror::Error;

use quaint::connector::ConnectionInfo;
use crate::{EdgeRef, Node, NodeRef, Query, QueryGraph};

use super::expression::{Binding, Expression};
Expand All @@ -18,28 +18,30 @@ pub enum TranslateError {

pub type TranslateResult<T> = Result<T, TranslateError>;

pub fn translate(mut graph: QueryGraph) -> TranslateResult<Expression> {
pub fn translate(mut graph: QueryGraph, connection_info: &ConnectionInfo) -> TranslateResult<Expression> {
graph
.root_nodes()
.into_iter()
.map(|node| NodeTranslator::new(&mut graph, node, &[]).translate())
.map(|node| NodeTranslator::new(&mut graph, node, &[], connection_info).translate())
.collect::<TranslateResult<Vec<_>>>()
.map(Expression::Seq)
}

struct NodeTranslator<'a, 'b> {
struct NodeTranslator<'a, 'b, 'c> {
graph: &'a mut QueryGraph,
node: NodeRef,
#[allow(dead_code)]
parent_edges: &'b [EdgeRef],
connection_info: &'c ConnectionInfo,
}

Check warning on line 36 in query-engine/core/src/compiler/translate.rs

View workflow job for this annotation

GitHub Actions / rustfmt

Diff in /home/runner/work/prisma-engines/prisma-engines/query-engine/core/src/compiler/translate.rs

impl<'a, 'b> NodeTranslator<'a, 'b> {
fn new(graph: &'a mut QueryGraph, node: NodeRef, parent_edges: &'b [EdgeRef]) -> Self {
impl<'a, 'b, 'c> NodeTranslator<'a, 'b, 'c> {
fn new(graph: &'a mut QueryGraph, node: NodeRef, parent_edges: &'b [EdgeRef], connection_info: &'c ConnectionInfo) -> Self {
Self {
graph,
node,
parent_edges,
connection_info,
}
}

Expand All @@ -64,7 +66,7 @@ impl<'a, 'b> NodeTranslator<'a, 'b> {
.try_into()
.expect("current node must be query");

translate_query(query)
translate_query(query, self.connection_info)
}

#[allow(dead_code)]
Expand Down Expand Up @@ -99,7 +101,7 @@ impl<'a, 'b> NodeTranslator<'a, 'b> {
.into_iter()
.map(|(_, node)| {
let edges = self.graph.incoming_edges(&node);
NodeTranslator::new(self.graph, node, &edges).translate()
NodeTranslator::new(self.graph, node, &edges, self.connection_info).translate()
})
.collect::<Result<Vec<_>, _>>()?;

Expand All @@ -121,7 +123,7 @@ impl<'a, 'b> NodeTranslator<'a, 'b> {
.map(|(_, node)| {
let name = node.id();
let edges = self.graph.incoming_edges(&node);
let expr = NodeTranslator::new(self.graph, node, &edges).translate()?;
let expr = NodeTranslator::new(self.graph, node, &edges, self.connection_info).translate()?;
Ok(Binding { name, expr })
})
.collect::<TranslateResult<Vec<_>>>()?;
Expand Down
22 changes: 11 additions & 11 deletions query-engine/core/src/compiler/translate/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod read;
mod write;

use quaint::{
prelude::{ConnectionInfo, ExternalConnectionInfo, SqlFamily},
prelude::{ConnectionInfo, SqlFamily},
visitor::Visitor,
};
use query_builder::DbQuery;
Expand All @@ -15,23 +15,23 @@ use crate::{compiler::expression::Expression, Query};

use super::TranslateResult;

pub(crate) fn translate_query(query: Query) -> TranslateResult<Expression> {
let connection_info = ConnectionInfo::External(ExternalConnectionInfo::new(
SqlFamily::Postgres,
"public".to_owned(),
None,
));

let ctx = Context::new(&connection_info, None);
pub(crate) fn translate_query(query: Query, connection_info: &ConnectionInfo) -> TranslateResult<Expression> {
let ctx = Context::new(connection_info, None);

match query {
Query::Read(rq) => translate_read_query(rq, &ctx),
Query::Write(wq) => translate_write_query(wq, &ctx),
}
}

fn build_db_query<'a>(query: impl Into<quaint::ast::Query<'a>>) -> TranslateResult<DbQuery> {
let (sql, params) = quaint::visitor::Postgres::build(query)?;
fn build_db_query<'a>(query: impl Into<quaint::ast::Query<'a>>, ctx: &Context<'_>) -> TranslateResult<DbQuery> {
let (sql, params) = match ctx.connection_info.sql_family() {
SqlFamily::Postgres => quaint::visitor::Postgres::build(query)?,
SqlFamily::Mysql => quaint::visitor::Mysql::build(query)?,
SqlFamily::Sqlite => quaint::visitor::Sqlite::build(query)?,
SqlFamily::Mssql => quaint::visitor::Mssql::build(query)?,
};

let params = params
.into_iter()
.map(convert::quaint_value_to_prisma_value)
Expand Down
6 changes: 3 additions & 3 deletions query-engine/core/src/compiler/translate/query/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub(crate) fn translate_read_query(query: ReadQuery, ctx: &Context<'_>) -> Trans
)
.limit(1);

let expr = Expression::Query(build_db_query(query)?);
let expr = Expression::Query(build_db_query(query, ctx)?);
let expr = Expression::Unique(Box::new(expr));

if rq.nested.is_empty() {
Expand All @@ -56,7 +56,7 @@ pub(crate) fn translate_read_query(query: ReadQuery, ctx: &Context<'_>) -> Trans
ctx,
);

let expr = Expression::Query(build_db_query(query)?);
let expr = Expression::Query(build_db_query(query, ctx)?);

let expr = if needs_reversed_order {
Expression::Reverse(Box::new(expr))
Expand Down Expand Up @@ -186,7 +186,7 @@ fn build_read_one2m_query(rrq: RelatedRecordsQuery, ctx: &Context<'_>) -> Transl

let query = if to_one_relation { query.limit(1) } else { query };

let mut expr = Expression::Query(build_db_query(query)?);
let mut expr = Expression::Query(build_db_query(query, ctx)?);

if to_one_relation {
expr = Expression::Unique(Box::new(expr));
Expand Down
6 changes: 3 additions & 3 deletions query-engine/core/src/compiler/translate/query/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub(crate) fn translate_write_query(query: WriteQuery, ctx: &Context<'_>) -> Tra
// TODO: we probably need some additional node type or extra info in the WriteQuery node
// to help the client executor figure out the returned ID in the case when it's inferred
// from the query arguments.
Expression::Query(build_db_query(query)?)
Expression::Query(build_db_query(query, ctx)?)
}

WriteQuery::CreateManyRecords(cmr) => {
Expand All @@ -33,15 +33,15 @@ pub(crate) fn translate_write_query(query: WriteQuery, ctx: &Context<'_>) -> Tra
ctx,
)
.into_iter()
.map(build_db_query)
.map(|query| build_db_query(query, ctx))
.map(|maybe_db_query| maybe_db_query.map(Expression::Execute))
.collect::<TranslateResult<Vec<_>>>()?,
)
} else {
Expression::Sum(
generate_insert_statements(&cmr.model, cmr.args, cmr.skip_duplicates, None, ctx)
.into_iter()
.map(build_db_query)
.map(|query| build_db_query(query, ctx))
.map(|maybe_db_query| maybe_db_query.map(Expression::Execute))
.collect::<TranslateResult<Vec<_>>>()?,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use quaint::prelude::ConnectionInfo;
use telemetry::TraceParent;

pub struct Context<'a> {
connection_info: &'a ConnectionInfo,
pub connection_info: &'a ConnectionInfo,
pub(crate) traceparent: Option<TraceParent>,
/// Maximum rows allowed at once for an insert query.
/// None is unlimited.
Expand Down
9 changes: 7 additions & 2 deletions query-engine/query-engine-node-api/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use napi::{threadsafe_function::ThreadSafeCallContext, Env, JsFunction, JsObject
use napi_derive::napi;
use prisma_metrics::{MetricFormat, WithMetricsInstrumentation};
use psl::PreviewFeature;
use quaint::connector::ExternalConnector;
use quaint::connector::{ConnectionInfo, ExternalConnector};
use query_core::{protocol::EngineProtocol, relation_load_strategy, schema, TransactionOptions, TxId};
use query_engine_common::{
engine::{
Expand Down Expand Up @@ -368,7 +368,12 @@ impl QueryEngine {
.into_doc(engine.query_schema())
.map_err(|err| napi::Error::from_reason(err.to_string()))?;

Check warning on line 369 in query-engine/query-engine-node-api/src/engine.rs

View workflow job for this annotation

GitHub Actions / rustfmt

Diff in /home/runner/work/prisma-engines/prisma-engines/query-engine/query-engine-node-api/src/engine.rs

let plan = query_core::compiler::compile(engine.query_schema(), query_doc).map_err(ApiError::from)?;
let connection_info = match self.connector_mode {
ConnectorMode::Js { ref adapter } => ConnectionInfo::External(adapter.get_connection_info().await.map_err(|err| napi::Error::from_reason(err.to_string()))?),
ConnectorMode::Rust => return Err(napi::Error::from_reason("Query compiler requires JS driver adapter".to_string())),
};

let plan = query_core::compiler::compile(engine.query_schema(), query_doc, &connection_info).map_err(ApiError::from)?;

let response = if human_readable {
plan.to_string()
Expand Down
6 changes: 4 additions & 2 deletions query-engine/query-engine-wasm/src/wasm/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
use driver_adapters::JsObject;
use js_sys::Function as JsFunction;
use psl::ConnectorRegistry;
use quaint::connector::ExternalConnector;
use quaint::connector::{ConnectionInfo, ExternalConnector};
use query_core::{
protocol::EngineProtocol,
relation_load_strategy,
Expand Down Expand Up @@ -375,7 +375,9 @@ impl QueryEngine {
let request = RequestBody::try_from_str(&request, engine.engine_protocol())?;
let query_doc = request.into_doc(engine.query_schema())?;

Check warning on line 377 in query-engine/query-engine-wasm/src/wasm/engine.rs

View workflow job for this annotation

GitHub Actions / rustfmt

Diff in /home/runner/work/prisma-engines/prisma-engines/query-engine/query-engine-wasm/src/wasm/engine.rs
let plan = query_core::compiler::compile(engine.query_schema(), query_doc).map_err(ApiError::from)?;
let connection_info = ConnectionInfo::External(self.adapter.get_connection_info().await?);

let plan = query_core::compiler::compile(engine.query_schema(), query_doc, &connection_info).map_err(ApiError::from)?;
Ok(serde_json::to_string(&plan)?)
}
.with_subscriber(dispatcher)
Expand Down
2 changes: 1 addition & 1 deletion query-engine/schema/src/query_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub struct QuerySchema {
pub internal_data_model: InternalDataModel,

pub(crate) enable_raw_queries: bool,
pub(crate) connector: &'static dyn Connector,
pub connector: &'static dyn Connector,

/// Indexes query and mutation fields by their own query info for easier access.
query_info_map: HashMap<(Operation, QueryInfo), usize>,
Expand Down

0 comments on commit 63ed1c7

Please sign in to comment.