diff --git a/.envrc b/.envrc index 5436426aa834..ed5ae2b6d10b 100644 --- a/.envrc +++ b/.envrc @@ -53,3 +53,11 @@ if command -v nix &> /dev/null && [ -z ${DISABLE_NIX+x} ] then use nix fi + +if [[ "$OSTYPE" == "linux-gnu"* ]] && command -v lld &> /dev/null && [ ! -f .cargo/config.toml ]; then + mkdir -p .cargo + cat << EOF > .cargo/config.toml +[target.$(uname -m)-unknown-linux-gnu] +rustflags = ["-C", "link-arg=-fuse-ld=lld"] +EOF +fi diff --git a/.github/workflows/build-engines.yml b/.github/workflows/build-engines.yml index db1c1a42ce6d..ac7a72024b52 100644 --- a/.github/workflows/build-engines.yml +++ b/.github/workflows/build-engines.yml @@ -6,30 +6,30 @@ on: push: branches: - main - - '*.*.x' - - 'integration/*' + - "*.*.x" + - "integration/*" paths-ignore: - - '!.github/workflows/build-engines*' - - '.github/**' - - '.buildkite/**' - - '*.md' - - 'LICENSE' - - 'CODEOWNERS' - - 'renovate.json' + - "!.github/workflows/build-engines*" + - ".github/**" + - ".buildkite/**" + - "*.md" + - "LICENSE" + - "CODEOWNERS" + - "renovate.json" workflow_dispatch: pull_request: paths-ignore: - - '!.github/workflows/build-engines*' - - '.github/**' - - '.buildkite/**' - - '*.md' - - 'LICENSE' - - 'CODEOWNERS' - - 'renovate.json' + - "!.github/workflows/build-engines*" + - ".github/**" + - ".buildkite/**" + - "*.md" + - "LICENSE" + - "CODEOWNERS" + - "renovate.json" jobs: is-release-necessary: - name: 'Decide if a release of the engines artifacts is necessary' + name: "Decide if a release of the engines artifacts is necessary" runs-on: ubuntu-22.04 outputs: release: ${{ steps.decision.outputs.release }} @@ -59,7 +59,7 @@ jobs: END_OF_COMMIT_MESSAGE echo "Commit message contains [integration]: ${{ contains(steps.commit-msg.outputs.commit-msg, '[integration]') }}" - - name: 'Check if commit message conatains `[integration]` and the PR author has permissions to trigger the workflow' + - name: "Check if commit message conatains `[integration]` and the PR author has permissions to trigger the workflow" id: check-commit-message # See https://docs.github.com/en/graphql/reference/enums # https://michaelheap.com/github-actions-check-permission/ @@ -68,7 +68,7 @@ jobs: # - the PR author has permissions to trigger the workflow (must be part of the org or a collaborator) if: | github.event_name == 'pull_request' && - contains(steps.commit-msg.outputs.commit-msg, '[integration]') && + contains(steps.commit-msg.outputs.commit-msg, '[integration]') && ( github.event.pull_request.author_association == 'OWNER' || github.event.pull_request.author_association == 'MEMBER' || @@ -76,8 +76,8 @@ jobs: github.event.pull_request.author_association == 'COLLABORATOR' ) run: | - echo "Commit message contains [integration] and PR author has permissions" - # set value to GitHub output + echo "Commit message contains [integration] and PR author has permissions" + # set value to GitHub output echo "release=true" >> $GITHUB_OUTPUT # @@ -118,9 +118,9 @@ jobs: # https://github.com/peter-evans/find-comment/tree/v3/?tab=readme-ov-file#outputs # Tip: Empty strings evaluate to zero in GitHub Actions expressions. e.g. If comment-id is an empty string steps.fc.outputs.comment-id == 0 evaluates to true. if: | - github.event_name == 'workflow_dispatch' || - github.event_name == 'push' || - steps.check-commit-message.outputs.release == 'true' || + github.event_name == 'workflow_dispatch' || + github.event_name == 'push' || + steps.check-commit-message.outputs.release == 'true' || steps.check-branch.outputs.release == 'true' id: decision @@ -140,7 +140,7 @@ jobs: build-linux: name: Build Engines for Linux - needs: + needs: - is-release-necessary if: ${{ needs.is-release-necessary.outputs.release == 'true' }} uses: ./.github/workflows/build-engines-linux-template.yml @@ -149,7 +149,7 @@ jobs: build-macos-intel: name: Build Engines for Apple Intel - needs: + needs: - is-release-necessary if: ${{ needs.is-release-necessary.outputs.release == 'true' }} uses: ./.github/workflows/build-engines-apple-intel-template.yml @@ -158,25 +158,25 @@ jobs: build-macos-silicon: name: Build Engines for Apple Silicon - needs: + needs: - is-release-necessary if: ${{ needs.is-release-necessary.outputs.release == 'true' }} uses: ./.github/workflows/build-engines-apple-silicon-template.yml with: commit: ${{ github.sha }} - build-react-native: - name: Build Engines for React native - needs: - - is-release-necessary - if: ${{ needs.is-release-necessary.outputs.release == 'true' }} - uses: ./.github/workflows/build-engines-react-native-template.yml - with: - commit: ${{ github.sha }} + # build-react-native: + # name: Build Engines for React native + # needs: + # - is-release-necessary + # if: ${{ needs.is-release-necessary.outputs.release == 'true' }} + # uses: ./.github/workflows/build-engines-react-native-template.yml + # with: + # commit: ${{ github.sha }} build-windows: name: Build Engines for Windows - needs: + needs: - is-release-necessary if: ${{ needs.is-release-necessary.outputs.release == 'true' }} uses: ./.github/workflows/build-engines-windows-template.yml @@ -184,7 +184,7 @@ jobs: commit: ${{ github.sha }} release-artifacts: - name: 'Release artifacts from branch ${{ github.head_ref || github.ref_name }} for commit ${{ github.sha }}' + name: "Release artifacts from branch ${{ github.head_ref || github.ref_name }} for commit ${{ github.sha }}" runs-on: ubuntu-22.04 concurrency: group: ${{ github.sha }} @@ -192,12 +192,12 @@ jobs: - build-linux - build-macos-intel - build-macos-silicon - - build-react-native + # - build-react-native - build-windows env: - BUCKET_NAME: 'prisma-builds' + BUCKET_NAME: "prisma-builds" PRISMA_ENGINES_COMMIT_SHA: ${{ github.sha }} - DESTINATION_TARGET_PATH: 's3://prisma-builds/all_commits/${{ github.sha }}' + DESTINATION_TARGET_PATH: "s3://prisma-builds/all_commits/${{ github.sha }}" steps: # Because we need the scripts @@ -215,22 +215,22 @@ jobs: # run-id: 9526334324 # github-token: ${{ secrets.GITHUB_TOKEN }} - - name: 'R2: Check if artifacts were already built and uploaded before via `.finished` file' + - name: "R2: Check if artifacts were already built and uploaded before via `.finished` file" env: - FILE_PATH: 'all_commits/${{ github.sha }}/.finished' - FILE_PATH_LEGACY: 'all_commits/${{ github.sha }}/rhel-openssl-1.1.x/.finished' - AWS_DEFAULT_REGION: 'auto' + FILE_PATH: "all_commits/${{ github.sha }}/.finished" + FILE_PATH_LEGACY: "all_commits/${{ github.sha }}/rhel-openssl-1.1.x/.finished" + AWS_DEFAULT_REGION: "auto" AWS_ACCESS_KEY_ID: ${{ vars.R2_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} AWS_ENDPOINT_URL_S3: ${{ vars.R2_ENDPOINT }} working-directory: .github/workflows/utils run: bash checkFinishedMarker.sh - - name: 'S3: Check if artifacts were already built and uploaded before via `.finished` file' + - name: "S3: Check if artifacts were already built and uploaded before via `.finished` file" env: - FILE_PATH: 'all_commits/${{ github.sha }}/.finished' - FILE_PATH_LEGACY: 'all_commits/${{ github.sha }}/rhel-openssl-1.1.x/.finished' - AWS_DEFAULT_REGION: 'eu-west-1' + FILE_PATH: "all_commits/${{ github.sha }}/.finished" + FILE_PATH_LEGACY: "all_commits/${{ github.sha }}/rhel-openssl-1.1.x/.finished" + AWS_DEFAULT_REGION: "eu-west-1" AWS_ACCESS_KEY_ID: ${{ vars.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} working-directory: .github/workflows/utils @@ -247,14 +247,14 @@ jobs: cp -r rhel-openssl-1.1.x debian-openssl-1.1.x cp -r rhel-openssl-3.0.x debian-openssl-3.0.x - - name: Create .zip for react-native - working-directory: engines-artifacts - run: | - mkdir react-native - zip -r react-native/binaries.zip ios android - rm -rf ios android + # - name: Create .zip for react-native + # working-directory: engines-artifacts + # run: | + # mkdir react-native + # zip -r react-native/binaries.zip ios android + # rm -rf ios android - - name: 'Create compressed engine files (.gz)' + - name: "Create compressed engine files (.gz)" working-directory: engines-artifacts run: | set -eu @@ -266,13 +266,13 @@ jobs: ls -Rl . - - name: 'Create SHA256 checksum files (.sha256).' + - name: "Create SHA256 checksum files (.sha256)." working-directory: engines-artifacts run: | set -eu find . -type f | while read filename; do - sha256sum "$filename" > "$filename.sha256" + sha256sum "$filename" > "$filename.sha256" echo "$filename.sha256 file created." done @@ -292,7 +292,7 @@ jobs: run: gpg -K # next to each file (excluding .sha256 files) - - name: 'Create a GPG detached signature (.sig)' + - name: "Create a GPG detached signature (.sig)" working-directory: engines-artifacts run: | set -eu @@ -303,18 +303,18 @@ jobs: ls -Rl . - - name: 'Cloudflare R2: Upload to bucket and verify uploaded files then create `.finished` file' + - name: "Cloudflare R2: Upload to bucket and verify uploaded files then create `.finished` file" # https://docs.aws.amazon.com/cli/v1/userguide/cli-configure-envvars.html env: - AWS_DEFAULT_REGION: 'auto' + AWS_DEFAULT_REGION: "auto" AWS_ACCESS_KEY_ID: ${{ vars.R2_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} AWS_ENDPOINT_URL_S3: ${{ vars.R2_ENDPOINT }} run: bash .github/workflows/utils/uploadAndVerify.sh engines-artifacts-for-r2 - - name: 'AWS S3: Upload to bucket and verify uploaded files then create `.finished` file' + - name: "AWS S3: Upload to bucket and verify uploaded files then create `.finished` file" env: - AWS_DEFAULT_REGION: 'eu-west-1' + AWS_DEFAULT_REGION: "eu-west-1" AWS_ACCESS_KEY_ID: ${{ vars.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} run: bash .github/workflows/utils/uploadAndVerify.sh engines-artifacts-for-s3 diff --git a/.github/workflows/test-compilation.yml b/.github/workflows/test-compilation.yml index 3db71c67b5e7..193fdc26da96 100644 --- a/.github/workflows/test-compilation.yml +++ b/.github/workflows/test-compilation.yml @@ -50,9 +50,9 @@ jobs: - name: "Check that Cargo.lock did not change" run: "git diff --exit-code" - test-react-native-compilation: - name: React Native - uses: ./.github/workflows/build-engines-react-native-template.yml - with: - commit: ${{ github.sha }} - uploadArtifacts: false + # test-react-native-compilation: + # name: React Native + # uses: ./.github/workflows/build-engines-react-native-template.yml + # with: + # commit: ${{ github.sha }} + # uploadArtifacts: false diff --git a/Cargo.lock b/Cargo.lock index fc18c125a92c..6783c338e32b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3681,11 +3681,13 @@ dependencies = [ "petgraph", "prisma-metrics", "psl", + "quaint", "query-connector", "query-structure", "schema", "serde", "serde_json", + "sql-query-connector", "telemetry", "thiserror", "tokio", @@ -3709,6 +3711,7 @@ dependencies = [ "enumflags2", "graphql-parser", "hyper", + "indexmap 2.2.2", "indoc 2.0.3", "mongodb-query-connector", "prisma-metrics", diff --git a/libs/prisma-value/src/lib.rs b/libs/prisma-value/src/lib.rs index 8a1b10c2aedb..01a4e5e50572 100644 --- a/libs/prisma-value/src/lib.rs +++ b/libs/prisma-value/src/lib.rs @@ -8,6 +8,7 @@ use chrono::prelude::*; use serde::de::Unexpected; use serde::ser::SerializeMap; use serde::{ser::Serializer, Deserialize, Deserializer, Serialize}; +use serde_json::json; use std::{convert::TryFrom, fmt, str::FromStr}; use uuid::Uuid; @@ -47,6 +48,45 @@ pub enum PrismaValue { #[serde(serialize_with = "serialize_bytes")] Bytes(Vec), + + #[serde(serialize_with = "serialize_placeholder")] + Placeholder { + name: String, + r#type: PlaceholderType, + }, +} + +#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize, PartialOrd, Ord)] +pub enum PlaceholderType { + Any, + String, + Int, + BigInt, + Float, + Boolean, + Decimal, + Date, + Array(Box), + Object, + Bytes, +} + +impl std::fmt::Display for PlaceholderType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PlaceholderType::Any => write!(f, "Any"), + PlaceholderType::String => write!(f, "String"), + PlaceholderType::Int => write!(f, "Int"), + PlaceholderType::BigInt => write!(f, "BigInt"), + PlaceholderType::Float => write!(f, "Float"), + PlaceholderType::Boolean => write!(f, "Boolean"), + PlaceholderType::Decimal => write!(f, "Decimal"), + PlaceholderType::Date => write!(f, "Date"), + PlaceholderType::Array(t) => write!(f, "Array<{t}>"), + PlaceholderType::Object => write!(f, "Object"), + PlaceholderType::Bytes => write!(f, "Bytes"), + } + } } /// Stringify a date to the following format @@ -107,6 +147,7 @@ impl TryFrom for PrismaValue { Ok(PrismaValue::DateTime(date)) } + Some("bigint") => { let value = obj .get("prisma__value") @@ -117,6 +158,7 @@ impl TryFrom for PrismaValue { .map(PrismaValue::BigInt) .map_err(|_| ConversionFailure::new("JSON bigint value", "PrismaValue")) } + Some("decimal") => { let value = obj .get("prisma__value") @@ -127,6 +169,7 @@ impl TryFrom for PrismaValue { .map(PrismaValue::Float) .map_err(|_| ConversionFailure::new("JSON decimal value", "PrismaValue")) } + Some("bytes") => { let value = obj .get("prisma__value") @@ -136,6 +179,24 @@ impl TryFrom for PrismaValue { decode_bytes(value).map(PrismaValue::Bytes) } + Some("param") => { + let obj = obj + .get("prisma__value") + .and_then(|v| v.as_object()) + .ok_or_else(|| ConversionFailure::new("JSON param value", "PrismaValue"))?; + + let name = obj + .get("name") + .and_then(|v| v.as_str()) + .ok_or_else(|| ConversionFailure::new("param name", "JSON param value"))? + .to_owned(); + + Ok(PrismaValue::Placeholder { + name, + r#type: PlaceholderType::Any, // parsing the type is not implemented yet + }) + } + _ => Ok(PrismaValue::Json(serde_json::to_string(&obj).unwrap())), }, } @@ -197,6 +258,24 @@ where map.end() } +fn serialize_placeholder(name: &str, r#type: &PlaceholderType, serializer: S) -> Result +where + S: Serializer, +{ + let mut map = serializer.serialize_map(Some(2))?; + + map.serialize_entry("prisma__type", "param")?; + map.serialize_entry( + "prisma__value", + &json!({ + "name": name, + "type": r#type.to_string(), + }), + )?; + + map.end() +} + struct BigDecimalVisitor; impl serde::de::Visitor<'_> for BigDecimalVisitor { @@ -345,6 +424,7 @@ impl fmt::Display for PrismaValue { write!(f, "{{ {joined} }}") } + PrismaValue::Placeholder { name, r#type } => write!(f, "var({name}: {type})"), } } } diff --git a/quaint/src/ast.rs b/quaint/src/ast.rs index 66d37a5754a7..50aa38cc4d1f 100644 --- a/quaint/src/ast.rs +++ b/quaint/src/ast.rs @@ -53,5 +53,5 @@ pub use select::{DistinctType, Select}; pub use table::*; pub use union::Union; pub use update::*; -pub use values::{IntoRaw, Raw, Value, ValueType, Values}; +pub use values::{IntoRaw, Raw, Value, ValueType, Values, VarType}; pub(crate) use values::{NativeColumnType, Params}; diff --git a/quaint/src/ast/values.rs b/quaint/src/ast/values.rs index 008191150618..256479fce6c1 100644 --- a/quaint/src/ast/values.rs +++ b/quaint/src/ast/values.rs @@ -225,6 +225,11 @@ impl<'a> Value<'a> { ValueType::xml(value).into_value() } + /// Creates a new variable. + pub fn var(name: impl Into>, ty: VarType) -> Self { + ValueType::var(name, ty).into_value() + } + /// `true` if the `Value` is null. pub fn is_null(&self) -> bool { self.typed.is_null() @@ -553,6 +558,59 @@ pub enum ValueType<'a> { Date(Option), /// A time value. Time(Option), + /// A variable that doesn't have a value assigned yet. + Var(Cow<'a, str>, VarType), +} + +#[derive(Debug, Clone, PartialEq)] +pub enum VarType { + Unknown, + Int32, + Int64, + Float, + Double, + Text, + Enum, + Bytes, + Boolean, + Char, + Array(Box), + Numeric, + Json, + Xml, + Uuid, + DateTime, + Date, + Time, +} + +impl fmt::Display for VarType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + VarType::Unknown => write!(f, "Unknown"), + VarType::Int32 => write!(f, "Int32"), + VarType::Int64 => write!(f, "Int64"), + VarType::Float => write!(f, "Float"), + VarType::Double => write!(f, "Double"), + VarType::Text => write!(f, "Text"), + VarType::Enum => write!(f, "Enum"), + VarType::Bytes => write!(f, "Bytes"), + VarType::Boolean => write!(f, "Boolean"), + VarType::Char => write!(f, "Char"), + VarType::Array(t) => { + write!(f, "Array<")?; + t.fmt(f)?; + write!(f, ">") + } + VarType::Numeric => write!(f, "Numeric"), + VarType::Json => write!(f, "Json"), + VarType::Xml => write!(f, "Xml"), + VarType::Uuid => write!(f, "Uuid"), + VarType::DateTime => write!(f, "DateTime"), + VarType::Date => write!(f, "Date"), + VarType::Time => write!(f, "Time"), + } + } } pub(crate) struct Params<'a>(pub(crate) &'a [Value<'a>]); @@ -619,6 +677,7 @@ impl fmt::Display for ValueType<'_> { ValueType::DateTime(val) => val.map(|v| write!(f, "\"{v}\"")), ValueType::Date(val) => val.map(|v| write!(f, "\"{v}\"")), ValueType::Time(val) => val.map(|v| write!(f, "\"{v}\"")), + ValueType::Var(name, ty) => Some(write!(f, "${name} as {ty}")), }; match res { @@ -677,6 +736,7 @@ impl<'a> From> for serde_json::Value { ValueType::DateTime(dt) => dt.map(|dt| serde_json::Value::String(dt.to_rfc3339())), ValueType::Date(date) => date.map(|date| serde_json::Value::String(format!("{date}"))), ValueType::Time(time) => time.map(|time| serde_json::Value::String(format!("{time}"))), + ValueType::Var(_, _) => todo!(), }; match res { @@ -830,6 +890,11 @@ impl<'a> ValueType<'a> { Self::Xml(Some(value.into())) } + /// Creates a new variable. + pub fn var(name: impl Into>, ty: VarType) -> Self { + Self::Var(name.into(), ty) + } + /// `true` if the `Value` is null. pub fn is_null(&self) -> bool { match self { @@ -851,6 +916,7 @@ impl<'a> ValueType<'a> { Self::Date(d) => d.is_none(), Self::Time(t) => t.is_none(), Self::Json(json) => json.is_none(), + Self::Var(_, _) => false, } } diff --git a/quaint/src/connector/column_type.rs b/quaint/src/connector/column_type.rs index 38fb3d786dc0..d8cd9d46d19f 100644 --- a/quaint/src/connector/column_type.rs +++ b/quaint/src/connector/column_type.rs @@ -1,7 +1,7 @@ #[cfg(not(target_arch = "wasm32"))] use super::TypeIdentifier; -use crate::{Value, ValueType}; +use crate::{ast::VarType, Value, ValueType}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ColumnType { @@ -99,23 +99,24 @@ impl From<&Value<'_>> for ColumnType { impl From<&ValueType<'_>> for ColumnType { fn from(value: &ValueType) -> Self { match value { - ValueType::Int32(_) => ColumnType::Int32, - ValueType::Int64(_) => ColumnType::Int64, - ValueType::Float(_) => ColumnType::Float, - ValueType::Double(_) => ColumnType::Double, - ValueType::Text(_) => ColumnType::Text, - ValueType::Enum(_, _) => ColumnType::Enum, + ValueType::Int32(_) | ValueType::Var(_, VarType::Int32) => ColumnType::Int32, + ValueType::Int64(_) | ValueType::Var(_, VarType::Int64) => ColumnType::Int64, + ValueType::Float(_) | ValueType::Var(_, VarType::Float) => ColumnType::Float, + ValueType::Double(_) | ValueType::Var(_, VarType::Double) => ColumnType::Double, + ValueType::Text(_) | ValueType::Var(_, VarType::Text) => ColumnType::Text, + ValueType::Enum(_, _) | ValueType::Var(_, VarType::Enum) => ColumnType::Enum, ValueType::EnumArray(_, _) => ColumnType::TextArray, - ValueType::Bytes(_) => ColumnType::Bytes, - ValueType::Boolean(_) => ColumnType::Boolean, - ValueType::Char(_) => ColumnType::Char, - ValueType::Numeric(_) => ColumnType::Numeric, - ValueType::Json(_) => ColumnType::Json, - ValueType::Xml(_) => ColumnType::Xml, - ValueType::Uuid(_) => ColumnType::Uuid, - ValueType::DateTime(_) => ColumnType::DateTime, - ValueType::Date(_) => ColumnType::Date, - ValueType::Time(_) => ColumnType::Time, + ValueType::Var(_, VarType::Array(vt)) if **vt == VarType::Enum => ColumnType::TextArray, + ValueType::Bytes(_) | ValueType::Var(_, VarType::Bytes) => ColumnType::Bytes, + ValueType::Boolean(_) | ValueType::Var(_, VarType::Boolean) => ColumnType::Boolean, + ValueType::Char(_) | ValueType::Var(_, VarType::Char) => ColumnType::Char, + ValueType::Numeric(_) | ValueType::Var(_, VarType::Numeric) => ColumnType::Numeric, + ValueType::Json(_) | ValueType::Var(_, VarType::Json) => ColumnType::Json, + ValueType::Xml(_) | ValueType::Var(_, VarType::Xml) => ColumnType::Xml, + ValueType::Uuid(_) | ValueType::Var(_, VarType::Uuid) => ColumnType::Uuid, + ValueType::DateTime(_) | ValueType::Var(_, VarType::DateTime) => ColumnType::DateTime, + ValueType::Date(_) | ValueType::Var(_, VarType::Date) => ColumnType::Date, + ValueType::Time(_) | ValueType::Var(_, VarType::Time) => ColumnType::Time, ValueType::Array(Some(vals)) if !vals.is_empty() => match &vals[0].typed { ValueType::Int32(_) => ColumnType::Int32Array, ValueType::Int64(_) => ColumnType::Int64Array, @@ -135,8 +136,30 @@ impl From<&ValueType<'_>> for ColumnType { ValueType::Time(_) => ColumnType::TimeArray, ValueType::Array(_) => ColumnType::Unknown, ValueType::EnumArray(_, _) => ColumnType::Unknown, + ValueType::Var(_, _) => ColumnType::Unknown, }, ValueType::Array(_) => ColumnType::Unknown, + ValueType::Var(_, VarType::Unknown) => ColumnType::Unknown, + ValueType::Var(_, VarType::Array(vt)) => match **vt { + VarType::Int32 => ColumnType::Int32Array, + VarType::Int64 => ColumnType::Int64Array, + VarType::Float => ColumnType::FloatArray, + VarType::Double => ColumnType::DoubleArray, + VarType::Text => ColumnType::TextArray, + VarType::Enum => ColumnType::TextArray, + VarType::Bytes => ColumnType::BytesArray, + VarType::Boolean => ColumnType::BooleanArray, + VarType::Char => ColumnType::CharArray, + VarType::Numeric => ColumnType::NumericArray, + VarType::Json => ColumnType::JsonArray, + VarType::Xml => ColumnType::TextArray, + VarType::Uuid => ColumnType::UuidArray, + VarType::DateTime => ColumnType::DateTimeArray, + VarType::Date => ColumnType::DateArray, + VarType::Time => ColumnType::TimeArray, + VarType::Unknown => ColumnType::Unknown, + VarType::Array(_) => ColumnType::Unknown, + }, } } } diff --git a/quaint/src/connector/mssql/native/conversion.rs b/quaint/src/connector/mssql/native/conversion.rs index 5d2eb2eb08b8..1ab099946252 100644 --- a/quaint/src/connector/mssql/native/conversion.rs +++ b/quaint/src/connector/mssql/native/conversion.rs @@ -1,4 +1,7 @@ -use crate::ast::{Value, ValueType}; +use crate::{ + ast::{Value, ValueType}, + error::{Error, ErrorKind}, +}; use bigdecimal::BigDecimal; use std::{borrow::Cow, convert::TryFrom}; @@ -25,6 +28,12 @@ impl<'a> IntoSql<'a> for &'a Value<'a> { ValueType::DateTime(val) => val.into_sql(), ValueType::Date(val) => val.into_sql(), ValueType::Time(val) => val.into_sql(), + ValueType::Var(name, _) => { + panic!( + "conversion error: {:?}", + Error::builder(ErrorKind::RanQueryWithVarParam(name.clone().into_owned())).build() + ) + } } } } diff --git a/quaint/src/connector/mysql/native/conversion.rs b/quaint/src/connector/mysql/native/conversion.rs index 1a2d065f03af..540ffb8b4df0 100644 --- a/quaint/src/connector/mysql/native/conversion.rs +++ b/quaint/src/connector/mysql/native/conversion.rs @@ -68,6 +68,10 @@ pub fn conv_params(params: &[Value<'_>]) -> crate::Result { dt.timestamp_subsec_micros(), ) }), + + ValueType::Var(name, _) => { + Err(Error::builder(ErrorKind::RanQueryWithVarParam(name.clone().into_owned())).build())? + } }; match res { diff --git a/quaint/src/connector/postgres/native/conversion.rs b/quaint/src/connector/postgres/native/conversion.rs index a55e6490bd86..c67cc8384b63 100644 --- a/quaint/src/connector/postgres/native/conversion.rs +++ b/quaint/src/connector/postgres/native/conversion.rs @@ -1,7 +1,7 @@ mod decimal; use crate::{ - ast::{Value, ValueType}, + ast::{Value, ValueType, VarType}, connector::queryable::{GetRow, ToColumnNames}, error::{Error, ErrorKind}, prelude::EnumVariant, @@ -40,23 +40,27 @@ pub(crate) fn params_to_types(params: &[Value<'_>]) -> Vec { } match &p.typed { - ValueType::Int32(_) => PostgresType::INT4, - ValueType::Int64(_) => PostgresType::INT8, - ValueType::Float(_) => PostgresType::FLOAT4, - ValueType::Double(_) => PostgresType::FLOAT8, - ValueType::Text(_) => PostgresType::TEXT, + ValueType::Int32(_) | ValueType::Var(_, VarType::Int32) => PostgresType::INT4, + ValueType::Int64(_) | ValueType::Var(_, VarType::Int64) => PostgresType::INT8, + ValueType::Float(_) | ValueType::Var(_, VarType::Float) => PostgresType::FLOAT4, + ValueType::Double(_) | ValueType::Var(_, VarType::Double) => PostgresType::FLOAT8, + ValueType::Text(_) | ValueType::Var(_, VarType::Text) => PostgresType::TEXT, // Enums are user-defined types, we can't statically infer them, so we let PG infer it - ValueType::Enum(_, _) | ValueType::EnumArray(_, _) => PostgresType::UNKNOWN, - ValueType::Bytes(_) => PostgresType::BYTEA, - ValueType::Boolean(_) => PostgresType::BOOL, - ValueType::Char(_) => PostgresType::CHAR, - ValueType::Numeric(_) => PostgresType::NUMERIC, - ValueType::Json(_) => PostgresType::JSONB, - ValueType::Xml(_) => PostgresType::XML, - ValueType::Uuid(_) => PostgresType::UUID, - ValueType::DateTime(_) => PostgresType::TIMESTAMPTZ, - ValueType::Date(_) => PostgresType::TIMESTAMP, - ValueType::Time(_) => PostgresType::TIME, + ValueType::Enum(_, _) | ValueType::EnumArray(_, _) | ValueType::Var(_, VarType::Enum) => { + PostgresType::UNKNOWN + } + ValueType::Bytes(_) | ValueType::Var(_, VarType::Bytes) => PostgresType::BYTEA, + ValueType::Boolean(_) | ValueType::Var(_, VarType::Boolean) => PostgresType::BOOL, + ValueType::Char(_) | ValueType::Var(_, VarType::Char) => PostgresType::CHAR, + ValueType::Numeric(_) | ValueType::Var(_, VarType::Numeric) => PostgresType::NUMERIC, + ValueType::Json(_) | ValueType::Var(_, VarType::Json) => PostgresType::JSONB, + ValueType::Xml(_) | ValueType::Var(_, VarType::Xml) => PostgresType::XML, + ValueType::Uuid(_) | ValueType::Var(_, VarType::Uuid) => PostgresType::UUID, + ValueType::DateTime(_) | ValueType::Var(_, VarType::DateTime) => PostgresType::TIMESTAMPTZ, + ValueType::Date(_) | ValueType::Var(_, VarType::Date) => PostgresType::TIMESTAMP, + ValueType::Time(_) | ValueType::Var(_, VarType::Time) => PostgresType::TIME, + ValueType::Var(_, VarType::Unknown) => PostgresType::UNKNOWN, + ValueType::Array(ref arr) => { let arr = arr.as_ref().unwrap(); @@ -76,27 +80,53 @@ pub(crate) fn params_to_types(params: &[Value<'_>]) -> Vec { } match first.typed { - ValueType::Int32(_) => PostgresType::INT4_ARRAY, - ValueType::Int64(_) => PostgresType::INT8_ARRAY, - ValueType::Float(_) => PostgresType::FLOAT4_ARRAY, - ValueType::Double(_) => PostgresType::FLOAT8_ARRAY, - ValueType::Text(_) => PostgresType::TEXT_ARRAY, + ValueType::Int32(_) | ValueType::Var(_, VarType::Int32) => PostgresType::INT4_ARRAY, + ValueType::Int64(_) | ValueType::Var(_, VarType::Int64) => PostgresType::INT8_ARRAY, + ValueType::Float(_) | ValueType::Var(_, VarType::Float) => PostgresType::FLOAT4_ARRAY, + ValueType::Double(_) | ValueType::Var(_, VarType::Double) => PostgresType::FLOAT8_ARRAY, + ValueType::Text(_) | ValueType::Var(_, VarType::Text) => PostgresType::TEXT_ARRAY, // Enums are special types, we can't statically infer them, so we let PG infer it - ValueType::Enum(_, _) | ValueType::EnumArray(_, _) => PostgresType::UNKNOWN, - ValueType::Bytes(_) => PostgresType::BYTEA_ARRAY, - ValueType::Boolean(_) => PostgresType::BOOL_ARRAY, - ValueType::Char(_) => PostgresType::CHAR_ARRAY, - ValueType::Numeric(_) => PostgresType::NUMERIC_ARRAY, - ValueType::Json(_) => PostgresType::JSONB_ARRAY, - ValueType::Xml(_) => PostgresType::XML_ARRAY, - ValueType::Uuid(_) => PostgresType::UUID_ARRAY, - ValueType::DateTime(_) => PostgresType::TIMESTAMPTZ_ARRAY, - ValueType::Date(_) => PostgresType::TIMESTAMP_ARRAY, - ValueType::Time(_) => PostgresType::TIME_ARRAY, + ValueType::Enum(_, _) | ValueType::EnumArray(_, _) | ValueType::Var(_, VarType::Enum) => { + PostgresType::UNKNOWN + } + ValueType::Bytes(_) | ValueType::Var(_, VarType::Bytes) => PostgresType::BYTEA_ARRAY, + ValueType::Boolean(_) | ValueType::Var(_, VarType::Boolean) => PostgresType::BOOL_ARRAY, + ValueType::Char(_) | ValueType::Var(_, VarType::Char) => PostgresType::CHAR_ARRAY, + ValueType::Numeric(_) | ValueType::Var(_, VarType::Numeric) => PostgresType::NUMERIC_ARRAY, + ValueType::Json(_) | ValueType::Var(_, VarType::Json) => PostgresType::JSONB_ARRAY, + ValueType::Xml(_) | ValueType::Var(_, VarType::Xml) => PostgresType::XML_ARRAY, + ValueType::Uuid(_) | ValueType::Var(_, VarType::Uuid) => PostgresType::UUID_ARRAY, + ValueType::DateTime(_) | ValueType::Var(_, VarType::DateTime) => { + PostgresType::TIMESTAMPTZ_ARRAY + } + ValueType::Date(_) | ValueType::Var(_, VarType::Date) => PostgresType::TIMESTAMP_ARRAY, + ValueType::Time(_) | ValueType::Var(_, VarType::Time) => PostgresType::TIME_ARRAY, // In the case of nested arrays, we let PG infer the type - ValueType::Array(_) => PostgresType::UNKNOWN, + ValueType::Array(_) | ValueType::Var(_, VarType::Array(_)) => PostgresType::UNKNOWN, + ValueType::Var(_, VarType::Unknown) => PostgresType::UNKNOWN, } } + + ValueType::Var(_, VarType::Array(t)) => match &**t { + VarType::Unknown => PostgresType::UNKNOWN, + VarType::Int32 => PostgresType::INT4_ARRAY, + VarType::Int64 => PostgresType::INT8_ARRAY, + VarType::Float => PostgresType::FLOAT4_ARRAY, + VarType::Double => PostgresType::FLOAT8_ARRAY, + VarType::Text => PostgresType::TEXT_ARRAY, + VarType::Enum => PostgresType::UNKNOWN, + VarType::Bytes => PostgresType::BYTEA_ARRAY, + VarType::Boolean => PostgresType::BOOL_ARRAY, + VarType::Char => PostgresType::CHAR_ARRAY, + VarType::Array(_) => PostgresType::UNKNOWN, + VarType::Numeric => PostgresType::NUMERIC_ARRAY, + VarType::Json => PostgresType::JSONB_ARRAY, + VarType::Xml => PostgresType::XML_ARRAY, + VarType::Uuid => PostgresType::UUID_ARRAY, + VarType::DateTime => PostgresType::TIMESTAMPTZ_ARRAY, + VarType::Date => PostgresType::TIMESTAMP_ARRAY, + VarType::Time => PostgresType::TIME_ARRAY, + }, } }) .collect() @@ -975,6 +1005,11 @@ impl ToSql for Value<'_> { Ok(result) }), (ValueType::DateTime(value), _) => value.map(|value| value.naive_utc().to_sql(ty, out)), + (ValueType::Var(name, _), _) => { + let error: Box = + Box::new(Error::builder(ErrorKind::RanQueryWithVarParam(name.clone().into_owned())).build()); + Some(Err(error)) + } }; match res { diff --git a/quaint/src/connector/sqlite/native/conversion.rs b/quaint/src/connector/sqlite/native/conversion.rs index e24379a58aca..dd6316d5b327 100644 --- a/quaint/src/connector/sqlite/native/conversion.rs +++ b/quaint/src/connector/sqlite/native/conversion.rs @@ -307,6 +307,10 @@ impl ToSql for Value<'_> { date.and_hms_opt(time.hour(), time.minute(), time.second()) }) .map(|dt| ToSqlOutput::from(dt.and_utc().timestamp_millis())), + + ValueType::Var(name, _) => Err(RusqlError::ToSqlConversionFailure(Box::new( + Error::builder(ErrorKind::RanQueryWithVarParam(name.clone().into_owned())).build(), + )))?, }; match value { diff --git a/quaint/src/error/mod.rs b/quaint/src/error/mod.rs index 661eb4d344ff..67f7e62650f8 100644 --- a/quaint/src/error/mod.rs +++ b/quaint/src/error/mod.rs @@ -241,6 +241,12 @@ pub enum ErrorKind { #[error("External error id#{}", _0)] ExternalError(i32), + + #[error("Variable '{0}' used as raw value in query. Variables must be used as parameters.")] + VarAsRawValue(String), + + #[error("Attempted to execute a query that contains unbound variable '{0}' in parameters.")] + RanQueryWithVarParam(String), } #[cfg(not(target_arch = "wasm32"))] diff --git a/quaint/src/visitor/mssql.rs b/quaint/src/visitor/mssql.rs index a3887b4cfaee..56d81232a1e3 100644 --- a/quaint/src/visitor/mssql.rs +++ b/quaint/src/visitor/mssql.rs @@ -402,6 +402,10 @@ impl<'a> Visitor<'a> for Mssql<'a> { // Style 3 is keep all whitespace + internal DTD processing: // https://docs.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?redirectedfrom=MSDN&view=sql-server-ver15#xml-styles ValueType::Xml(cow) => cow.map(|cow| self.write(format!("CONVERT(XML, N'{cow}', 3)"))), + + ValueType::Var(name, _) => Some(Err( + Error::builder(ErrorKind::VarAsRawValue(name.clone().into_owned())).build() + )), }; match res { diff --git a/quaint/src/visitor/mysql.rs b/quaint/src/visitor/mysql.rs index 77979b8b7f64..fd3a10bd3b39 100644 --- a/quaint/src/visitor/mysql.rs +++ b/quaint/src/visitor/mysql.rs @@ -193,6 +193,10 @@ impl<'a> Visitor<'a> for Mysql<'a> { ValueType::Date(date) => date.map(|date| self.write(format!("'{date}'"))), ValueType::Time(time) => time.map(|time| self.write(format!("'{time}'"))), ValueType::Xml(cow) => cow.as_ref().map(|cow| self.write(format!("'{cow}'"))), + + ValueType::Var(name, _) => Some(Err( + Error::builder(ErrorKind::VarAsRawValue(name.clone().into_owned())).build() + )), }; match res { diff --git a/quaint/src/visitor/postgres.rs b/quaint/src/visitor/postgres.rs index c119ab319ded..196eb01ebd7d 100644 --- a/quaint/src/visitor/postgres.rs +++ b/quaint/src/visitor/postgres.rs @@ -1,5 +1,6 @@ use crate::{ ast::*, + error::{Error, ErrorKind}, visitor::{self, Visitor}, }; use itertools::Itertools; @@ -257,6 +258,10 @@ impl<'a> Visitor<'a> for Postgres<'a> { ValueType::DateTime(dt) => dt.map(|dt| self.write(format!("'{}'", dt.to_rfc3339(),))), ValueType::Date(date) => date.map(|date| self.write(format!("'{date}'"))), ValueType::Time(time) => time.map(|time| self.write(format!("'{time}'"))), + + ValueType::Var(name, _) => Some(Err( + Error::builder(ErrorKind::VarAsRawValue(name.clone().into_owned())).build() + )), }; match res { diff --git a/quaint/src/visitor/sqlite.rs b/quaint/src/visitor/sqlite.rs index 7292aa2eca64..d6d70739b5f6 100644 --- a/quaint/src/visitor/sqlite.rs +++ b/quaint/src/visitor/sqlite.rs @@ -140,6 +140,10 @@ impl<'a> Visitor<'a> for Sqlite<'a> { ValueType::Date(date) => date.map(|date| self.write(format!("'{date}'"))), ValueType::Time(time) => time.map(|time| self.write(format!("'{time}'"))), ValueType::Xml(cow) => cow.as_ref().map(|cow| self.write(format!("'{cow}'"))), + + ValueType::Var(name, _) => Some(Err( + Error::builder(ErrorKind::VarAsRawValue(name.clone().into_owned())).build() + )), }; match res { diff --git a/query-engine/connectors/sql-query-connector/src/context.rs b/query-engine/connectors/sql-query-connector/src/context.rs index 3f90e94a027d..b3e28c8152c1 100644 --- a/query-engine/connectors/sql-query-connector/src/context.rs +++ b/query-engine/connectors/sql-query-connector/src/context.rs @@ -1,7 +1,7 @@ use quaint::prelude::ConnectionInfo; use telemetry::TraceParent; -pub(super) struct Context<'a> { +pub struct Context<'a> { connection_info: &'a ConnectionInfo, pub(crate) traceparent: Option, /// Maximum rows allowed at once for an insert query. @@ -13,7 +13,7 @@ pub(super) struct Context<'a> { } impl<'a> Context<'a> { - pub(crate) fn new(connection_info: &'a ConnectionInfo, traceparent: Option) -> Self { + pub fn new(connection_info: &'a ConnectionInfo, traceparent: Option) -> Self { let max_insert_rows = connection_info.max_insert_rows(); let max_bind_values = connection_info.max_bind_values(); diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs index 137bff50ca58..75e0bda84bde 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs @@ -207,7 +207,7 @@ fn collect_affected_fields(args: &[WriteArgs], model: &Model) -> HashSet, skip_duplicates: bool, diff --git a/query-engine/connectors/sql-query-connector/src/error.rs b/query-engine/connectors/sql-query-connector/src/error.rs index f3e76d84fd9e..970fb173804f 100644 --- a/query-engine/connectors/sql-query-connector/src/error.rs +++ b/query-engine/connectors/sql-query-connector/src/error.rs @@ -367,6 +367,8 @@ impl From for SqlError { e @ QuaintKind::DatabaseAlreadyExists { .. } => SqlError::ConnectionError(e), e @ QuaintKind::InvalidConnectionArguments => SqlError::ConnectionError(e), e @ QuaintKind::SocketTimeout => SqlError::ConnectionError(e), + e @ QuaintKind::VarAsRawValue { .. } => SqlError::ConversionError(e.into()), + e @ QuaintKind::RanQueryWithVarParam { .. } => SqlError::ConversionError(e.into()), } } } diff --git a/query-engine/connectors/sql-query-connector/src/lib.rs b/query-engine/connectors/sql-query-connector/src/lib.rs index 9bd6c2d7f211..680bd4134186 100644 --- a/query-engine/connectors/sql-query-connector/src/lib.rs +++ b/query-engine/connectors/sql-query-connector/src/lib.rs @@ -2,17 +2,17 @@ #![deny(unsafe_code)] mod column_metadata; -mod context; +pub mod context; mod cursor_condition; mod database; mod error; mod filter; mod join_utils; -mod model_extensions; +pub mod model_extensions; mod nested_aggregations; mod ordering; -mod query_arguments_ext; -mod query_builder; +pub mod query_arguments_ext; +pub mod query_builder; mod query_ext; mod row; mod ser_raw; @@ -22,6 +22,8 @@ mod value; use self::{column_metadata::*, context::Context, query_ext::QueryExt, row::*}; use quaint::prelude::Queryable; +pub use database::operations::write::generate_insert_statements; + pub use database::FromSource; #[cfg(feature = "driver-adapters")] pub use database::Js; diff --git a/query-engine/connectors/sql-query-connector/src/model_extensions/column.rs b/query-engine/connectors/sql-query-connector/src/model_extensions/column.rs index 1ee4c358b0d2..42557aa01b3e 100644 --- a/query-engine/connectors/sql-query-connector/src/model_extensions/column.rs +++ b/query-engine/connectors/sql-query-connector/src/model_extensions/column.rs @@ -32,7 +32,7 @@ impl From>> for ColumnIterator { } } -pub(crate) trait AsColumns { +pub trait AsColumns { fn as_columns(&self, ctx: &Context<'_>) -> ColumnIterator; } @@ -48,7 +48,7 @@ impl AsColumns for ModelProjection { } } -pub(crate) trait AsColumn { +pub trait AsColumn { fn as_column(&self, ctx: &Context<'_>) -> Column<'static>; } diff --git a/query-engine/connectors/sql-query-connector/src/model_extensions/mod.rs b/query-engine/connectors/sql-query-connector/src/model_extensions/mod.rs index 66cb072cc7af..d1bff1954100 100644 --- a/query-engine/connectors/sql-query-connector/src/model_extensions/mod.rs +++ b/query-engine/connectors/sql-query-connector/src/model_extensions/mod.rs @@ -5,4 +5,5 @@ mod scalar_field; mod selection_result; mod table; -pub(crate) use self::{column::*, record::*, relation::*, scalar_field::*, selection_result::*, table::*}; +pub use self::{column::*, record::*, scalar_field::*}; +pub(crate) use self::{relation::*, selection_result::*, table::*}; diff --git a/query-engine/connectors/sql-query-connector/src/model_extensions/scalar_field.rs b/query-engine/connectors/sql-query-connector/src/model_extensions/scalar_field.rs index 826bc2f1d7e1..a3e88aa1d403 100644 --- a/query-engine/connectors/sql-query-connector/src/model_extensions/scalar_field.rs +++ b/query-engine/connectors/sql-query-connector/src/model_extensions/scalar_field.rs @@ -1,8 +1,8 @@ use crate::context::Context; use chrono::Utc; -use prisma_value::PrismaValue; +use prisma_value::{PlaceholderType, PrismaValue}; use quaint::{ - ast::{EnumName, Value, ValueType}, + ast::{EnumName, Value, ValueType, VarType}, prelude::{EnumVariant, TypeDataLength, TypeFamily}, }; use query_structure::{ScalarField, TypeIdentifier}; @@ -74,7 +74,21 @@ impl ScalarFieldExt for ScalarField { TypeIdentifier::Int => Value::null_int32(), TypeIdentifier::BigInt => Value::null_int64(), TypeIdentifier::Bytes => Value::null_bytes(), - TypeIdentifier::Unsupported => unreachable!("No unsupported field should reach that path"), + TypeIdentifier::Unsupported => unreachable!("No unsupported field should reach this path"), + }, + (PrismaValue::Placeholder { name, .. }, ident) => match ident { + TypeIdentifier::String => Value::var(name, VarType::Text), + TypeIdentifier::Int => Value::var(name, VarType::Int32), + TypeIdentifier::BigInt => Value::var(name, VarType::Int64), + TypeIdentifier::Float => Value::var(name, VarType::Numeric), + TypeIdentifier::Decimal => Value::var(name, VarType::Numeric), + TypeIdentifier::Boolean => Value::var(name, VarType::Boolean), + TypeIdentifier::Enum(_) => Value::var(name, VarType::Enum), + TypeIdentifier::UUID => Value::var(name, VarType::Uuid), + TypeIdentifier::Json => Value::var(name, VarType::Json), + TypeIdentifier::DateTime => Value::var(name, VarType::DateTime), + TypeIdentifier::Bytes => Value::var(name, VarType::Bytes), + TypeIdentifier::Unsupported => unreachable!("No unsupported field should reach this path"), }, }; @@ -126,6 +140,23 @@ pub fn convert_lossy<'a>(pv: PrismaValue) -> Value<'a> { PrismaValue::Bytes(b) => Value::bytes(b), PrismaValue::Null => Value::null_int32(), // Can't tell which type the null is supposed to be. PrismaValue::Object(_) => unimplemented!(), + PrismaValue::Placeholder { name, r#type } => Value::var(name, convert_placeholder_type_to_var_type(&r#type)), + } +} + +fn convert_placeholder_type_to_var_type(pt: &PlaceholderType) -> VarType { + match pt { + PlaceholderType::Any => VarType::Unknown, + PlaceholderType::String => VarType::Text, + PlaceholderType::Int => VarType::Int32, + PlaceholderType::BigInt => VarType::Int64, + PlaceholderType::Float => VarType::Numeric, + PlaceholderType::Boolean => VarType::Boolean, + PlaceholderType::Decimal => VarType::Numeric, + PlaceholderType::Date => VarType::DateTime, + PlaceholderType::Array(t) => VarType::Array(Box::new(convert_placeholder_type_to_var_type(t))), + PlaceholderType::Object => VarType::Json, + PlaceholderType::Bytes => VarType::Bytes, } } diff --git a/query-engine/connectors/sql-query-connector/src/query_arguments_ext.rs b/query-engine/connectors/sql-query-connector/src/query_arguments_ext.rs index 33db6ff17676..1e2aebb3535e 100644 --- a/query-engine/connectors/sql-query-connector/src/query_arguments_ext.rs +++ b/query-engine/connectors/sql-query-connector/src/query_arguments_ext.rs @@ -1,6 +1,6 @@ use query_structure::QueryArguments; -pub(crate) trait QueryArgumentsExt { +pub trait QueryArgumentsExt { /// If we need to take rows before a cursor position, then we need to reverse the order in SQL. fn needs_reversed_order(&self) -> bool; diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/mod.rs b/query-engine/connectors/sql-query-connector/src/query_builder/mod.rs index 199847a2f340..15d696b4e7ea 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/mod.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/mod.rs @@ -1,7 +1,7 @@ -pub(crate) mod read; +pub mod read; #[cfg(feature = "relation_joins")] -pub(crate) mod select; -pub(crate) mod write; +pub mod select; +pub mod write; use crate::context::Context; use crate::model_extensions::SelectionResultExt; diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/read.rs b/query-engine/connectors/sql-query-connector/src/query_builder/read.rs index 7b1806948688..e33d51857a2f 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/read.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/read.rs @@ -7,7 +7,7 @@ use itertools::Itertools; use quaint::ast::*; use query_structure::*; -pub(crate) trait SelectDefinition { +pub trait SelectDefinition { fn into_select<'a>( self, _: &Model, @@ -122,7 +122,7 @@ impl SelectDefinition for QueryArguments { } } -pub(crate) fn get_records<'a, T>( +pub fn get_records<'a, T>( model: &Model, columns: impl Iterator>, virtual_selections: impl IntoIterator, diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/write.rs b/query-engine/connectors/sql-query-connector/src/query_builder/write.rs index aa010044b0b9..b250a58d7cee 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/write.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/write.rs @@ -6,7 +6,7 @@ use std::{collections::HashSet, convert::TryInto}; /// `INSERT` a new record to the database. Resulting an `INSERT` ast and an /// optional `RecordProjection` if available from the arguments or model. -pub(crate) fn create_record( +pub fn create_record( model: &Model, mut args: WriteArgs, selected_fields: &ModelProjection, diff --git a/query-engine/connectors/sql-query-connector/src/ser_raw.rs b/query-engine/connectors/sql-query-connector/src/ser_raw.rs index 87e0e84ae8f9..1bcd059e4c31 100644 --- a/query-engine/connectors/sql-query-connector/src/ser_raw.rs +++ b/query-engine/connectors/sql-query-connector/src/ser_raw.rs @@ -187,6 +187,7 @@ impl Serialize for SerializedValue<'_> { ValueType::DateTime(value) => value.map(|value| value.to_rfc3339()).serialize(serializer), ValueType::Date(value) => value.serialize(serializer), ValueType::Time(value) => value.serialize(serializer), + ValueType::Var(_, _) => unreachable!(), } } } diff --git a/query-engine/connectors/sql-query-connector/src/value.rs b/query-engine/connectors/sql-query-connector/src/value.rs index 2221925e8040..f83746c2beb2 100644 --- a/query-engine/connectors/sql-query-connector/src/value.rs +++ b/query-engine/connectors/sql-query-connector/src/value.rs @@ -98,7 +98,35 @@ pub fn to_prisma_value<'a, T: Into>>(qv: T) -> crate::Result s .map(|s| PrismaValue::String(s.into_owned())) .unwrap_or(PrismaValue::Null), + + ValueType::Var(name, vt) => PrismaValue::Placeholder { + name: name.into_owned(), + r#type: var_type_to_prisma_type(&vt), + }, }; Ok(val) } + +fn var_type_to_prisma_type(vt: &quaint::ast::VarType) -> prisma_value::PlaceholderType { + match vt { + quaint::ast::VarType::Unknown => prisma_value::PlaceholderType::Any, + quaint::ast::VarType::Int32 => prisma_value::PlaceholderType::Int, + quaint::ast::VarType::Int64 => prisma_value::PlaceholderType::BigInt, + quaint::ast::VarType::Float => prisma_value::PlaceholderType::Float, + quaint::ast::VarType::Double => prisma_value::PlaceholderType::Float, + quaint::ast::VarType::Text => prisma_value::PlaceholderType::String, + quaint::ast::VarType::Enum => prisma_value::PlaceholderType::String, + quaint::ast::VarType::Bytes => prisma_value::PlaceholderType::Bytes, + quaint::ast::VarType::Boolean => prisma_value::PlaceholderType::Boolean, + quaint::ast::VarType::Char => prisma_value::PlaceholderType::String, + quaint::ast::VarType::Array(t) => prisma_value::PlaceholderType::Array(Box::new(var_type_to_prisma_type(t))), + quaint::ast::VarType::Numeric => prisma_value::PlaceholderType::Decimal, + quaint::ast::VarType::Json => prisma_value::PlaceholderType::Object, + quaint::ast::VarType::Xml => prisma_value::PlaceholderType::String, + quaint::ast::VarType::Uuid => prisma_value::PlaceholderType::String, + quaint::ast::VarType::DateTime => prisma_value::PlaceholderType::Date, + quaint::ast::VarType::Date => prisma_value::PlaceholderType::Date, + quaint::ast::VarType::Time => prisma_value::PlaceholderType::Date, + } +} diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index 6edb37979c24..d8f97ec2cec1 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -41,3 +41,11 @@ telemetry = { path = "../../libs/telemetry" } lru = "0.7.7" enumflags2.workspace = true derive_more.workspace = true + +# HACK: query builders need to be a separate crate, and maybe the compiler too +# HACK: we hardcode PostgreSQL as the dialect for now +sql-query-connector = { path = "../connectors/sql-query-connector", features = [ + "postgresql", +] } +# HACK: this should not be in core either +quaint.workspace = true diff --git a/query-engine/core/src/compiler/expression.rs b/query-engine/core/src/compiler/expression.rs new file mode 100644 index 000000000000..26e6e066be55 --- /dev/null +++ b/query-engine/core/src/compiler/expression.rs @@ -0,0 +1,155 @@ +use query_structure::PrismaValue; +use serde::Serialize; + +#[derive(Debug, Serialize)] +pub struct Binding { + pub name: String, + pub expr: Expression, +} + +impl Binding { + pub fn new(name: String, expr: Expression) -> Self { + Self { name, expr } + } +} + +impl std::fmt::Display for Binding { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} = {}", self.name, self.expr) + } +} + +#[derive(Debug, Serialize)] +pub struct DbQuery { + pub query: String, + pub params: Vec, +} + +impl DbQuery { + pub fn new(query: String, params: Vec) -> Self { + Self { query, params } + } +} + +#[derive(Debug, Serialize)] +#[serde(tag = "type", content = "args")] +pub enum Expression { + /// Sequence of statements. The whole sequence evaluates to the result of the last expression. + Seq(Vec), + + /// Get binding value. + Get { name: String }, + + /// A lexical scope with let-bindings. + Let { + bindings: Vec, + expr: Box, + }, + + /// Gets the first non-empty value from a list of bindings. + GetFirstNonEmpty { names: Vec }, + + /// A database query that returns data. + Query(DbQuery), + + /// A database query that returns the number of affected rows. + Execute(DbQuery), + + /// Reverses the result of an expression in memory. + Reverse(Box), + + /// Sums a list of scalars returned by the expressions. + Sum(Vec), + + /// Concatenates a list of lists. + Concat(Vec), +} + +impl Expression { + fn display(&self, f: &mut std::fmt::Formatter<'_>, level: usize) -> std::fmt::Result { + let indent = " ".repeat(level); + + match self { + Self::Seq(exprs) => { + writeln!(f, "{indent}{{")?; + for expr in exprs { + expr.display(f, level + 1)?; + writeln!(f, ";")?; + } + write!(f, "{indent}}}")?; + } + + Self::Get { name } => { + write!(f, "{indent}get {name}")?; + } + + Self::Let { bindings, expr } => { + writeln!(f, "{indent}let")?; + for Binding { name, expr } in bindings { + writeln!(f, "{indent} {name} =")?; + expr.display(f, level + 2)?; + writeln!(f, ";")?; + } + writeln!(f, "{indent}in")?; + expr.display(f, level + 1)?; + } + + Self::GetFirstNonEmpty { names } => { + write!(f, "{indent}getFirstNonEmpty")?; + for name in names { + write!(f, " {}", name)?; + } + } + + Self::Query(query) => self.display_query("query", query, f, level)?, + + Self::Execute(query) => self.display_query("execute", query, f, level)?, + + Self::Reverse(expr) => { + writeln!(f, "{indent}reverse (")?; + expr.display(f, level + 1)?; + write!(f, "{indent})")?; + } + + Self::Sum(exprs) => self.display_function("sum", exprs, f, level)?, + + Self::Concat(exprs) => self.display_function("concat", exprs, f, level)?, + } + + Ok(()) + } + + fn display_query( + &self, + op: &str, + db_query: &DbQuery, + f: &mut std::fmt::Formatter<'_>, + level: usize, + ) -> std::fmt::Result { + let indent = " ".repeat(level); + let DbQuery { query, params } = db_query; + write!(f, "{indent}{op} {{\n{indent} {query}\n{indent}}} with {params:?}") + } + + fn display_function( + &self, + name: &str, + args: &[Expression], + f: &mut std::fmt::Formatter<'_>, + level: usize, + ) -> std::fmt::Result { + let indent = " ".repeat(level); + write!(f, "{indent}{name} (")?; + for arg in args { + arg.display(f, level + 1)?; + writeln!(f, ",")?; + } + write!(f, ")") + } +} + +impl std::fmt::Display for Expression { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.display(f, 0) + } +} diff --git a/query-engine/core/src/compiler/mod.rs b/query-engine/core/src/compiler/mod.rs new file mode 100644 index 000000000000..26170861f259 --- /dev/null +++ b/query-engine/core/src/compiler/mod.rs @@ -0,0 +1,29 @@ +pub mod expression; +pub mod translate; + +use std::sync::Arc; + +pub use expression::Expression; +use schema::QuerySchema; +use thiserror::Error; +pub use translate::{translate, TranslateError}; + +use crate::{QueryDocument, QueryGraphBuilder}; + +#[derive(Debug, Error)] +pub enum CompileError { + #[error("only a single query can be compiled at a time")] + UnsupportedRequest, + + #[error("{0}")] + TranslateError(#[from] TranslateError), +} + +pub fn compile(query_schema: &Arc, query_doc: QueryDocument) -> crate::Result { + let QueryDocument::Single(query) = query_doc else { + return Err(CompileError::UnsupportedRequest.into()); + }; + + let (graph, _serializer) = QueryGraphBuilder::new(query_schema).build(query)?; + Ok(translate(graph).map_err(CompileError::from)?) +} diff --git a/query-engine/core/src/compiler/translate.rs b/query-engine/core/src/compiler/translate.rs new file mode 100644 index 000000000000..650d03e936fb --- /dev/null +++ b/query-engine/core/src/compiler/translate.rs @@ -0,0 +1,151 @@ +mod query; + +use query::translate_query; +use thiserror::Error; + +use crate::{EdgeRef, Node, NodeRef, Query, QueryGraph}; + +use super::expression::{Binding, Expression}; + +#[derive(Debug, Error)] +pub enum TranslateError { + #[error("node {0} has no content")] + NodeContentEmpty(String), + + #[error("{0}")] + QuaintError(#[from] quaint::error::Error), +} + +pub type TranslateResult = Result; + +pub fn translate(mut graph: QueryGraph) -> TranslateResult { + graph + .root_nodes() + .into_iter() + .map(|node| NodeTranslator::new(&mut graph, node, &[]).translate()) + .collect::>>() + .map(Expression::Seq) +} + +struct NodeTranslator<'a, 'b> { + graph: &'a mut QueryGraph, + node: NodeRef, + #[allow(dead_code)] + parent_edges: &'b [EdgeRef], +} + +impl<'a, 'b> NodeTranslator<'a, 'b> { + fn new(graph: &'a mut QueryGraph, node: NodeRef, parent_edges: &'b [EdgeRef]) -> Self { + Self { + graph, + node, + parent_edges, + } + } + + fn translate(&mut self) -> TranslateResult { + let node = self + .graph + .node_content(&self.node) + .ok_or_else(|| TranslateError::NodeContentEmpty(self.node.id()))?; + + match node { + Node::Query(_) => self.translate_query(), + _ => unimplemented!(), + } + } + + fn translate_query(&mut self) -> TranslateResult { + self.graph.mark_visited(&self.node); + + let query: Query = self + .graph + .pluck_node(&self.node) + .try_into() + .expect("current node must be query"); + + translate_query(query) + } + + #[allow(dead_code)] + fn process_children(&mut self) -> TranslateResult> { + let mut child_pairs = self.graph.direct_child_pairs(&self.node); + + // Find the positions of all result returning graph nodes. + let mut result_positions = child_pairs + .iter() + .enumerate() + .filter_map(|(idx, (_, child_node))| { + if self.graph.subgraph_contains_result(child_node) { + Some(idx) + } else { + None + } + }) + .collect::>(); + + // Start removing the highest indices first to not invalidate subsequent removals. + result_positions.sort_unstable(); + result_positions.reverse(); + + let result_subgraphs = result_positions + .into_iter() + .map(|pos| child_pairs.remove(pos)) + .collect::>(); + + // Because we split from right to left, everything remaining in `child_pairs` + // doesn't belong into results, and is executed before all result scopes. + let mut expressions: Vec = child_pairs + .into_iter() + .map(|(_, node)| { + let edges = self.graph.incoming_edges(&node); + NodeTranslator::new(self.graph, node, &edges).translate() + }) + .collect::, _>>()?; + + // Fold result scopes into one expression. + if !result_subgraphs.is_empty() { + let result_exp = self.fold_result_scopes(result_subgraphs)?; + expressions.push(result_exp); + } + + Ok(expressions) + } + + #[allow(dead_code)] + fn fold_result_scopes(&mut self, result_subgraphs: Vec<(EdgeRef, NodeRef)>) -> TranslateResult { + // if the subgraphs all point to the same result node, we fold them in sequence + // if not, we can separate them with a getfirstnonempty + let bindings = result_subgraphs + .into_iter() + .map(|(_, node)| { + let name = node.id(); + let edges = self.graph.incoming_edges(&node); + let expr = NodeTranslator::new(self.graph, node, &edges).translate()?; + Ok(Binding { name, expr }) + }) + .collect::>>()?; + + let result_nodes = self.graph.result_nodes(); + let result_binding_names = bindings.iter().map(|b| b.name.clone()).collect::>(); + + if result_nodes.len() == 1 { + Ok(Expression::Let { + bindings, + expr: Box::new(Expression::Get { + name: result_binding_names + .into_iter() + .last() + .expect("no binding for result node"), + }), + }) + } else { + Ok(Expression::Let { + bindings, + expr: Box::new(Expression::GetFirstNonEmpty { + names: result_binding_names, + }), + }) + } + } +} diff --git a/query-engine/core/src/compiler/translate/query.rs b/query-engine/core/src/compiler/translate/query.rs new file mode 100644 index 000000000000..f3ff82c95298 --- /dev/null +++ b/query-engine/core/src/compiler/translate/query.rs @@ -0,0 +1,42 @@ +mod convert; +mod read; +mod write; + +use quaint::{ + prelude::{ConnectionInfo, ExternalConnectionInfo, SqlFamily}, + visitor::Visitor, +}; +use read::translate_read_query; +use sql_query_connector::context::Context; +use write::translate_write_query; + +use crate::{ + compiler::expression::{DbQuery, Expression}, + Query, +}; + +use super::TranslateResult; + +pub(crate) fn translate_query(query: Query) -> TranslateResult { + let connection_info = ConnectionInfo::External(ExternalConnectionInfo::new( + SqlFamily::Postgres, + "public".to_owned(), + None, + )); + + let ctx = Context::new(&connection_info, None); + + match query { + Query::Read(rq) => translate_read_query(rq, &ctx), + Query::Write(wq) => translate_write_query(wq, &ctx), + } +} + +fn build_db_query<'a>(query: impl Into>) -> TranslateResult { + let (sql, params) = quaint::visitor::Postgres::build(query)?; + let params = params + .into_iter() + .map(convert::quaint_value_to_prisma_value) + .collect::>(); + Ok(DbQuery::new(sql, params)) +} diff --git a/query-engine/core/src/compiler/translate/query/convert.rs b/query-engine/core/src/compiler/translate/query/convert.rs new file mode 100644 index 000000000000..2ea8463f93c0 --- /dev/null +++ b/query-engine/core/src/compiler/translate/query/convert.rs @@ -0,0 +1,94 @@ +use bigdecimal::{BigDecimal, FromPrimitive}; +use chrono::{DateTime, NaiveDate, Utc}; +use quaint::ast::VarType; +use query_structure::{PlaceholderType, PrismaValue}; + +pub(crate) fn quaint_value_to_prisma_value(value: quaint::Value<'_>) -> PrismaValue { + match value.typed { + quaint::ValueType::Int32(Some(i)) => PrismaValue::Int(i.into()), + quaint::ValueType::Int32(None) => PrismaValue::Null, + quaint::ValueType::Int64(Some(i)) => PrismaValue::BigInt(i), + quaint::ValueType::Int64(None) => PrismaValue::Null, + quaint::ValueType::Float(Some(f)) => PrismaValue::Float( + BigDecimal::from_f32(f) + .expect("float to decimal conversion should succeed") + .normalized(), + ), + quaint::ValueType::Float(None) => PrismaValue::Null, + quaint::ValueType::Double(Some(d)) => PrismaValue::Float( + BigDecimal::from_f64(d) + .expect("double to decimal conversion should succeed") + .normalized(), + ), + quaint::ValueType::Double(None) => PrismaValue::Null, + quaint::ValueType::Text(Some(s)) => PrismaValue::String(s.into_owned()), + quaint::ValueType::Text(None) => PrismaValue::Null, + quaint::ValueType::Enum(Some(e), _) => PrismaValue::Enum(e.into_owned()), + quaint::ValueType::Enum(None, _) => PrismaValue::Null, + quaint::ValueType::EnumArray(Some(es), _) => PrismaValue::List( + es.into_iter() + .map(|e| e.into_text()) + .map(quaint_value_to_prisma_value) + .collect(), + ), + quaint::ValueType::EnumArray(None, _) => PrismaValue::Null, + quaint::ValueType::Bytes(Some(b)) => PrismaValue::Bytes(b.into_owned()), + quaint::ValueType::Bytes(None) => PrismaValue::Null, + quaint::ValueType::Boolean(Some(b)) => PrismaValue::Boolean(b), + quaint::ValueType::Boolean(None) => PrismaValue::Null, + quaint::ValueType::Char(Some(c)) => PrismaValue::String(c.to_string()), + quaint::ValueType::Char(None) => PrismaValue::Null, + quaint::ValueType::Array(Some(a)) => { + PrismaValue::List(a.into_iter().map(quaint_value_to_prisma_value).collect()) + } + quaint::ValueType::Array(None) => PrismaValue::Null, + quaint::ValueType::Numeric(Some(bd)) => PrismaValue::Float(bd), + quaint::ValueType::Numeric(None) => PrismaValue::Null, + quaint::ValueType::Json(Some(j)) => PrismaValue::Json(j.to_string()), + quaint::ValueType::Json(None) => PrismaValue::Null, + quaint::ValueType::Xml(Some(x)) => PrismaValue::String(x.into_owned()), + quaint::ValueType::Xml(None) => PrismaValue::Null, + quaint::ValueType::Uuid(Some(u)) => PrismaValue::Uuid(u), + quaint::ValueType::Uuid(None) => PrismaValue::Null, + quaint::ValueType::DateTime(Some(dt)) => PrismaValue::DateTime(dt.into()), + quaint::ValueType::DateTime(None) => PrismaValue::Null, + quaint::ValueType::Date(Some(d)) => { + let dt = DateTime::::from_naive_utc_and_offset(d.and_hms_opt(0, 0, 0).unwrap(), Utc); + PrismaValue::DateTime(dt.into()) + } + quaint::ValueType::Date(None) => PrismaValue::Null, + quaint::ValueType::Time(Some(t)) => { + let d = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + let dt = DateTime::::from_naive_utc_and_offset(d.and_time(t), Utc); + PrismaValue::DateTime(dt.into()) + } + quaint::ValueType::Time(None) => PrismaValue::Null, + quaint::ValueType::Var(name, vt) => PrismaValue::Placeholder { + name: name.into_owned(), + r#type: var_type_to_placeholder_type(&vt), + }, + } +} + +fn var_type_to_placeholder_type(vt: &VarType) -> PlaceholderType { + match vt { + VarType::Unknown => PlaceholderType::Any, + VarType::Int32 => PlaceholderType::Int, + VarType::Int64 => PlaceholderType::BigInt, + VarType::Float => PlaceholderType::Float, + VarType::Double => PlaceholderType::Float, + VarType::Text => PlaceholderType::String, + VarType::Enum => PlaceholderType::String, + VarType::Bytes => PlaceholderType::Bytes, + VarType::Boolean => PlaceholderType::Boolean, + VarType::Char => PlaceholderType::String, + VarType::Array(t) => PlaceholderType::Array(Box::new(var_type_to_placeholder_type(t))), + VarType::Numeric => PlaceholderType::Float, + VarType::Json => PlaceholderType::Object, + VarType::Xml => PlaceholderType::String, + VarType::Uuid => PlaceholderType::String, + VarType::DateTime => PlaceholderType::Date, + VarType::Date => PlaceholderType::Date, + VarType::Time => PlaceholderType::Date, + } +} diff --git a/query-engine/core/src/compiler/translate/query/read.rs b/query-engine/core/src/compiler/translate/query/read.rs new file mode 100644 index 000000000000..076d4379566a --- /dev/null +++ b/query-engine/core/src/compiler/translate/query/read.rs @@ -0,0 +1,74 @@ +use query_structure::ModelProjection; +use sql_query_connector::{ + context::Context, model_extensions::AsColumns, query_arguments_ext::QueryArgumentsExt, query_builder, +}; + +use crate::{ + compiler::{expression::Expression, translate::TranslateResult}, + ReadQuery, RelatedRecordsQuery, +}; + +use super::build_db_query; + +pub(crate) fn translate_read_query(query: ReadQuery, ctx: &Context<'_>) -> TranslateResult { + Ok(match query { + ReadQuery::RecordQuery(rq) => { + let selected_fields = rq.selected_fields.without_relations().into_virtuals_last(); + + let query = query_builder::read::get_records( + &rq.model, + ModelProjection::from(&selected_fields) + .as_columns(ctx) + .mark_all_selected(), + selected_fields.virtuals(), + rq.filter.expect("ReadOne query should always have filter set"), + ctx, + ) + .limit(1); + + Expression::Query(build_db_query(query)?) + } + + ReadQuery::ManyRecordsQuery(mrq) => { + let selected_fields = mrq.selected_fields.without_relations().into_virtuals_last(); + let needs_reversed_order = mrq.args.needs_reversed_order(); + + // TODO: we ignore chunking for now + let query = query_builder::read::get_records( + &mrq.model, + ModelProjection::from(&selected_fields) + .as_columns(ctx) + .mark_all_selected(), + selected_fields.virtuals(), + mrq.args, + ctx, + ); + + let expr = Expression::Query(build_db_query(query)?); + + if needs_reversed_order { + Expression::Reverse(Box::new(expr)) + } else { + expr + } + } + + ReadQuery::RelatedRecordsQuery(rrq) => { + if rrq.parent_field.relation().is_many_to_many() { + build_read_m2m_query(rrq, ctx)? + } else { + build_read_one2m_query(rrq, ctx)? + } + } + + _ => unimplemented!(), + }) +} + +fn build_read_m2m_query(_query: RelatedRecordsQuery, _ctx: &Context<'_>) -> TranslateResult { + todo!() +} + +fn build_read_one2m_query(_query: RelatedRecordsQuery, _ctx: &Context<'_>) -> TranslateResult { + todo!() +} diff --git a/query-engine/core/src/compiler/translate/query/write.rs b/query-engine/core/src/compiler/translate/query/write.rs new file mode 100644 index 000000000000..a3a39f2372da --- /dev/null +++ b/query-engine/core/src/compiler/translate/query/write.rs @@ -0,0 +1,57 @@ +use query_structure::ModelProjection; +use sql_query_connector::{context::Context, generate_insert_statements, query_builder}; + +use crate::{ + compiler::{expression::Expression, translate::TranslateResult}, + WriteQuery, +}; + +use super::build_db_query; + +pub(crate) fn translate_write_query(query: WriteQuery, ctx: &Context<'_>) -> TranslateResult { + Ok(match query { + WriteQuery::CreateRecord(cr) => { + // TODO: MySQL needs additional logic to generate IDs on our side. + // See sql_query_connector::database::operations::write::create_record + let query = query_builder::write::create_record( + &cr.model, + cr.args, + &ModelProjection::from(&cr.selected_fields), + ctx, + ); + + // TODO: we probably need some additional node type or extra info in the WriteQuery node + // to help the client executor figure out the returned ID in the case when it's inferred + // from the query arguments. + Expression::Execute(build_db_query(query)?) + } + + WriteQuery::CreateManyRecords(cmr) => { + if let Some(selected_fields) = cmr.selected_fields { + Expression::Concat( + generate_insert_statements( + &cmr.model, + cmr.args, + cmr.skip_duplicates, + Some(&selected_fields.fields.into()), + ctx, + ) + .into_iter() + .map(build_db_query) + .map(|maybe_db_query| maybe_db_query.map(Expression::Execute)) + .collect::>>()?, + ) + } else { + Expression::Sum( + generate_insert_statements(&cmr.model, cmr.args, cmr.skip_duplicates, None, ctx) + .into_iter() + .map(build_db_query) + .map(|maybe_db_query| maybe_db_query.map(Expression::Execute)) + .collect::>>()?, + ) + } + } + + _ => todo!(), + }) +} diff --git a/query-engine/core/src/constants.rs b/query-engine/core/src/constants.rs index f6a9eb403646..2ec2a7680060 100644 --- a/query-engine/core/src/constants.rs +++ b/query-engine/core/src/constants.rs @@ -12,6 +12,7 @@ pub mod custom_types { pub const ENUM: &str = "Enum"; pub const FIELD_REF: &str = "FieldRef"; pub const RAW: &str = "Raw"; + pub const PARAM: &str = "Param"; pub fn make_object(typ: &str, value: PrismaValue) -> PrismaValue { PrismaValue::Object(vec![make_type_pair(typ), make_value_pair(value)]) diff --git a/query-engine/core/src/error.rs b/query-engine/core/src/error.rs index b067a325a4a5..e779fc311b3a 100644 --- a/query-engine/core/src/error.rs +++ b/query-engine/core/src/error.rs @@ -1,4 +1,4 @@ -use crate::{InterpreterError, QueryGraphBuilderError, RelationViolation, TransactionError}; +use crate::{compiler::CompileError, InterpreterError, QueryGraphBuilderError, RelationViolation, TransactionError}; use connector::error::ConnectorError; use query_structure::DomainError; use thiserror::Error; @@ -67,6 +67,9 @@ pub enum CoreError { #[error("Query timed out")] QueryTimeout, + + #[error("Error compiling a query: {0}")] + CompileError(#[from] CompileError), } impl CoreError { diff --git a/query-engine/core/src/lib.rs b/query-engine/core/src/lib.rs index 7e1868cc017f..3280660dd458 100644 --- a/query-engine/core/src/lib.rs +++ b/query-engine/core/src/lib.rs @@ -3,6 +3,7 @@ #[macro_use] extern crate tracing; +pub mod compiler; pub mod constants; pub mod executor; pub mod protocol; diff --git a/query-engine/core/src/query_document/parser.rs b/query-engine/core/src/query_document/parser.rs index 0f512612144d..85ee16e27df7 100644 --- a/query-engine/core/src/query_document/parser.rs +++ b/query-engine/core/src/query_document/parser.rs @@ -232,6 +232,11 @@ impl QueryDocumentParser { possible_input_types: &[InputType<'a>], query_schema: &'a QuerySchema, ) -> QueryParserResult> { + // TODO: we disabled generating Param explicitly in the query schema for now + if let ArgumentValue::Scalar(pv @ PrismaValue::Placeholder { .. }) = &value { + return Ok(ParsedInputValue::Single(pv.clone())); + } + let mut failures = Vec::new(); macro_rules! try_this { @@ -411,6 +416,8 @@ impl QueryDocumentParser { // UUID coercion matchers (PrismaValue::Uuid(uuid), ScalarType::String) => Ok(PrismaValue::String(uuid.to_string())), + (pv @ PrismaValue::Placeholder { .. }, ScalarType::Param) => Ok(pv), + // All other combinations are value type mismatches. (_, _) => Err(ValidationError::invalid_argument_type( selection_path.segments(), @@ -908,6 +915,7 @@ pub(crate) mod conversions { PrismaValue::Float(_) => "Float".to_string(), PrismaValue::BigInt(_) => "BigInt".to_string(), PrismaValue::Bytes(_) => "Bytes".to_string(), + PrismaValue::Placeholder { r#type, .. } => r#type.to_string(), } } diff --git a/query-engine/core/src/query_graph_builder/error.rs b/query-engine/core/src/query_graph_builder/error.rs index 825b312bbbf5..937b3842d34f 100644 --- a/query-engine/core/src/query_graph_builder/error.rs +++ b/query-engine/core/src/query_graph_builder/error.rs @@ -43,6 +43,14 @@ pub enum QueryGraphBuilderError { QueryGraphError(QueryGraphError), } +impl std::fmt::Display for QueryGraphBuilderError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Debug::fmt(self, f) + } +} + +impl std::error::Error for QueryGraphBuilderError {} + #[derive(Debug)] pub struct RelationViolation { pub(crate) relation_name: String, diff --git a/query-engine/dmmf/src/ast_builders/datamodel_ast_builder.rs b/query-engine/dmmf/src/ast_builders/datamodel_ast_builder.rs index 7c202dd962d5..77aac40b78d2 100644 --- a/query-engine/dmmf/src/ast_builders/datamodel_ast_builder.rs +++ b/query-engine/dmmf/src/ast_builders/datamodel_ast_builder.rs @@ -335,6 +335,7 @@ fn prisma_value_to_serde(value: &PrismaValue) -> serde_json::Value { serde_json::Value::Object(map) } + PrismaValue::Placeholder { .. } => unreachable!(), } } diff --git a/query-engine/dmmf/src/ast_builders/schema_ast_builder/type_renderer.rs b/query-engine/dmmf/src/ast_builders/schema_ast_builder/type_renderer.rs index dd4f26660440..c88e83438a50 100644 --- a/query-engine/dmmf/src/ast_builders/schema_ast_builder/type_renderer.rs +++ b/query-engine/dmmf/src/ast_builders/schema_ast_builder/type_renderer.rs @@ -49,6 +49,7 @@ pub(super) fn render_output_type<'a>(output_type: &OutputType<'a>, ctx: &mut Ren ScalarType::UUID => "UUID", ScalarType::JsonList => "Json", ScalarType::Bytes => "Bytes", + ScalarType::Param => unreachable!("output type must not be Param"), }; DmmfTypeReference { diff --git a/query-engine/driver-adapters/src/conversion/js_arg_type.rs b/query-engine/driver-adapters/src/conversion/js_arg_type.rs index e1ea7c1c5754..63ffafd3b66b 100644 --- a/query-engine/driver-adapters/src/conversion/js_arg_type.rs +++ b/query-engine/driver-adapters/src/conversion/js_arg_type.rs @@ -89,5 +89,6 @@ pub fn value_to_js_arg_type(value: &quaint::Value) -> JSArgType { quaint::ValueType::DateTime(_) => JSArgType::DateTime, quaint::ValueType::Date(_) => JSArgType::Date, quaint::ValueType::Time(_) => JSArgType::Time, + quaint::ValueType::Var(_, _) => unreachable!(), } } diff --git a/query-engine/query-engine-node-api/src/engine.rs b/query-engine/query-engine-node-api/src/engine.rs index 9e6bf180171f..1d17eb56ff87 100644 --- a/query-engine/query-engine-node-api/src/engine.rs +++ b/query-engine/query-engine-node-api/src/engine.rs @@ -354,6 +354,35 @@ impl QueryEngine { .await } + #[napi] + pub async fn compile(&self, request: String, human_readable: bool) -> napi::Result { + let dispatcher = self.logger.dispatcher(); + let recorder = self.logger.recorder(); + + async_panic_to_js_error(async { + let inner = self.inner.read().await; + let engine = inner.as_engine()?; + + let request = RequestBody::try_from_str(&request, engine.engine_protocol())?; + let query_doc = request + .into_doc(engine.query_schema()) + .map_err(|err| napi::Error::from_reason(err.to_string()))?; + + let plan = query_core::compiler::compile(engine.query_schema(), query_doc).map_err(ApiError::from)?; + + let response = if human_readable { + plan.to_string() + } else { + serde_json::to_string(&plan)? + }; + + Ok(response) + }) + .with_subscriber(dispatcher) + .with_optional_recorder(recorder) + .await + } + /// If connected, attempts to start a transaction in the core and returns its ID. #[napi] pub async fn start_transaction(&self, input: String, trace: String, request_id: String) -> napi::Result { diff --git a/query-engine/query-engine-wasm/src/wasm/engine.rs b/query-engine/query-engine-wasm/src/wasm/engine.rs index 5adc474e5cfb..2f76eb2375e6 100644 --- a/query-engine/query-engine-wasm/src/wasm/engine.rs +++ b/query-engine/query-engine-wasm/src/wasm/engine.rs @@ -359,4 +359,28 @@ impl QueryEngine { .with_subscriber(dispatcher) .await } + + #[wasm_bindgen] + pub async fn compile( + &self, + request: String, + _human_readable: bool, // ignored on wasm to not compile it in + ) -> Result { + let dispatcher = self.logger.dispatcher(); + + async { + let inner = self.inner.read().await; + let engine = inner.as_engine()?; + + let request = RequestBody::try_from_str(&request, engine.engine_protocol())?; + let query_doc = request + .into_doc(engine.query_schema()) + .map_err(|err| napi::Error::from_reason(err.to_string()))?; + + let plan = query_core::compiler::compile(engine.query_schema(), query_doc).map_err(ApiError::from)?; + Ok(serde_json::to_string(&plan)?) + } + .with_subscriber(dispatcher) + .await + } } diff --git a/query-engine/query-engine/Cargo.toml b/query-engine/query-engine/Cargo.toml index 439a64f987c9..db011f9238d7 100644 --- a/query-engine/query-engine/Cargo.toml +++ b/query-engine/query-engine/Cargo.toml @@ -38,6 +38,7 @@ telemetry = { path = "../../libs/telemetry" } serial_test = "*" quaint.workspace = true indoc.workspace = true +indexmap.workspace = true [build-dependencies] build-utils.path = "../../libs/build-utils" diff --git a/query-engine/query-engine/examples/compiler.rs b/query-engine/query-engine/examples/compiler.rs new file mode 100644 index 000000000000..7a1150cc3651 --- /dev/null +++ b/query-engine/query-engine/examples/compiler.rs @@ -0,0 +1,63 @@ +use std::sync::Arc; + +use query_core::{query_graph_builder::QueryGraphBuilder, QueryDocument}; +use request_handlers::{JsonBody, JsonSingleQuery, RequestBody}; +use serde_json::json; + +pub fn main() -> anyhow::Result<()> { + let schema_string = include_str!("./schema.prisma"); + let schema = psl::validate(schema_string.into()); + + if schema.diagnostics.has_errors() { + anyhow::bail!("invalid schema"); + } + + let schema = Arc::new(schema); + let query_schema = Arc::new(query_core::schema::build(schema, true)); + + // prisma.user.findMany({ + // where: { + // email: Prisma.Param("userEmail") + // } + // }) + let query: JsonSingleQuery = serde_json::from_value(json!({ + "modelName": "User", + "action": "findMany", + "query": { + "arguments": { + "where": { + "email": { + "$type": "Param", + "value": "userEmail" + } + } + }, + "selection": { + "$scalars": true, + "posts": { + "arguments": {}, + "selection": { + "$scalars": true + } + } + } + } + }))?; + + let request = RequestBody::Json(JsonBody::Single(query)); + let doc = request.into_doc(&query_schema)?; + + let QueryDocument::Single(query) = doc else { + anyhow::bail!("expected single query"); + }; + + let (graph, _serializer) = QueryGraphBuilder::new(&query_schema).build(query)?; + + println!("{graph}"); + + let expr = query_core::compiler::translate(graph)?; + + println!("{expr}"); + + Ok(()) +} diff --git a/query-engine/query-engine/examples/schema.prisma b/query-engine/query-engine/examples/schema.prisma new file mode 100644 index 000000000000..ab9cd218da49 --- /dev/null +++ b/query-engine/query-engine/examples/schema.prisma @@ -0,0 +1,27 @@ +generator client { + provider = "prisma-client-js" +} + +datasource db { + provider = "postgresql" + url = "postgresql://postgres:prisma@localhost:5438" +} + +model User { + id String @id @default(cuid()) + email String @unique + name String? + posts Post[] + val Int? +} + +model Post { + id String @id @default(cuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + published Boolean + title String + content String? + authorId String? + author User? @relation(fields: [authorId], references: [id]) +} diff --git a/query-engine/request-handlers/src/protocols/graphql/schema_renderer/type_renderer.rs b/query-engine/request-handlers/src/protocols/graphql/schema_renderer/type_renderer.rs index 82a70e53dc00..b43449d34cc8 100644 --- a/query-engine/request-handlers/src/protocols/graphql/schema_renderer/type_renderer.rs +++ b/query-engine/request-handlers/src/protocols/graphql/schema_renderer/type_renderer.rs @@ -47,6 +47,7 @@ impl<'a> GqlTypeRenderer<'a> { ScalarType::UUID => "UUID", ScalarType::JsonList => "Json", ScalarType::Bytes => "Bytes", + ScalarType::Param => "Param", ScalarType::Null => unreachable!("Null types should not be picked for GQL rendering."), }; @@ -86,6 +87,7 @@ impl<'a> GqlTypeRenderer<'a> { ScalarType::JsonList => "Json", ScalarType::Bytes => "Bytes", ScalarType::Null => unreachable!("Null types should not be picked for GQL rendering."), + ScalarType::Param => unreachable!("output type must not be Param"), }; stringified.to_string() diff --git a/query-engine/request-handlers/src/protocols/json/protocol_adapter.rs b/query-engine/request-handlers/src/protocols/json/protocol_adapter.rs index 650f4e1a8bb9..e30077ecce44 100644 --- a/query-engine/request-handlers/src/protocols/json/protocol_adapter.rs +++ b/query-engine/request-handlers/src/protocols/json/protocol_adapter.rs @@ -6,7 +6,7 @@ use query_core::{ schema::{ObjectType, OutputField, QuerySchema}, ArgumentValue, Operation, Selection, }; -use query_structure::{decode_bytes, parse_datetime, prelude::ParentContainer, Field}; +use query_structure::{decode_bytes, parse_datetime, prelude::ParentContainer, Field, PlaceholderType, PrismaValue}; use serde_json::Value as JsonValue; use std::str::FromStr; @@ -243,6 +243,20 @@ impl<'a> JsonProtocolAdapter<'a> { Ok(ArgumentValue::FieldRef(values)) } + Some(custom_types::PARAM) => { + let name = obj + .get(custom_types::VALUE) + .and_then(|v| v.as_str()) + .ok_or_else(build_err)? + .to_owned(); + + let placeholder = PrismaValue::Placeholder { + name, + r#type: PlaceholderType::Any, + }; + + Ok(ArgumentValue::Scalar(placeholder)) + } _ => { let values = obj .into_iter() @@ -421,12 +435,12 @@ mod tests { generator client { provider = "prisma-client-js" } - + datasource db { provider = "mongodb" url = "mongodb://" } - + model User { id String @id @map("_id") name String? @@ -441,7 +455,7 @@ mod tests { model Post { id String @id @map("_id") title String - userId String + userId String user User @relation(fields: [userId], references: [id]) } @@ -1391,28 +1405,28 @@ mod tests { generator client { provider = "prisma-client-js" } - + datasource db { provider = "mongodb" url = "mongodb://" } - + model Comment { id String @id @default(auto()) @map("_id") @db.ObjectId - + country String? content CommentContent } - + type CommentContent { text String upvotes CommentContentUpvotes[] } - + type CommentContentUpvotes { vote Boolean userId String - } + } "#; let mut schema = psl::validate(schema_str.into()); @@ -1532,21 +1546,21 @@ mod tests { generator client { provider = "prisma-client-js" } - + datasource db { provider = "mongodb" url = "mongodb://" } - + model List { id String @id @default(auto()) @map("_id") @db.ObjectId head ListNode? } - + type ListNode { value Int - next ListNode? - } + next ListNode? + } "#; let mut schema = psl::validate(schema_str.into()); @@ -1586,24 +1600,24 @@ mod tests { generator client { provider = "prisma-client-js" } - + datasource db { provider = "mongodb" url = "mongodb://" } - + model User { id String @id @default(auto()) @map("_id") @db.ObjectId - + billingAddress Address shippingAddress Address } - + type Address { number Int street String zipCode Int - } + } "#; let mut schema = psl::validate(schema_str.into()); @@ -1675,28 +1689,28 @@ mod tests { generator client { provider = "prisma-client-js" } - + datasource db { provider = "mongodb" url = "mongodb://" } - + model User { id String @id @default(auto()) @map("_id") @db.ObjectId billingAddress Address shippingAddress Address } - + type Address { streetAddress StreetAddress zipCode String city String } - + type StreetAddress { streetName String houseNumber String - } + } "#; let mut schema = psl::validate(schema_str.into()); diff --git a/query-engine/schema/src/input_types.rs b/query-engine/schema/src/input_types.rs index 4ce09bd97bd2..3b47b2a37789 100644 --- a/query-engine/schema/src/input_types.rs +++ b/query-engine/schema/src/input_types.rs @@ -122,10 +122,15 @@ pub struct InputField<'a> { impl<'a> InputField<'a> { pub(crate) fn new( name: Cow<'a, str>, - field_types: Vec>, + mut field_types: Vec>, default_value: Option, is_required: bool, ) -> InputField<'a> { + // todo + #[allow(clippy::overly_complex_bool_expr)] + if false && field_types.iter().any(|t| t.is_scalar()) { + field_types.push(InputType::Scalar(ScalarType::Param)); + } InputField { name, default_value, @@ -279,6 +284,10 @@ impl<'a> InputType<'a> { InputType::Enum(containing) } + pub fn is_scalar(&self) -> bool { + matches!(self, Self::Scalar(_)) + } + pub fn is_json(&self) -> bool { matches!( self, diff --git a/query-engine/schema/src/query_schema.rs b/query-engine/schema/src/query_schema.rs index f8fede0ea355..af7885f1a4a4 100644 --- a/query-engine/schema/src/query_schema.rs +++ b/query-engine/schema/src/query_schema.rs @@ -370,6 +370,7 @@ pub enum ScalarType { JsonList, UUID, Bytes, + Param, } impl fmt::Display for ScalarType { @@ -387,6 +388,7 @@ impl fmt::Display for ScalarType { ScalarType::UUID => "UUID", ScalarType::JsonList => "Json", ScalarType::Bytes => "Bytes", + ScalarType::Param => "Param", }; f.write_str(typ) diff --git a/shell.nix b/shell.nix index 309f6275660f..bf6c03cc3716 100644 --- a/shell.nix +++ b/shell.nix @@ -45,7 +45,9 @@ pkgs.mkShell { useLld = "-C link-arg=-fuse-ld=lld"; in pkgs.lib.optionalString pkgs.stdenv.isLinux '' - export CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUSTFLAGS="${useLld}" - export CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_RUSTFLAGS="${useLld}" + if [ ! -f .cargo/config.toml ]; then + export CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUSTFLAGS="${useLld}" + export CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_RUSTFLAGS="${useLld}" + fi ''; }