Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add flatten array function #562

Merged
merged 5 commits into from
Feb 25, 2024
Merged

Conversation

mobley-trent
Copy link
Contributor

Which issue does this PR close?

Refer to issue #463

Rationale for this change

What changes are included in this PR?

Are there any user-facing changes?

@mobley-trent mobley-trent marked this pull request as ready for review January 16, 2024 11:44
@mobley-trent
Copy link
Contributor Author

Hello @andygrove do you mind giving me a hand with this PR ? I exposed Flatten in functions.rs but the python array function test for flatten is failing like so:

name = 'flatten'

    def __getattr__(name):
>       return getattr(functions, name)
E       AttributeError: module 'functions' has no attribute 'flatten'

@ongchi
Copy link
Contributor

ongchi commented Feb 6, 2024

Hello @andygrove do you mind giving me a hand with this PR ? I exposed Flatten in functions.rs but the python array function test for flatten is failing like so:

name = 'flatten'

    def __getattr__(name):
>       return getattr(functions, name)
E       AttributeError: module 'functions' has no attribute 'flatten'

Hi @mobley-trent
Did you try to rebuild the package before running pytest? Like this:

# build and install package
maturin develop

Also, don't forget to active the venv before this command.

@mobley-trent
Copy link
Contributor Author

Hey @ongchi I tested the flatten function and its failing. Here is the code :

from datafusion import SessionContext, column
from datafusion import functions as f
import numpy as np
import pyarrow as pa


def py_flatten(arr):
    # Testing helper function
    result = []
    for elem in arr:
        if isinstance(elem, list):
            result.extend(py_flatten(elem))
        else:
            result.append(elem)
    return result

ctx = SessionContext()
data = [[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]]

batch = pa.RecordBatch.from_arrays(
    [np.array(data, dtype=object)], names=["arr"]
)
df = ctx.create_dataframe([[batch]])
col = column("arr")


stmt = f.flatten(col)
py_expr = lambda: [py_flatten(data)]

result = df.select(stmt).collect()[0].column(0).tolist()

print(f"flatten query: {result}")
print(f"py_expr: {py_expr()}")

Results:

>>> flatten query: [[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]]
>>> py_expr: [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]

I expected the flatten query to be identical to the py_expr. Is there something I overlooked ? Or is this an underlying bug ?

@mobley-trent
Copy link
Contributor Author

mobley-trent commented Feb 10, 2024

Using a regular flatten query:

ctx = SessionContext()
ctx.sql("select flatten([[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]]);")

Result:

DataFrame()
+----------------------------------------------------------------------------------------------------------------------------+
| flatten(make_array(make_array(Float64(1),Float64(2),Float64(3)),make_array(Float64(4),Float64(5)),make_array(Float64(6)))) |
+----------------------------------------------------------------------------------------------------------------------------+
| [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]                                                                                             |
+----------------------------------------------------------------------------------------------------------------------------+

@ongchi
Copy link
Contributor

ongchi commented Feb 11, 2024

DataFrame()
+----------------------------------------------------------------------------------------------------------------------------+
| flatten(make_array(make_array(Float64(1),Float64(2),Float64(3)),make_array(Float64(4),Float64(5)),make_array(Float64(6)))) |
+----------------------------------------------------------------------------------------------------------------------------+
| [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]                                                                                             |
+----------------------------------------------------------------------------------------------------------------------------+

Hi @mobley-trent
The df created in the test case maybe is a bit misleading, but it would be like this:

❯ SELECT column1 AS arr FROM (VALUES ([1.0, 2.0, 3.0, 3.0]), ([4.0, 5.0, 3.0]), ([6.0]));
+----------------------+
| arr                  |
+----------------------+
| [1.0, 2.0, 3.0, 3.0] |
| [4.0, 5.0, 3.0]      |
| [6.0]                |
+----------------------+

It's contains of multiple rows of one-dimensional array values. For the flatten function, the existing df should be modified or a new dataframe should be created for this test case.

Copy link
Member

@andygrove andygrove left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mobley-trent

@mobley-trent
Copy link
Contributor Author

Fixed the merge conflicts

@andygrove andygrove merged commit 27a9264 into apache:main Feb 25, 2024
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants