-
Notifications
You must be signed in to change notification settings - Fork 10
/
schema.py
350 lines (295 loc) · 10.6 KB
/
schema.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
# stdlib
from enum import Enum
from typing import Any, Dict, List, Optional, Union
# third party
import streamlit as st
from pydantic import BaseModel, Field, model_validator
# first party
from queries import GRAPHQL_QUERIES
GQL_MAP: Dict = {
"metrics": {
"kwarg": "$metrics",
"argument": "[MetricInput!]!",
},
"groupBy": {
"kwarg": "$groupBy",
"argument": "[GroupByInput!]",
},
"where": {
"kwarg": "$where",
"argument": "[WhereInput!]",
},
"orderBy": {
"kwarg": "$orderBy",
"argument": "[OrderByInput!]",
},
"limit": {"kwarg": "$limit", "argument": "Int"},
}
class TimeGranularity(str, Enum):
day = "DAY"
week = "WEEK"
month = "MONTH"
quarter = "QUARTER"
year = "YEAR"
class MetricInput(BaseModel):
name: str = Field(
description=(
"Metric name defined by the user. A metric can generally be thought of "
"as a descriptive statistic, indicator, or figure of merit used to "
"describe or measure something quantitatively."
)
)
class GroupByInput(BaseModel):
name: str = Field(
description=(
"Dimension name defined by the user. They often contain qualitative "
"values (such as names, dates, or geographical data). You can use "
"dimensions to categorize, segment, and reveal the details in your data. "
"A common dimension used here will be metric_time. This will ALWAYS have "
"an associated grain."
)
)
grain: Optional[TimeGranularity] = Field(
default=None,
description=(
"The grain is the time interval represented by a single point in the data"
),
)
class Config:
use_enum_values = True
class OrderByInput(BaseModel):
"""
Important note: Only one of metric or groupBy is allowed to be specified
"""
metric: Optional[MetricInput] = None
groupBy: Optional[GroupByInput] = None
descending: Optional[bool] = None
@model_validator(mode="before")
def check_metric_or_groupBy(cls, values):
if (values.get("metric") is None) and (values.get("groupBy") is None):
raise ValueError("either metric or groupBy is required")
if (values.get("metric") is not None) and (values.get("groupBy") is not None):
raise ValueError("only one of metric or groupBy is allowed")
return values
class Config:
exclude_none = True
class WhereInput(BaseModel):
sql: str
class Query(BaseModel):
metrics: List[MetricInput]
groupBy: Optional[List[GroupByInput]] = None
where: Optional[List[WhereInput]] = None
orderBy: Optional[List[OrderByInput]] = None
limit: Optional[int] = None
@property
def all_names(self):
return self.metric_names + self.dimension_names
@property
def metric_names(self):
return [m.name for m in self.metrics]
@property
def dimension_names(self):
return [
f"{g.name}__{g.grain.lower()}" if g.grain is not None else g.name
for g in self.groupBy or []
]
@property
def time_dimension_names(self):
return [
f"{g.name}__{g.grain.lower()}"
for g in self.groupBy or []
if g.grain is not None
]
@property
def has_time_dimension(self):
if self.groupBy is not None:
return any([g.grain is not None for g in self.groupBy])
return False
@property
def has_multiple_metrics(self):
return len(self.metrics) > 1
@property
def used_inputs(self) -> List[str]:
inputs = []
for key in GQL_MAP.keys():
prop = getattr(self, key)
if prop is not None:
try:
if len(prop) > 0:
inputs.append(key)
except TypeError:
inputs.append(key)
return inputs
@property
def _jdbc_text(self) -> str:
text = f"metrics={[m.name for m in self.metrics]}"
if self.groupBy is not None:
group_by = [
f"{g.name}__{g.grain.lower()}" if g.grain is not None else g.name
for g in self.groupBy
]
text += f",\n group_by={group_by}"
if self.where is not None:
where = " AND ".join([w.sql for w in self.where])
text += f',\n where="{where}"'
if self.orderBy is not None:
names = []
for order in self.orderBy:
obj = order.metric if order.metric else order.groupBy
if hasattr(obj, "grain") and obj.grain is not None:
name = f"{obj.name}__{obj.grain.lower()}"
else:
name = obj.name
if order.descending:
name = f"-{name}"
names.append(name)
text += f",\n order_by={names}"
if self.limit is not None:
text += f",\n limit={self.limit}"
return text
@property
def jdbc_query(self):
sql = f"""
select *
from {{{{
semantic_layer.query(
{self._jdbc_text}
)
}}}}
"""
return sql
@property
def gql(self) -> str:
query = GRAPHQL_QUERIES["create_query"]
kwargs = {"environmentId": "$environmentId"}
arguments = {"environmentId": "BigInt!"}
for input in self.used_inputs:
kwargs[input] = GQL_MAP[input]["kwarg"]
arguments[input] = GQL_MAP[input]["argument"]
return query.format(
**{
"arguments": ", ".join(f"${k}: {v}" for k, v in arguments.items()),
"kwargs": ",\n ".join([f"{k}: {v}" for k, v in kwargs.items()]),
}
)
@property
def sdk(self) -> Dict[str, Any]:
def str_or_dict(item):
item_dict = item.model_dump(exclude_none=True)
keys = len(item_dict.keys())
if keys == 1:
return item_dict["name"]
if keys == 2 and "grain" in item_dict:
return f'{item_dict["name"]}__{item_dict["grain"]}'
return item_dict
def inputs_to_dict(inputs):
return [str_or_dict(i) for i in inputs]
return {
"metrics": inputs_to_dict(self.metrics) if self.metrics else [],
"group_by": inputs_to_dict(self.groupBy) if self.groupBy else [],
"where": [w.sql for w in self.where] if self.where else [],
"order_by": [
("-" if o.descending else "")
+ str_or_dict(o.metric if o.metric else o.groupBy)
for o in self.orderBy or []
],
"limit": self.limit if self.limit else None,
}
@property
def variables(self) -> Dict[str, List[Any]]:
variables = {}
for input in self.used_inputs:
data = getattr(self, input)
if isinstance(data, list):
variables[input] = [m.model_dump(exclude_none=True) for m in data]
else:
try:
variables[input] = getattr(self, input).model_dump(
exclude_none=True
)
except AttributeError:
variables[input] = getattr(self, input)
return variables
class QueryLoader:
def __init__(self, state: st.session_state):
self.state = state
def create(self):
return Query(
metrics=self._metrics,
groupBy=self._groupBy or None,
where=self._where or None,
orderBy=self._orderBy or None,
limit=self._limit or None,
)
def _is_time_dimension(self, dimension: str):
return self.state.dimension_dict[dimension]["type"].lower() == "time"
@property
def _metrics(self):
return [MetricInput(name=m) for m in self.state.selected_metrics]
@property
def _groupBy(self):
dimensions = []
for dimension in self.state.selected_dimensions:
kwargs = {"name": dimension}
if self._is_time_dimension(dimension):
kwargs["grain"] = self.state.selected_grain.upper()
dimensions.append(GroupByInput(**kwargs))
return dimensions
@property
def _where(self):
def where_dimension(dimension: str):
if self._is_time_dimension(dimension):
return f"TimeDimension('{dimension}', '{self.state.get('selected_grain', 'day').upper()}')"
return f"Dimension('{dimension}')"
def where_condition(condition: Union[List, str]):
if isinstance(condition, list):
return "(" + ", ".join(f"'{item}'" for item in condition) + ")"
if isinstance(condition, tuple):
return f"'{condition[0]}' AND '{condition[1]}'"
return f"'{condition}'"
wheres = []
for i in range(10):
column = f"where_column_{i}"
operator = f"where_operator_{i}"
condition = f"where_condition_{i}"
if column in self.state and self.state[column] is not None:
dimension_arg = where_dimension(self.state[column])
condition_arg = where_condition(self.state[condition])
wheres.append(
WhereInput(
sql=f"{{{{ {dimension_arg} }}}} {self.state[operator]} {condition_arg}"
)
)
else:
break
return wheres
@property
def _orderBy(self):
def metric(metric_name):
return {"metric": {"name": metric_name}}
def groupBy(dimension_name):
dct = {"name": dimension_name}
if self._is_time_dimension(dimension_name):
dct["grain"] = self.state.selected_grain.upper()
return {"groupBy": dct}
orderBys = []
for i in range(10):
column = f"order_column_{i}"
direction = f"order_direction_{i}"
if column in self.state and self.state[column] is not None:
name = self.state[column]
if name in self.state.selected_metrics:
dct = metric(name)
else:
dct = groupBy(name)
if self.state[direction].lower() == "desc":
dct["descending"] = True
orderBys.append(OrderByInput(**dct))
else:
break
return orderBys
@property
def _limit(self):
if self.state.selected_limit is not None and self.state.selected_limit != 0:
return self.state.selected_limit
return None