-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathclient.py
200 lines (155 loc) · 6.69 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""SQL client handling.
This includes mssqlStream and mssqlConnector.
"""
from __future__ import annotations
import gzip
import json
from datetime import datetime
from uuid import uuid4
from typing import Any, Dict, Iterable, Optional
import pendulum
import pyodbc
import sqlalchemy
from sqlalchemy.engine import URL
from singer_sdk import SQLConnector, SQLStream
from singer_sdk.helpers._batch import (
BaseBatchFileEncoding,
BatchConfig,
)
from singer_sdk.streams.core import lazy_chunked_generator
class mssqlConnector(SQLConnector):
"""Connects to the mssql SQL source."""
def __init__(self, config: dict | None = None, sqlalchemy_url: str | None = None) -> None:
"""Class Default Init"""
# If pyodbc given set pyodbc.pooling to False
# This allows SQLA to manage to connection pool
if config['driver_type'] == 'pyodbc':
pyodbc.pooling = False
super().__init__(config, sqlalchemy_url)
def get_sqlalchemy_url(cls, config: dict) -> str:
"""Concatenate a SQLAlchemy URL for use in connecting to the source."""
if config['dialect'] == "mssql":
url_drivername:str = config['dialect']
else:
cls.logger.error("Invalid dialect given")
exit(1)
if config['driver_type'] in ["pyodbc", "pymssql"]:
url_drivername += f"+{config['driver_type']}"
else:
cls.logger.error("Invalid driver_type given")
exit(1)
config_url = URL.create(
url_drivername,
config['user'],
config['password'],
host = config['host'],
database = config['database']
)
if 'port' in config:
config_url = config_url.set(port=config['port'])
if 'sqlalchemy_url_query' in config:
config_url = config_url.update_query_dict(config['sqlalchemy_url_query'])
return (config_url)
def create_sqlalchemy_engine(self) -> sqlalchemy.engine.Engine:
"""Return a new SQLAlchemy engine using the provided config.
Developers can generally override just one of the following:
`sqlalchemy_engine`, sqlalchemy_url`.
Returns:
A newly created SQLAlchemy engine object.
"""
eng_prefix = "ep."
eng_config = {f"{eng_prefix}url":self.sqlalchemy_url,f"{eng_prefix}echo":"False"}
if self.config.get('sqlalchemy_eng_params'):
for key, value in self.config['sqlalchemy_eng_params'].items():
eng_config.update({f"{eng_prefix}{key}": value})
return sqlalchemy.engine_from_config(eng_config, prefix=eng_prefix)
@staticmethod
def to_jsonschema_type(sql_type: sqlalchemy.types.TypeEngine) -> dict:
"""Returns a JSON Schema equivalent for the given SQL type.
Developers may optionally add custom logic before calling the default
implementation inherited from the base class.
"""
"""
Checks for the MSSQL type of NUMERIC
if scale = 0 it is typed as a INTEGER
if scale != 0 it is typed as NUMBER
"""
if str(sql_type).startswith("NUMERIC"):
if str(sql_type).endswith(", 0)"):
sql_type = "int"
else:
sql_type = "number"
if str(sql_type) in ["MONEY", "SMALLMONEY"]:
sql_type = "number"
return SQLConnector.to_jsonschema_type(sql_type)
@staticmethod
def to_sql_type(jsonschema_type: dict) -> sqlalchemy.types.TypeEngine:
"""Returns a JSON Schema equivalent for the given SQL type.
Developers may optionally add custom logic before calling the default
implementation inherited from the base class.
"""
# Optionally, add custom logic before calling the parent SQLConnector method.
# You may delete this method if overrides are not needed.
return SQLConnector.to_sql_type(jsonschema_type)
# Custom class extends json.JSONEncoder
class CustomJSONEncoder(json.JSONEncoder):
# Override default() method
def default(self, obj):
# Datetime to string
if isinstance(obj, datetime):
# Format datetime - `Fri, 21 Aug 2020 17:59:59 GMT`
#obj = obj.strftime('%a, %d %b %Y %H:%M:%S GMT')
obj = pendulum.instance(obj).isoformat()
return obj
# Default behavior for all other types
return super().default(obj)
class mssqlStream(SQLStream):
"""Stream class for mssql streams."""
connector_class = mssqlConnector
def get_records(self, partition: Optional[dict]) -> Iterable[Dict[str, Any]]:
"""Return a generator of record-type dictionary objects.
Developers may optionally add custom logic before calling the default
implementation inherited from the base class.
Args:
partition: If provided, will read specifically from this data slice.
Yields:
One dict per record.
"""
# Optionally, add custom logic instead of calling the super().
# This is helpful if the source database provides batch-optimized record
# retrieval.
# If no overrides or optimizations are needed, you may delete this method.
yield from super().get_records(partition)
def get_batches(
self,
batch_config: BatchConfig,
context: dict | None = None,
) -> Iterable[tuple[BaseBatchFileEncoding, list[str]]]:
"""Batch generator function.
Developers are encouraged to override this method to customize batching
behavior for databases, bulk APIs, etc.
Args:
batch_config: Batch config for this stream.
context: Stream partition or context dictionary.
Yields:
A tuple of (encoding, manifest) for each batch.
"""
sync_id = f"{self.tap_name}--{self.name}-{uuid4()}"
prefix = batch_config.storage.prefix or ""
for i, chunk in enumerate(
lazy_chunked_generator(
self._sync_records(context, write_messages=False),
self.batch_size,
),
start=1,
):
filename = f"{prefix}{sync_id}-{i}.json.gz"
with batch_config.storage.fs() as fs:
with fs.open(filename, "wb") as f:
# TODO: Determine compression from config.
with gzip.GzipFile(fileobj=f, mode="wb") as gz:
gz.writelines(
(json.dumps(record, cls=CustomJSONEncoder) + "\n").encode() for record in chunk
)
file_url = fs.geturl(filename)
yield batch_config.encoding, [file_url]