Skip to content

Commit

Permalink
Properly persist full left and right chain for compare eq
Browse files Browse the repository at this point in the history
  • Loading branch information
danielbachhuber committed Feb 4, 2025
1 parent aafcc0f commit 8e2dd3c
Showing 1 changed file with 18 additions and 22 deletions.
40 changes: 18 additions & 22 deletions posthog/warehouse/models/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,8 @@ def _join_function(
if not join_to_add.fields_accessed:
raise ResolutionError(f"No fields requested from {join_to_add.to_table}")

left = parse_expr(_source_table_key)
if isinstance(left, ast.Field):
left.chain = [join_to_add.from_table, *left.chain]
elif isinstance(left, ast.Call) and isinstance(left.args[0], ast.Field):
left.args[0].chain = [join_to_add.from_table, *left.args[0].chain]
else:
raise ResolutionError("Data Warehouse Join HogQL expression should be a Field or Call node")

right = parse_expr(_joining_table_key)
if isinstance(right, ast.Field):
right.chain = [join_to_add.to_table, *right.chain]
elif isinstance(right, ast.Call) and isinstance(right.args[0], ast.Field):
right.args[0].chain = [join_to_add.to_table, *right.args[0].chain]
else:
raise ResolutionError("Data Warehouse Join HogQL expression should be a Field or Call node")
left = self.__parse_table_key_expression(_source_table_key, join_to_add.from_table)
right = self.__parse_table_key_expression(_joining_table_key, join_to_add.to_table)

join_expr = ast.JoinExpr(
table=ast.SelectQuery(
Expand Down Expand Up @@ -119,6 +106,9 @@ def _join_function_for_experiments(
if not timestamp_key:
raise ResolutionError("experiments_timestamp_key is not set for this join")

left = self.__parse_table_key_expression(self.source_table_key, join_to_add.from_table)
right = self.__parse_table_key_expression(self.joining_table_key, join_to_add.to_table)

whereExpr: list[ast.Expr] = [
ast.CompareOperation(
op=ast.CompareOperationOp.Eq,
Expand Down Expand Up @@ -184,14 +174,9 @@ def _join_function_for_experiments(
right=ast.Constant(value="$feature_flag_called"),
),
ast.CompareOperation(
left=ast.Field(
chain=[
join_to_add.from_table,
self.source_table_key,
]
),
left=left,
op=ast.CompareOperationOp.Eq,
right=ast.Field(chain=[join_to_add.to_table, *self.joining_table_key.split(".")]),
right=right,
),
ast.CompareOperation(
left=ast.Field(
Expand All @@ -210,3 +195,14 @@ def _join_function_for_experiments(
)

return _join_function_for_experiments

def __parse_table_key_expression(self, table_key: str, table_name: str) -> ast.Expr:
expr = parse_expr(table_key)
if isinstance(expr, ast.Field):
expr.chain = [table_name, *expr.chain]
elif isinstance(expr, ast.Call) and isinstance(expr.args[0], ast.Field):
expr.args[0].chain = [table_name, *expr.args[0].chain]
else:
raise ResolutionError("Data Warehouse Join HogQL expression should be a Field or Call node")

return expr

0 comments on commit 8e2dd3c

Please sign in to comment.