Skip to content

Commit

Permalink
update cargo.toml in python crate and fix unit test due to hash joins (
Browse files Browse the repository at this point in the history
…#483)

* update cargo.toml

* fix group by

* remove unused imports
  • Loading branch information
jimexist authored Jun 3, 2021
1 parent e82d053 commit e713bc3
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 29 deletions.
2 changes: 1 addition & 1 deletion python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ libc = "0.2"
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] }
rand = "0.7"
pyo3 = { version = "0.13.2", features = ["extension-module"] }
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "2423ff0d" }
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "c3fc0c75af5ff2ebb99dba197d9d2ccd83eb5952" }

[lib]
name = "datafusion"
Expand Down
6 changes: 0 additions & 6 deletions python/tests/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,9 @@
# specific language governing permissions and limitations
# under the License.

import unittest
import tempfile
import datetime
import os.path
import shutil

import numpy
import pyarrow
import datafusion

# used to write parquet files
import pyarrow.parquet
Expand Down
24 changes: 9 additions & 15 deletions python/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

import pyarrow as pa
import datafusion

f = datafusion.functions


class TestCase(unittest.TestCase):

def _prepare(self):
ctx = datafusion.ExecutionContext()

Expand Down Expand Up @@ -51,12 +51,10 @@ def test_select(self):
def test_filter(self):
df = self._prepare()

df = df \
.select(
f.col("a") + f.col("b"),
f.col("a") - f.col("b"),
) \
.filter(f.col("a") > f.lit(2))
df = df.select(
f.col("a") + f.col("b"),
f.col("a") - f.col("b"),
).filter(f.col("a") > f.lit(2))

# execute and collect the first (and only) batch
result = df.collect()[0]
Expand All @@ -66,12 +64,10 @@ def test_filter(self):

def test_sort(self):
df = self._prepare()
df = df.sort([
f.col("b").sort(ascending=False)
])
df = df.sort([f.col("b").sort(ascending=False)])

table = pa.Table.from_batches(df.collect())
expected = {'a': [3, 2, 1], 'b': [6, 5, 4]}
expected = {"a": [3, 2, 1], "b": [6, 5, 4]}
self.assertEqual(table.to_pydict(), expected)

def test_limit(self):
Expand Down Expand Up @@ -111,10 +107,8 @@ def test_join(self):
df1 = ctx.create_dataframe([[batch]])

df = df.join(df1, on="a", how="inner")
df = df.sort([
f.col("a").sort(ascending=True)
])
df = df.sort([f.col("a").sort(ascending=True)])
table = pa.Table.from_batches(df.collect())

expected = {'a': [1, 2], 'c': [8, 10], 'b': [4, 5]}
expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
self.assertEqual(table.to_pydict(), expected)
12 changes: 9 additions & 3 deletions python/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,18 @@ def test_execute(self):
)

# group by
result = ctx.sql(
results = ctx.sql(
"SELECT CAST(a as int), COUNT(a) FROM t GROUP BY CAST(a as int)"
).collect()

result_keys = result[0].to_pydict()["CAST(a AS Int32)"]
result_values = result[0].to_pydict()["COUNT(a)"]
# group by returns batches
result_keys = []
result_values = []
for result in results:
pydict = result.to_pydict()
result_keys.extend(pydict["CAST(a AS Int32)"])
result_values.extend(pydict["COUNT(a)"])

result_keys, result_values = (
list(t) for t in zip(*sorted(zip(result_keys, result_values)))
)
Expand Down
8 changes: 4 additions & 4 deletions python/tests/test_udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.

import unittest

import pyarrow
import pyarrow.compute
import datafusion
Expand Down Expand Up @@ -86,6 +85,7 @@ def test_group_by(self):
df = df.aggregate([f.col("b")], [udaf(f.col("a"))])

# execute and collect the first (and only) batch
result = df.collect()[0]

self.assertEqual(result.column(1), pyarrow.array([1.0 + 2.0, 3.0]))
batches = df.collect()
arrays = [batch.column(1) for batch in batches]
joined = pyarrow.concat_arrays(arrays)
self.assertEqual(joined, pyarrow.array([1.0 + 2.0, 3.0]))

0 comments on commit e713bc3

Please sign in to comment.