-
Notifications
You must be signed in to change notification settings - Fork 591
/
__init__.py
97 lines (62 loc) · 2.26 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING, Any, Callable
import sqlglot as sg
if TYPE_CHECKING:
import ibis.expr.datatypes as dt
from ibis.backends.base.sqlglot.datatypes import SqlglotType
class AggGen:
__slots__ = ("aggfunc",)
def __init__(self, *, aggfunc: Callable) -> None:
self.aggfunc = aggfunc
def __getattr__(self, name: str) -> partial:
return partial(self.aggfunc, name)
def __getitem__(self, key: str) -> partial:
return getattr(self, key)
def _func(name: str, *args: Any, **kwargs: Any):
return sg.func(name, *map(sg.exp.convert, args), **kwargs)
class FuncGen:
__slots__ = ()
def __getattr__(self, name: str) -> partial:
return partial(_func, name)
def __getitem__(self, key: str) -> partial:
return getattr(self, key)
def array(self, *args):
return sg.exp.Array.from_arg_list(list(map(sg.exp.convert, args)))
def tuple(self, *args):
return sg.func("tuple", *map(sg.exp.convert, args))
def exists(self, query):
return sg.exp.Exists(this=query)
def concat(self, *args):
return sg.exp.Concat.from_arg_list(list(map(sg.exp.convert, args)))
def map(self, keys, values):
return sg.exp.Map(keys=keys, values=values)
class ColGen:
__slots__ = ()
def __getattr__(self, name: str) -> sg.exp.Column:
return sg.column(name)
def __getitem__(self, key: str) -> sg.exp.Column:
return sg.column(key)
def paren(expr):
"""Wrap a sqlglot expression in parentheses."""
return sg.exp.Paren(this=expr)
def parenthesize(op, arg):
import ibis.expr.operations as ops
if isinstance(op, (ops.Binary, ops.Unary)):
return paren(arg)
# function calls don't need parens
return arg
def interval(value, *, unit):
return sg.exp.Interval(this=sg.exp.convert(value), unit=sg.exp.var(unit))
C = ColGen()
F = FuncGen()
NULL = sg.exp.NULL
FALSE = sg.exp.FALSE
TRUE = sg.exp.TRUE
STAR = sg.exp.Star()
def make_cast(
converter: SqlglotType,
) -> Callable[[sg.exp.Expression, dt.DataType], sg.exp.Cast]:
def cast(arg: sg.exp.Expression, to: dt.DataType) -> sg.exp.Cast:
return sg.cast(arg, to=converter.from_ibis(to))
return cast