Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Test Coverage for Interpolation class in interpol.py #243

Closed
wants to merge 11 commits into from
1 change: 1 addition & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ sphinx-design==0.2.0
sphinx-panels==0.6.0
jsonref==0.2
python-dateutil==2.8.2
coverage==6.4.3
2 changes: 1 addition & 1 deletion python/tempo/interpol.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __validate_col(
f"Target Column: '{column}' does not exist in DataFrame."
)
if df.select(column).dtypes[0][1] not in supported_target_col_types:
raise ValueError(
raise TypeError(
f"Target Column needs to be one of the following types: {supported_target_col_types}"
)

Expand Down
16 changes: 8 additions & 8 deletions python/tests/as_of_join_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ def test_asof_join(self):
).df

# joined dataframe should equal the expected dataframe
self.assertDataFramesEqual(joined_df, dfExpected)
self.assertDataFramesEqual(non_prefix_joined_df, noRightPrefixdfExpected)
self.assertDataFrameEquality(joined_df, dfExpected)
self.assertDataFrameEquality(non_prefix_joined_df, noRightPrefixdfExpected)

spark_sql_joined_df = tsdf_left.asofJoin(
tsdf_right, left_prefix="left", right_prefix="right"
).df
self.assertDataFramesEqual(spark_sql_joined_df, dfExpected)
self.assertDataFrameEquality(spark_sql_joined_df, dfExpected)

def test_asof_join_skip_nulls_disabled(self):
"""AS-OF Join with skip nulls disabled"""
Expand All @@ -47,15 +47,15 @@ def test_asof_join_skip_nulls_disabled(self):
).df

# joined dataframe should equal the expected dataframe with nulls skipped
self.assertDataFramesEqual(joined_df, dfExpectedSkipNulls)
self.assertDataFrameEquality(joined_df, dfExpectedSkipNulls)

# perform the join with skip nulls disabled
joined_df = tsdf_left.asofJoin(
tsdf_right, left_prefix="left", right_prefix="right", skipNulls=False
).df

# joined dataframe should equal the expected dataframe without nulls skipped
self.assertDataFramesEqual(joined_df, dfExpectedSkipNullsDisabled)
self.assertDataFrameEquality(joined_df, dfExpectedSkipNullsDisabled)

def test_sequence_number_sort(self):
"""Skew AS-OF Join with Partition Window Test"""
Expand All @@ -69,7 +69,7 @@ def test_sequence_number_sort(self):
joined_df = tsdf_left.asofJoin(tsdf_right, right_prefix="right").df

# joined dataframe should equal the expected dataframe
self.assertDataFramesEqual(joined_df, dfExpected)
self.assertDataFrameEquality(joined_df, dfExpected)

def test_partitioned_asof_join(self):
"""AS-OF Join with a time-partition"""
Expand All @@ -87,7 +87,7 @@ def test_partitioned_asof_join(self):
fraction=0.1,
).df

self.assertDataFramesEqual(joined_df, dfExpected)
self.assertDataFrameEquality(joined_df, dfExpected)

def test_asof_join_nanos(self):
"""As of join with nanosecond timestamps"""
Expand All @@ -103,7 +103,7 @@ def test_asof_join_nanos(self):
).df

# compare
self.assertDataFramesEqual(joined_df, dfExpected)
self.assertDataFrameEquality(joined_df, dfExpected)


# MAIN
Expand Down
73 changes: 28 additions & 45 deletions python/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
import os
import unittest
import warnings
from typing import Union

import jsonref

import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from tempo.tsdf import TSDF
from chispa import assert_df_equality
from pyspark.sql.dataframe import DataFrame


class SparkTest(unittest.TestCase):
Expand All @@ -22,7 +26,7 @@ class SparkTest(unittest.TestCase):
def setUpClass(cls) -> None:
# create and configure PySpark Session
cls.spark = (
SparkSession.builder.appName("myapp")
SparkSession.builder.appName("unit-tests")
.config("spark.jars.packages", "io.delta:delta-core_2.12:1.1.0")
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
.config(
Expand Down Expand Up @@ -142,8 +146,8 @@ def buildTestDF(self, schema, data, ts_cols=["event_ts"]):
# check if ts_col follows standard timestamp format, then check if timestamp has micro/nanoseconds
for tsc in ts_cols:
ts_value = str(df.select(ts_cols).limit(1).collect()[0][0])
ts_pattern = "^\d{4}-\d{2}-\d{2}| \d{2}:\d{2}:\d{2}\.\d*$"
decimal_pattern = "[.]\d+"
ts_pattern = r"^\d{4}-\d{2}-\d{2}| \d{2}:\d{2}:\d{2}\.\d*$"
decimal_pattern = r"[.]\d+"
if re.match(ts_pattern, str(ts_value)) is not None:
if (
re.search(decimal_pattern, ts_value) is None
Expand All @@ -153,7 +157,7 @@ def buildTestDF(self, schema, data, ts_cols=["event_ts"]):
return df

#
# DataFrame Assert Functions
# Assertion Functions
#

def assertFieldsEqual(self, fieldA, fieldB):
Expand Down Expand Up @@ -182,49 +186,28 @@ def assertSchemaContainsField(self, schema, field):
# the attributes of the fields must be equal
self.assertFieldsEqual(field, schema[field.name])

def assertSchemasEqual(self, schemaA, schemaB):
"""
Test that the two given schemas are equivalent (column ordering ignored)
"""
# both schemas must have the same length
self.assertEqual(len(schemaA.fields), len(schemaB.fields))
# schemaA must contain every field in schemaB
for field in schemaB.fields:
self.assertSchemaContainsField(schemaA, field)

def assertHasSchema(self, df, expectedSchema):
"""
Test that the given Dataframe conforms to the expected schema
"""
self.assertSchemasEqual(df.schema, expectedSchema)

def assertDataFramesEqual(self, dfA, dfB):
@staticmethod
def assertDataFrameEquality(
df1: Union[TSDF, DataFrame],
df2: Union[TSDF, DataFrame],
from_tsdf: bool = False,
ignore_row_order: bool = False,
ignore_column_order: bool = True,
ignore_nullable: bool = True,
):
"""
Test that the two given Dataframes are equivalent.
That is, they have equivalent schemas, and both contain the same values
"""
# must have the same schemas
self.assertSchemasEqual(dfA.schema, dfB.schema)
# enforce a common column ordering
colOrder = sorted(dfA.columns)
sortedA = dfA.select(colOrder)
sortedB = dfB.select(colOrder)
# must have identical data
# that is all rows in A must be in B, and vice-versa
self.assertEqual(
sortedA.subtract(sortedB).count(),
0,
msg="There are rows in DataFrame A that are not in DataFrame B",
)
self.assertEqual(
sortedB.subtract(sortedA).count(),
0,
msg="There are rows in DataFrame B that are not in DataFrame A",
)

def assertTSDFsEqual(self, tsdfA, tsdfB):
"""
Test that two given TSDFs are equivalent.
That is, their underlying Dataframes are equivalent.
"""
self.assertDataFramesEqual(tsdfA.df, tsdfB.df)
if from_tsdf:
df1 = df1.df
df2 = df2.df

assert_df_equality(
df1,
df2,
ignore_row_order=ignore_row_order,
ignore_column_order=ignore_column_order,
ignore_nullable=ignore_nullable,
)
79 changes: 65 additions & 14 deletions python/tests/interpol_tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import unittest

from chispa.dataframe_comparer import *
from pyspark.sql.types import *

from tempo.interpol import Interpolation
Expand All @@ -16,6 +15,58 @@ def setUp(self) -> None:
# register interpolation helper
self.interpolate_helper = Interpolation(is_resampled=False)

def test_is_resampled_type(self):
self.assertIsInstance(self.interpolate_helper.is_resampled, bool)

def test_validate_fill_method(self):
self.assertRaises(
ValueError,
self.interpolate_helper._Interpolation__validate_fill,
"abcd",
)

def test_validate_col_exist_in_df(self):
input_df: DataFrame = self.get_data_as_sdf("input_data")

self.assertRaises(
ValueError,
self.interpolate_helper._Interpolation__validate_col,
input_df,
["partition_a", "does_not_exist"],
["value_a", "value_b"],
"event_ts",
)

self.assertRaises(
ValueError,
self.interpolate_helper._Interpolation__validate_col,
input_df,
["partition_a", "partition_b"],
["does_not_exist", "value_b"],
"event_ts",
)

self.assertRaises(
ValueError,
self.interpolate_helper._Interpolation__validate_col,
input_df,
["partition_a", "partition_b"],
["value_a", "value_b"],
"wrongly_named",
)

def test_validate_col_target_cols_data_type(self):
input_df: DataFrame = self.get_data_as_sdf("input_data")

self.assertRaises(
TypeError,
self.interpolate_helper._Interpolation__validate_col,
input_df,
["partition_a", "partition_b"],
["string_target", "float_target"],
"event_ts",
)

def test_fill_validation(self):
"""Test fill parameter is valid."""

Expand Down Expand Up @@ -57,8 +108,8 @@ def test_target_column_validation(self):
method="zero",
show_interpolated=True,
)
except ValueError as e:
self.assertEqual(type(e), ValueError)
except TypeError as e:
self.assertEqual(type(e), TypeError)
else:
self.fail("ValueError not raised")

Expand Down Expand Up @@ -132,7 +183,7 @@ def test_zero_fill_interpolation(self):
show_interpolated=True,
)

assert_df_equality(expected_df, actual_df, ignore_nullable=True)
self.assertDataFrameEquality(expected_df, actual_df, ignore_nullable=True)

def test_null_fill_interpolation(self):
"""Test null fill interpolation.
Expand All @@ -158,7 +209,7 @@ def test_null_fill_interpolation(self):
show_interpolated=True,
)

assert_df_equality(expected_df, actual_df, ignore_nullable=True)
self.assertDataFrameEquality(expected_df, actual_df, ignore_nullable=True)

def test_back_fill_interpolation(self):
"""Test back fill interpolation.
Expand All @@ -185,7 +236,7 @@ def test_back_fill_interpolation(self):
show_interpolated=True,
)

assert_df_equality(expected_df, actual_df, ignore_nullable=True)
self.assertDataFrameEquality(expected_df, actual_df, ignore_nullable=True)

def test_forward_fill_interpolation(self):
"""Test forward fill interpolation.
Expand All @@ -212,7 +263,7 @@ def test_forward_fill_interpolation(self):
show_interpolated=True,
)

assert_df_equality(expected_df, actual_df, ignore_nullable=True)
self.assertDataFrameEquality(expected_df, actual_df, ignore_nullable=True)

def test_linear_fill_interpolation(self):
"""Test linear fill interpolation.
Expand All @@ -239,7 +290,7 @@ def test_linear_fill_interpolation(self):
show_interpolated=True,
)

assert_df_equality(expected_df, actual_df, ignore_nullable=True)
self.assertDataFrameEquality(expected_df, actual_df, ignore_nullable=True)

def test_different_freq_abbreviations(self):
"""Test abbreviated frequency values
Expand All @@ -264,7 +315,7 @@ def test_different_freq_abbreviations(self):
show_interpolated=True,
)

assert_df_equality(expected_df, actual_df, ignore_nullable=True)
self.assertDataFrameEquality(expected_df, actual_df, ignore_nullable=True)

def test_show_interpolated(self):
"""Test linear `show_interpolated` flag
Expand All @@ -291,7 +342,7 @@ def test_show_interpolated(self):
show_interpolated=False,
)

assert_df_equality(expected_df, actual_df, ignore_nullable=True)
self.assertDataFrameEquality(expected_df, actual_df, ignore_nullable=True)


class InterpolationIntegrationTest(SparkTest):
Expand All @@ -311,7 +362,7 @@ def test_interpolation_using_default_tsdf_params(self):
).df

# compare with expected
assert_df_equality(expected_df, actual_df, ignore_nullable=True)
self.assertDataFrameEquality(expected_df, actual_df, ignore_nullable=True)

def test_interpolation_using_custom_params(self):
"""Verify that by specifying optional paramters it will change the result of the interpolation based on those modified params."""
Expand All @@ -336,7 +387,7 @@ def test_interpolation_using_custom_params(self):
method="linear",
).df

assert_df_equality(expected_df, actual_df, ignore_nullable=True)
self.assertDataFrameEquality(expected_df, actual_df, ignore_nullable=True)

def test_tsdf_constructor_params_are_updated(self):
"""Verify that resulting TSDF class has the correct values for ts_col and partition_col based on the interpolation."""
Expand Down Expand Up @@ -372,7 +423,7 @@ def test_interpolation_on_sampled_data(self):
.df
)

assert_df_equality(expected_df, actual_df, ignore_nullable=True)
self.assertDataFrameEquality(expected_df, actual_df, ignore_nullable=True)

def test_defaults_with_resampled_df(self):
"""Verify interpolation can be chained with resample within the TSDF class"""
Expand All @@ -388,7 +439,7 @@ def test_defaults_with_resampled_df(self):
.df
)

assert_df_equality(expected_df, actual_df, ignore_nullable=True)
self.assertDataFrameEquality(expected_df, actual_df, ignore_nullable=True)


# MAIN
Expand Down
Loading