Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add union_tag scalar function #14687

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions datafusion/functions/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub mod nvl2;
pub mod planner;
pub mod r#struct;
pub mod union_extract;
pub mod union_tag;
pub mod version;

// create UDFs
Expand All @@ -50,6 +51,7 @@ make_udf_function!(coalesce::CoalesceFunc, coalesce);
make_udf_function!(greatest::GreatestFunc, greatest);
make_udf_function!(least::LeastFunc, least);
make_udf_function!(union_extract::UnionExtractFun, union_extract);
make_udf_function!(union_tag::UnionTagFunc, union_tag);
make_udf_function!(version::VersionFunc, version);

pub mod expr_fn {
Expand Down Expand Up @@ -95,6 +97,10 @@ pub mod expr_fn {
least,
"Returns `least(args...)`, which evaluates to the smallest value in the list of expressions or NULL if all the expressions are NULL",
args,
),(
union_tag,
"Returns the name of the currently selected field in the union",
arg1
));

#[doc = "Returns the value of the field with the given name from the struct"]
Expand Down Expand Up @@ -129,6 +135,7 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
greatest(),
least(),
union_extract(),
union_tag(),
version(),
r#struct(),
]
Expand Down
223 changes: 223 additions & 0 deletions datafusion/functions/src/core/union_tag.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Array, AsArray, DictionaryArray, Int8Array, StringArray};
use arrow::datatypes::DataType;
use datafusion_common::utils::take_function_args;
use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
use datafusion_doc::Documentation;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use datafusion_macros::user_doc;
use std::sync::Arc;

#[user_doc(
doc_section(label = "Union Functions"),
description = "Returns the name of the currently selected field in the union",
syntax_example = "union_tag(union_expression)",
sql_example = r#"```sql
❯ select union_column, union_tag(union_column) from table_with_union;
+--------------+-------------------------+
| union_column | union_tag(union_column) |
+--------------+-------------------------+
| {a=1} | a |
| {b=3.0} | b |
| {a=4} | a |
| {b=} | b |
| {a=} | a |
+--------------+-------------------------+
```"#,
standard_argument(name = "union", prefix = "Union")
)]
#[derive(Debug)]
pub struct UnionTagFunc {
signature: Signature,
}

impl Default for UnionTagFunc {
fn default() -> Self {
Self::new()
}
}

impl UnionTagFunc {
pub fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for UnionTagFunc {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"union_tag"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Dictionary(
Box::new(DataType::Int8),
Box::new(DataType::Utf8),
))
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let [union_] = take_function_args("union_tag", args.args)?;

match union_ {
ColumnarValue::Array(array)
if matches!(array.data_type(), DataType::Union(_, _)) =>
{
let union_array = array.as_union();

let keys = Int8Array::try_new(union_array.type_ids().clone(), None)?;

let fields = match union_array.data_type() {
DataType::Union(fields, _) => fields,
_ => unreachable!(),
};

// Union fields type IDs only constraints are being unique and in the 0..128 range:
// They may not start at 0, be sequential, or even contiguous.
// Therefore, we allocate a values vector with a length equal to the highest type ID plus one,
// ensuring that each field's name can be placed at the index corresponding to its type ID.
Copy link
Contributor Author

@gstvg gstvg Feb 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The union column used on the sqllogictests contains a single field with type id 3, so this is put to the test

fn register_union_table(ctx: &SessionContext) {
let union = UnionArray::try_new(
UnionFields::new(vec![3], vec![Field::new("int", DataType::Int32, false)]),
ScalarBuffer::from(vec![3, 3]),
None,
vec![Arc::new(Int32Array::from(vec![1, 2]))],
)
.unwrap();
let schema = Schema::new(vec![Field::new(
"union_column",
union.data_type().clone(),
false,
)]);
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(union)]).unwrap();
ctx.register_batch("union_table", batch).unwrap();
}

"union_function.slt" => {
info!("Registering table with union column");
register_union_table(test_ctx.session_ctx())
}

let values_len = fields
.iter()
.map(|(type_id, _)| type_id + 1)
.max()
.unwrap_or_default() as usize;

let mut values = vec![""; values_len];

for (type_id, field) in fields.iter() {
values[type_id as usize] = field.name().as_str()
}

let values = Arc::new(StringArray::from(values));

// SAFETY: union type_ids are validated to not be smaller than zero.
// values len is the union biggest type id plus one.
// keys is built from the union type_ids, which contains only valid type ids
// therefore, `keys[i] >= values.len() || keys[i] < 0` never occurs
let dict = unsafe { DictionaryArray::new_unchecked(keys, values) };

Ok(ColumnarValue::Array(Arc::new(dict)))
}
ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => match value {
Some((value_type_id, _)) => fields
.iter()
.find(|(type_id, _)| value_type_id == *type_id)
.map(|(_, field)| {
ColumnarValue::Scalar(ScalarValue::Dictionary(
Box::new(DataType::Int8),
Box::new(field.name().as_str().into()),
))
})
.ok_or_else(|| {
exec_datafusion_err!(
"union_tag: union scalar with unknow type_id {value_type_id}"
)
}),
None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
args.return_type,
)?)),
},
v => exec_err!("union_tag only support unions, got {:?}", v.data_type()),
}
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}

#[cfg(test)]
mod tests {
use super::UnionTagFunc;
use arrow::datatypes::{DataType, Field, UnionFields, UnionMode};
use datafusion_common::ScalarValue;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
use std::sync::Arc;

// when it becomes possible to construct union scalars in SQL, this should go to sqllogictests
#[test]
fn union_scalar() {
let fields = [(0, Arc::new(Field::new("a", DataType::UInt32, false)))]
.into_iter()
.collect();

let scalar = ScalarValue::Union(
Some((0, Box::new(ScalarValue::UInt32(Some(0))))),
fields,
UnionMode::Dense,
);

let result = UnionTagFunc::new()
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Scalar(scalar)],
number_rows: 1,
return_type: &DataType::Dictionary(
Box::new(DataType::Int8),
Box::new(DataType::Utf8),
),
})
.unwrap();

assert_scalar(
result,
ScalarValue::Dictionary(Box::new(DataType::Int8), Box::new("a".into())),
);
}

#[test]
fn union_scalar_empty() {
let scalar = ScalarValue::Union(None, UnionFields::empty(), UnionMode::Dense);

let result = UnionTagFunc::new()
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Scalar(scalar)],
number_rows: 1,
return_type: &DataType::Dictionary(
Box::new(DataType::Int8),
Box::new(DataType::Utf8),
),
})
.unwrap();

assert_scalar(
result,
ScalarValue::Dictionary(
Box::new(DataType::Int8),
Box::new(ScalarValue::Utf8(None)),
),
);
}

fn assert_scalar(value: ColumnarValue, expected: ScalarValue) {
match value {
ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"),
ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected),
}
}
}
15 changes: 15 additions & 0 deletions datafusion/sqllogictest/test_files/union_function.slt
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,18 @@ select union_extract(union_column, 1) from union_table;

query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 3
select union_extract(union_column, 'a', 'b') from union_table;

query ?T
select union_column, union_tag(union_column) from union_table;
----
{int=1} int
{int=2} int

query error DataFusion error: Error during planning: 'union_tag' does not support zero arguments
select union_tag() from union_table;

query error DataFusion error: Error during planning: The function 'union_tag' expected 1 arguments but received 2
select union_tag(union_column, 'int') from union_table;

query error DataFusion error: Execution error: union_tag only support unions, got Utf8
select union_tag('int') from union_table;
28 changes: 28 additions & 0 deletions docs/source/user-guide/sql/scalar_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -4344,6 +4344,7 @@ sha512(expression)
Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator

- [union_extract](#union_extract)
- [union_tag](#union_tag)

### `union_extract`

Expand Down Expand Up @@ -4373,6 +4374,33 @@ union_extract(union, field_name)
+--------------+----------------------------------+----------------------------------+
```

### `union_tag`

Returns the name of the currently selected field in the union

```sql
union_tag(union_expression)
```

#### Arguments

- **union**: Union expression to operate on. Can be a constant, column, or function, and any combination of operators.

#### Example

```sql
❯ select union_column, union_tag(union_column) from table_with_union;
+--------------+-------------------------+
| union_column | union_tag(union_column) |
+--------------+-------------------------+
| {a=1} | a |
| {b=3.0} | b |
| {a=4} | a |
| {b=} | b |
| {a=} | a |
+--------------+-------------------------+
```

## Other Functions

- [arrow_cast](#arrow_cast)
Expand Down
Loading