Skip to content

Commit

Permalink
Add trends test case for 8e2dd3c
Browse files Browse the repository at this point in the history
  • Loading branch information
danielbachhuber committed Feb 4, 2025
1 parent 8e2dd3c commit 2d10c40
Showing 1 changed file with 251 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,162 @@ def create_data_warehouse_table_with_usage(self):
)
return table_name

def create_data_warehouse_table_with_subscriptions(self):
if not OBJECT_STORAGE_ACCESS_KEY_ID or not OBJECT_STORAGE_SECRET_ACCESS_KEY:
raise Exception("Missing vars")

fs = s3fs.S3FileSystem(
client_kwargs={
"region_name": "us-east-1",
"endpoint_url": OBJECT_STORAGE_ENDPOINT,
"aws_access_key_id": OBJECT_STORAGE_ACCESS_KEY_ID,
"aws_secret_access_key": OBJECT_STORAGE_SECRET_ACCESS_KEY,
},
)

path_to_s3_object = "s3://" + OBJECT_STORAGE_BUCKET + f"/{TEST_BUCKET}"

credential = DataWarehouseCredential.objects.create(
access_key=OBJECT_STORAGE_ACCESS_KEY_ID,
access_secret=OBJECT_STORAGE_SECRET_ACCESS_KEY,
team=self.team,
)

subscription_table_data = [
{
"subscription_id": "1",
"subscription_created_at": datetime(2023, 1, 2),
"subscription_customer_id": "user_control_0",
"subscription_amount": 100,
},
{
"subscription_id": "2",
"subscription_created_at": datetime(2023, 1, 3),
"subscription_customer_id": "user_test_1",
"subscription_amount": 50,
},
{
"subscription_id": "3",
"subscription_created_at": datetime(2023, 1, 4),
"subscription_customer_id": "user_test_2",
"subscription_amount": 75,
},
{
"subscription_id": "4",
"subscription_created_at": datetime(2023, 1, 5),
"subscription_customer_id": "user_test_3",
"subscription_amount": 80,
},
{
"subscription_id": "5",
"subscription_created_at": datetime(2023, 1, 6),
"subscription_customer_id": "user_extra",
"subscription_amount": 90,
},
]

pq.write_to_dataset(
pa.Table.from_pylist(subscription_table_data),
path_to_s3_object,
filesystem=fs,
use_dictionary=True,
compression="snappy",
)

subscription_table_name = "subscriptions"

DataWarehouseTable.objects.create(
name=subscription_table_name,
url_pattern=f"http://host.docker.internal:19000/{OBJECT_STORAGE_BUCKET}/{TEST_BUCKET}/*.parquet",
format=DataWarehouseTable.TableFormat.Parquet,
team=self.team,
columns={
"subscription_id": "String",
"subscription_created_at": "DateTime64(3, 'UTC')",
"subscription_customer_id": "String",
"subscription_amount": "Int64",
},
credential=credential,
)

customer_table_data = [
{
"customer_id": "user_control_0",
"customer_created_at": datetime(2023, 1, 1),
"customer_name": "John Doe",
"customer_email": "john.doe@example.com",
},
{
"customer_id": "user_test_1",
"customer_created_at": datetime(2023, 1, 2),
"customer_name": "Jane Doe",
"customer_email": "jane.doe@example.com",
},
{
"customer_id": "user_test_2",
"customer_created_at": datetime(2023, 1, 3),
"customer_name": "John Smith",
"customer_email": "john.smith@example.com",
},
{
"customer_id": "user_test_3",
"customer_created_at": datetime(2023, 1, 6),
"customer_name": "Jane Smith",
"customer_email": "jane.smith@example.com",
},
{
"customer_id": "user_extra",
"customer_created_at": datetime(2023, 1, 7),
"customer_name": "John Doe Jr",
"customer_email": "john.doejr@example.com",
},
]

pq.write_to_dataset(
pa.Table.from_pylist(customer_table_data),
path_to_s3_object,
filesystem=fs,
use_dictionary=True,
compression="snappy",
)

customer_table_name = "customers"

DataWarehouseTable.objects.create(
name=customer_table_name,
url_pattern=f"http://host.docker.internal:19000/{OBJECT_STORAGE_BUCKET}/{TEST_BUCKET}/*.parquet",
format=DataWarehouseTable.TableFormat.Parquet,
team=self.team,
columns={
"customer_id": "String",
"customer_created_at": "DateTime64(3, 'UTC')",
"customer_name": "String",
"customer_email": "String",
},
credential=credential,
)

DataWarehouseJoin.objects.create(
team=self.team,
source_table_name=subscription_table_name,
source_table_key="subscription_customer_id",
joining_table_name=customer_table_name,
joining_table_key="customer_id",
field_name="subscription_customer",
)

DataWarehouseJoin.objects.create(
team=self.team,
source_table_name=subscription_table_name,
source_table_key="subscription_customer.customer_email",
joining_table_name="events",
joining_table_key="person.properties.email",
field_name="events",
configuration={"experiments_optimized": True, "experiments_timestamp_key": "subscription_created_at"},
)

return subscription_table_name

@freeze_time("2020-01-01T12:00:00Z")
def test_query_runner(self):
feature_flag = self.create_feature_flag()
Expand Down Expand Up @@ -2203,6 +2359,101 @@ def test_query_runner_with_data_warehouse_series_expected_query(self):
self.assertEqual(control_result.absolute_exposure, 7)
self.assertEqual(test_result.absolute_exposure, 9)

def test_query_runner_with_data_warehouse_subscriptions_table(self):
table_name = self.create_data_warehouse_table_with_subscriptions()

feature_flag = self.create_feature_flag()
experiment = self.create_experiment(
feature_flag=feature_flag,
start_date=datetime(2023, 1, 1),
end_date=datetime(2023, 1, 10),
)

feature_flag_property = f"$feature/{feature_flag.key}"

count_query = TrendsQuery(
series=[
DataWarehouseNode(
id=table_name,
distinct_id_field="subscription_customer_id",
id_field="id",
table_name=table_name,
timestamp_field="subscription_created_at",
math="total",
)
]
)

experiment_query = ExperimentTrendsQuery(
experiment_id=experiment.id,
kind="ExperimentTrendsQuery",
count_query=count_query,
exposure_query=None,
)

experiment.metrics = [{"type": "primary", "query": experiment_query.model_dump()}]
experiment.save()

# Populate exposure events
for variant, count in [("control", 7), ("test", 9)]:
for i in range(count):
_create_event(
team=self.team,
event="$feature_flag_called",
distinct_id=f"user_{variant}_{i}",
properties={
"$feature_flag_response": variant,
feature_flag_property: variant,
"$feature_flag": feature_flag.key,
},
timestamp=datetime(2023, 1, i + 1),
)

_create_person(
team=self.team,
distinct_ids=["user_control_0"],
properties={"email": "john.doe@example.com"},
)

_create_person(
team=self.team,
distinct_ids=["user_test_1"],
properties={"email": "jane.doe@example.com"},
)

_create_person(
team=self.team,
distinct_ids=["user_test_2"],
properties={"email": "john.smith@example.com"},
)

_create_person(
team=self.team,
distinct_ids=["user_test_3"],
properties={"email": "jane.smith@example.com"},
)

flush_persons_and_events()

query_runner = ExperimentTrendsQueryRunner(
query=ExperimentTrendsQuery(**experiment.metrics[0]["query"]), team=self.team
)

with freeze_time("2023-01-10"):
result = query_runner.calculate()

trend_result = cast(ExperimentTrendsQueryResponse, result)

self.assertEqual(len(result.variants), 2)

control_result = next(variant for variant in trend_result.variants if variant.key == "control")
test_result = next(variant for variant in trend_result.variants if variant.key == "test")

self.assertEqual(control_result.count, 1)
self.assertEqual(test_result.count, 3)
self.assertEqual(control_result.absolute_exposure, 7)
self.assertEqual(test_result.absolute_exposure, 9)

def test_query_runner_with_invalid_data_warehouse_table_name(self):
# parquet file isn't created, so we'll get an error
table_name = "invalid_table_name"
Expand Down

0 comments on commit 2d10c40

Please sign in to comment.