Skip to content

Commit

Permalink
feat: postgres arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
chesedo authored Jan 17, 2022
1 parent 05b5642 commit 286770f
Show file tree
Hide file tree
Showing 17 changed files with 2,005 additions and 48 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/synth-postgres.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,5 @@ jobs:
working-directory: synth/testing_harness/postgres
- run: ./e2e.sh test-warning
working-directory: synth/testing_harness/postgres
- run: ./e2e.sh test-arrays
working-directory: synth/testing_harness/postgres
96 changes: 95 additions & 1 deletion core/src/graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,98 @@ impl Type<MySql> for Value {
}
}

impl Value {
fn to_postgres_string(&self) -> String {
match self {
Self::Array(arr) => {
let (typ, _) = self.get_postgres_type();
let inner = arr
.iter()
.map(|v| v.to_postgres_string())
.collect::<Vec<String>>()
.join(", ");

if typ == "jsonb" {
format!("[{}]", inner)
} else {
format!("{{{}}}", inner)
}
}
Self::Null(_) => "NULL".to_string(),
Self::Bool(b) => b.to_string(),
Self::Number(num) => match num {
Number::F32(f32) => (*f32).to_string(),
Number::F64(f64) => (*f64).to_string(),
_ => num.to_string(),
},
Self::String(str) => format!("\"{}\"", str),
Self::DateTime(date) => date.format_to_string(),
Self::Object(_) => {
serde_json::to_string(&json::synth_val_to_json(self.clone())).unwrap()
}
}
}

pub fn get_postgres_type(&self) -> (&'static str, usize) {
let mut depth = 0;
let mut typ = "";

let mut current = Some(self);

// Based on https://docs.rs/sqlx-core/0.5.9/sqlx_core/postgres/types/index.html
while let Some(c) = current {
let pair = match c {
Value::Null(_) => (None, "unknown"),
Value::Bool(_) => (None, "bool"),
Value::Number(num) => match *num {
Number::I8(_) => (None, "char"),
Number::I16(_) => (None, "int2"),
Number::I32(_) => (None, "int4"),
Number::I64(_) => (None, "int8"),
Number::I128(_) => (None, "numeric"),
Number::U8(_) => (None, "char"),
Number::U16(_) => (None, "int2"),
Number::U32(_) => (None, "int4"),
Number::U64(_) => (None, "int8"),
Number::U128(_) => (None, "numeric"),
Number::F32(_) => (None, "float4"),
Number::F64(_) => (None, "float8"),
},
Value::String(_) => (None, "text"),
Value::DateTime(ChronoValueAndFormat { value, .. }) => match value {
ChronoValue::NaiveDate(_) => (None, "date"),
ChronoValue::NaiveTime(_) => (None, "time"),
ChronoValue::NaiveDateTime(_) => (None, "timestamp"),
ChronoValue::DateTime(_) => (None, "timestamptz"),
},
Value::Object(_) => (None, "jsonb"),
Value::Array(arr) => {
depth += 1;
if arr.is_empty() {
(None, "unknown")
} else {
(Some(&arr[0]), "")
}
}
};

current = pair.0;
typ = pair.1;
}

(typ, depth)
}
}

impl Encode<'_, Postgres> for Value {
fn produces(&self) -> Option<PgTypeInfo> {
// Only arrays needs a special type
match self {
Value::Array(_) => Some(PgTypeInfo::with_name("text")),
_ => None,
}
}

fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
match self {
Value::Null(_) => IsNull::Yes,
Expand Down Expand Up @@ -409,7 +500,10 @@ impl Encode<'_, Postgres> for Value {
json::synth_val_to_json(self.clone()),
buf,
),
Value::Array(arr) => arr.encode_by_ref(buf), //TODO special-case for BYTEA
Value::Array(_) => {
let s = self.to_postgres_string();
<String as Encode<'_, Postgres>>::encode_by_ref(&s, buf)
} //TODO special-case for BYTEA
}
}
}
Expand Down
12 changes: 9 additions & 3 deletions synth/src/datasource/mysql_datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl RelationalDataSource for MySqlDataSource {
async fn execute_query(
&self,
query: String,
query_params: Vec<&Value>,
query_params: Vec<Value>,
) -> Result<MySqlQueryResult> {
let mut query = sqlx::query(query.as_str());

Expand Down Expand Up @@ -199,7 +199,13 @@ impl RelationalDataSource for MySqlDataSource {
Ok(content)
}

fn extend_parameterised_query(query: &mut String, _curr_index: usize, extend: usize) {
fn extend_parameterised_query(
query: &mut String,
_curr_index: usize,
query_params: Vec<Value>,
) {
let extend = query_params.len();

query.push('(');
for i in 0..extend {
query.push('?');
Expand Down Expand Up @@ -299,7 +305,7 @@ fn try_match_value(row: &MySqlRow, column: &MySqlColumn) -> Result<Value> {
return Ok(Value::Number(Number::from(truncated)));
}

bail!("Failed to convert Postgresql numeric data type to 64 bit float")
bail!("Failed to convert Mysql numeric data type to 64 bit float")
}
"timestamp" => Value::String(row.try_get::<String, &str>(column.name())?),
"date" => Value::String(format!(
Expand Down
172 changes: 158 additions & 14 deletions synth/src/datasource/postgres_datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use std::collections::BTreeMap;
use std::convert::TryFrom;
use synth_core::schema::number_content::{F32, F64, I32, I64};
use synth_core::schema::{
BoolContent, Categorical, ChronoValueType, DateTimeContent, NumberContent, RangeStep,
RegexContent, StringContent, Uuid,
ArrayContent, BoolContent, Categorical, ChronoValue, ChronoValueAndFormat, ChronoValueType,
DateTimeContent, NumberContent, ObjectContent, RangeStep, RegexContent, StringContent, Uuid,
};
use synth_core::{Content, Value};

Expand Down Expand Up @@ -120,7 +120,7 @@ impl RelationalDataSource for PostgresDataSource {
async fn execute_query(
&self,
query: String,
query_params: Vec<&Value>,
query_params: Vec<Value>,
) -> Result<PgQueryResult> {
let mut query = sqlx::query(query.as_str());

Expand Down Expand Up @@ -255,20 +255,20 @@ impl RelationalDataSource for PostgresDataSource {
RegexContent::pattern(pattern).context("pattern will always compile")?,
))
}
"int2" => Content::Number(NumberContent::I64(I64::Range(RangeStep::default()))),
"int2" => Content::Number(NumberContent::I32(I32::Range(RangeStep::default()))),
"int4" => Content::Number(NumberContent::I32(I32::Range(RangeStep::default()))),
"int8" => Content::Number(NumberContent::I64(I64::Range(RangeStep::default()))),
"float4" => Content::Number(NumberContent::F32(F32::Range(RangeStep::default()))),
"float8" => Content::Number(NumberContent::F64(F64::Range(RangeStep::default()))),
"numeric" => Content::Number(NumberContent::F64(F64::Range(RangeStep::default()))),
"timestamptz" => Content::DateTime(DateTimeContent {
format: "".to_string(), // todo
format: "%Y-%m-%dT%H:%M:%S%z".to_string(),
type_: ChronoValueType::DateTime,
begin: None,
end: None,
}),
"timestamp" => Content::DateTime(DateTimeContent {
format: "".to_string(), // todo
format: "%Y-%m-%dT%H:%M:%S".to_string(),
type_: ChronoValueType::NaiveDateTime,
begin: None,
end: None,
Expand All @@ -279,20 +279,56 @@ impl RelationalDataSource for PostgresDataSource {
begin: None,
end: None,
}),
"time" => Content::DateTime(DateTimeContent {
format: "%H:%M:%S".to_string(),
type_: ChronoValueType::NaiveTime,
begin: None,
end: None,
}),
"json" | "jsonb" => Content::Object(ObjectContent {
skip_when_null: false,
fields: BTreeMap::new(),
}),
"uuid" => Content::String(StringContent::Uuid(Uuid)),
_ => bail!(
"We haven't implemented a converter for {}",
column_info.data_type
),
_ => {
if let Some(data_type) = column_info.data_type.strip_prefix('_') {
let mut column_info = column_info.clone();
column_info.data_type = data_type.to_string();

Content::Array(ArrayContent::from_content_default_length(
self.decode_to_content(&column_info)?,
))
} else {
bail!(
"We haven't implemented a converter for {}",
column_info.data_type
)
}
}
};

Ok(content)
}

fn extend_parameterised_query(query: &mut String, curr_index: usize, extend: usize) {
fn extend_parameterised_query(query: &mut String, curr_index: usize, query_params: Vec<Value>) {
let extend = query_params.len();

query.push('(');
for i in 0..extend {
query.push_str(&format!("${}", curr_index + i + 1));
for (i, param) in query_params.iter().enumerate() {
let extra = if let Value::Array(_) = param {
let (typ, depth) = param.get_postgres_type();
if typ == "unknown" {
"".to_string() // This is currently not supported
} else if typ == "jsonb" {
"::jsonb".to_string() // Cannot have an array of jsonb - ie jsonb[]
} else {
format!("::{}{}", typ, "[]".repeat(depth))
}
} else {
"".to_string()
};

query.push_str(&format!("${}{}", curr_index + i + 1, extra));
if i != extend - 1 {
query.push(',');
}
Expand Down Expand Up @@ -347,7 +383,10 @@ impl TryFrom<PgRow> for ValueWrapper {
let mut kv = BTreeMap::new();

for column in row.columns() {
let value = try_match_value(&row, column).unwrap_or(Value::Null(()));
let value = try_match_value(&row, column).unwrap_or_else(|err| {
debug!("try_match_value failed: {}", err);
Value::Null(())
});
kv.insert(column.name().to_string(), value);
}

Expand Down Expand Up @@ -389,6 +428,111 @@ fn try_match_value(row: &PgRow, column: &PgColumn) -> Result<Value> {
"{}",
row.try_get::<chrono::NaiveDate, &str>(column.name())?
)),
"time" => Value::String(format!(
"{}",
row.try_get::<chrono::NaiveTime, &str>(column.name())?
)),
"json" | "jsonb" => {
let serde_value = row.try_get::<serde_json::Value, &str>(column.name())?;
serde_json::from_value(serde_value)?
}
"char[]" | "varchar[]" | "text[]" | "citext[]" | "bpchar[]" | "name[]" | "unknown[]" => {
Value::Array(
row.try_get::<Vec<String>, &str>(column.name())
.map(|vec| vec.iter().map(|s| Value::String(s.to_string())).collect())?,
)
}
"bool[]" => Value::Array(
row.try_get::<Vec<bool>, &str>(column.name())
.map(|vec| vec.into_iter().map(Value::Bool).collect())?,
),
"int2[]" => Value::Array(
row.try_get::<Vec<i16>, &str>(column.name())
.map(|vec| vec.into_iter().map(|i| Value::Number(i.into())).collect())?,
),
"int4[]" => Value::Array(
row.try_get::<Vec<i32>, &str>(column.name())
.map(|vec| vec.into_iter().map(|i| Value::Number(i.into())).collect())?,
),
"int8[]" => Value::Array(
row.try_get::<Vec<i64>, &str>(column.name())
.map(|vec| vec.into_iter().map(|i| Value::Number(i.into())).collect())?,
),
"float4[]" => Value::Array(
row.try_get::<Vec<f32>, &str>(column.name())
.map(|vec| vec.into_iter().map(|i| Value::Number(i.into())).collect())?,
),
"float8[]" => Value::Array(
row.try_get::<Vec<f64>, &str>(column.name())
.map(|vec| vec.into_iter().map(|i| Value::Number(i.into())).collect())?,
),
"numeric[]" => {
let vec = row.try_get::<Vec<Decimal>, &str>(column.name())?;
let result: Result<Vec<Value>, _> = vec
.into_iter()
.map(|d| {
if let Some(truncated) = d.to_f64() {
return Ok(Value::Number(truncated.into()));
}

bail!("Failed to convert Postgresql numeric data type to 64 bit float")
})
.collect();

Value::Array(result?)
}
"timestamp[]" => Value::Array(
row.try_get::<Vec<chrono::NaiveDateTime>, &str>(column.name())
.map(|vec| {
vec.into_iter()
.map(|d| {
Value::DateTime(ChronoValueAndFormat {
format: Arc::from("%Y-%m-%dT%H:%M:%S".to_owned()),
value: ChronoValue::NaiveDateTime(d),
})
})
.collect()
})?,
),
"timestamptz[]" => Value::Array(
row.try_get::<Vec<chrono::DateTime<chrono::FixedOffset>>, &str>(column.name())
.map(|vec| {
vec.into_iter()
.map(|d| {
Value::DateTime(ChronoValueAndFormat {
format: Arc::from("%Y-%m-%dT%H:%M:%S%z".to_owned()),
value: ChronoValue::DateTime(d),
})
})
.collect()
})?,
),
"date[]" => Value::Array(
row.try_get::<Vec<chrono::NaiveDate>, &str>(column.name())
.map(|vec| {
vec.into_iter()
.map(|d| {
Value::DateTime(ChronoValueAndFormat {
format: Arc::from("%Y-%m-%d".to_owned()),
value: ChronoValue::NaiveDate(d),
})
})
.collect()
})?,
),
"time[]" => Value::Array(
row.try_get::<Vec<chrono::NaiveTime>, &str>(column.name())
.map(|vec| {
vec.into_iter()
.map(|t| {
Value::DateTime(ChronoValueAndFormat {
format: Arc::from("%H:%M:%S".to_owned()),
value: ChronoValue::NaiveTime(t),
})
})
.collect()
})?,
),
_ => {
bail!(
"Could not convert value. Converter not implemented for {}",
Expand Down
Loading

0 comments on commit 286770f

Please sign in to comment.