Skip to content

Commit

Permalink
fix: Support Dict types in in_list physical plans (apache#10031)
Browse files Browse the repository at this point in the history
* fix: Relax type check with dict types in in_list

* refine comments

* fix style, refine comments and address reviewer's comments

* refine comments

* address comments
  • Loading branch information
advancedxy authored Apr 13, 2024
1 parent 694d4b8 commit d698d9d
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 4 deletions.
126 changes: 122 additions & 4 deletions datafusion/physical-expr/src/expressions/in_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,18 @@ impl PartialEq<dyn Any> for InListExpr {
}
}

/// Checks if two types are logically equal, dictionary types are compared by their value types.
fn is_logically_eq(lhs: &DataType, rhs: &DataType) -> bool {
match (lhs, rhs) {
(DataType::Dictionary(_, v1), DataType::Dictionary(_, v2)) => {
v1.as_ref().eq(v2.as_ref())
}
(DataType::Dictionary(_, l), _) => l.as_ref().eq(rhs),
(_, DataType::Dictionary(_, r)) => lhs.eq(r.as_ref()),
_ => lhs.eq(rhs),
}
}

/// Creates a unary expression InList
pub fn in_list(
expr: Arc<dyn PhysicalExpr>,
Expand All @@ -426,7 +438,7 @@ pub fn in_list(
let expr_data_type = expr.data_type(schema)?;
for list_expr in list.iter() {
let list_expr_data_type = list_expr.data_type(schema)?;
if !expr_data_type.eq(&list_expr_data_type) {
if !is_logically_eq(&expr_data_type, &list_expr_data_type) {
return internal_err!(
"The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}"
);
Expand Down Expand Up @@ -499,7 +511,21 @@ mod tests {
macro_rules! in_list {
($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{
let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?;
let expr = in_list(cast_expr, cast_list_exprs, $NEGATED, $SCHEMA).unwrap();
in_list_raw!(
$BATCH,
cast_list_exprs,
$NEGATED,
$EXPECTED,
cast_expr,
$SCHEMA
);
}};
}

// applies the in_list expr to an input batch and list without cast
macro_rules! in_list_raw {
($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{
let expr = in_list($COL, $LIST, $NEGATED, $SCHEMA).unwrap();
let result = expr
.evaluate(&$BATCH)?
.into_array($BATCH.num_rows())
Expand Down Expand Up @@ -540,7 +566,7 @@ mod tests {
&schema
);

// expression: "a not in ("a", "b")"
// expression: "a in ("a", "b", null)"
let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))];
in_list!(
batch,
Expand All @@ -551,7 +577,7 @@ mod tests {
&schema
);

// expression: "a not in ("a", "b")"
// expression: "a not in ("a", "b", null)"
let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))];
in_list!(
batch,
Expand Down Expand Up @@ -1314,4 +1340,96 @@ mod tests {

Ok(())
}

#[test]
fn in_list_utf8_with_dict_types() -> Result<()> {
fn dict_lit(key_type: DataType, value: &str) -> Arc<dyn PhysicalExpr> {
lit(ScalarValue::Dictionary(
Box::new(key_type),
Box::new(ScalarValue::new_utf8(value.to_string())),
))
}

fn null_dict_lit(key_type: DataType) -> Arc<dyn PhysicalExpr> {
lit(ScalarValue::Dictionary(
Box::new(key_type),
Box::new(ScalarValue::Utf8(None)),
))
}

let schema = Schema::new(vec![Field::new(
"a",
DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
true,
)]);
let a: UInt16DictionaryArray =
vec![Some("a"), Some("d"), None].into_iter().collect();
let col_a = col("a", &schema)?;
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

// expression: "a in ("a", "b")"
let lists = [
vec![lit("a"), lit("b")],
vec![
dict_lit(DataType::Int8, "a"),
dict_lit(DataType::UInt16, "b"),
],
];
for list in lists.iter() {
in_list_raw!(
batch,
list.clone(),
&false,
vec![Some(true), Some(false), None],
col_a.clone(),
&schema
);
}

// expression: "a not in ("a", "b")"
for list in lists.iter() {
in_list_raw!(
batch,
list.clone(),
&true,
vec![Some(false), Some(true), None],
col_a.clone(),
&schema
);
}

// expression: "a in ("a", "b", null)"
let lists = [
vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))],
vec![
dict_lit(DataType::Int8, "a"),
dict_lit(DataType::UInt16, "b"),
null_dict_lit(DataType::UInt16),
],
];
for list in lists.iter() {
in_list_raw!(
batch,
list.clone(),
&false,
vec![Some(true), None, None],
col_a.clone(),
&schema
);
}

// expression: "a not in ("a", "b", null)"
for list in lists.iter() {
in_list_raw!(
batch,
list.clone(),
&true,
vec![Some(false), None, None],
col_a.clone(),
&schema
);
}

Ok(())
}
}
39 changes: 39 additions & 0 deletions datafusion/sqllogictest/test_files/dictionary.slt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,22 @@ f3 Utf8 YES
f4 Float64 YES
time Timestamp(Nanosecond, None) YES

# in list with dictionary input
query BBB
SELECT
tag_id in ('1000'), '1000' in (tag_id, null), arrow_cast('999','Dictionary(Int32, Utf8)') in (tag_id, null)
FROM m1
----
true true NULL
true true NULL
true true NULL
true true NULL
true true NULL
true true NULL
true true NULL
true true NULL
true true NULL
true true NULL

# Table m2 with a tag columns `tag_id` and `type`, a field column `f5`, and `time`
statement ok
Expand Down Expand Up @@ -165,6 +181,29 @@ order by date_bin('30 minutes', time) DESC
3 400 600 500 2023-12-04T00:30:00
3 100 300 200 2023-12-04T00:00:00

# query with in list
query BBBBBBBB
SELECT
type in ('active', 'passive')
, 'active' in (type)
, 'active' in (type, null)
, arrow_cast('passive','Dictionary(Int8, Utf8)') in (type, null)
, tag_id in ('1000', '2000')
, tag_id in ('999')
, '1000' in (tag_id, null)
, arrow_cast('999','Dictionary(Int16, Utf8)') in (tag_id, null)
FROM m2
----
true true true NULL true false true NULL
true true true NULL true false true NULL
true true true NULL true false true NULL
true true true NULL true false true NULL
true true true NULL true false true NULL
true true true NULL true false true NULL
true false NULL true true false true NULL
true false NULL true true false true NULL
true false NULL true true false true NULL
true false NULL true true false true NULL


# Reproducer for https://github.com/apache/arrow-datafusion/issues/8738
Expand Down

0 comments on commit d698d9d

Please sign in to comment.