-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpeewee_extras.py
371 lines (284 loc) · 10.9 KB
/
peewee_extras.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
import peewee
import datetime
import playhouse
from peewee import DateTimeField
####################################################################
# Model manager
####################################################################
class ModelManager(list):
"""Handles model registration"""
def __init__(self, database_manager):
self.dbm = database_manager
def create_tables(self):
"""Create database tables"""
for cls in self:
cls.create_table(fail_silently=True)
def destroy_tables(self):
"""Destroy database tables"""
for cls in self:
cls.drop_table(fail_silently=True)
def register(self, model_cls):
"""Register model(s) with app"""
assert issubclass(model_cls, peewee.Model)
assert not hasattr(model_cls._meta, 'database_manager')
if model_cls in self:
raise RuntimeError("Model already registered")
self.append(model_cls)
model_cls._meta.database = self.dbm
return model_cls
####################################################################
# DB manager
####################################################################
# XXX: improve KeyError message
class DatabaseManager(dict):
"""Database manager"""
def __init__(self):
self.routers = set()
self.models = ModelManager(database_manager=self)
def connect(self):
"""Create connection for all databases"""
for name, connection in self.items():
connection.connect()
def disconnect(self):
"""Disconnect from all databases"""
for name, connection in self.items():
if not connection.is_closed():
connection.close()
def get_database(self, model):
"""Find matching database router"""
for router in self.routers:
r = router.get_database(model)
if r is not None:
return r
return self.get('default')
def register(self, name, db):
if isinstance(db, str):
self[name] = playhouse.db_url.connect(db)
elif isinstance(db, peewee.Database):
self[name] = db
else:
raise ValueError("unexpected 'db' type")
####################################################################
# Database routers
####################################################################
class DatabaseRouter(object):
def get_database(self, model):
return None
####################################################################
# Model
####################################################################
class Metadata(peewee.Metadata):
_database = None
@property
def database(self):
if isinstance(self._database, DatabaseManager):
db = self._database.get_database(self)
if db: return db
return self._database
@database.setter
def database(self, value):
self._database = value
class Model(peewee.Model):
"""Custom model"""
class Meta:
model_metadata_class = Metadata
def update_instance(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
self.save()
@classmethod
def create_or_get(self, **kwargs):
with self.atomic():
try:
return self.create(**kwargs), True
except peewee.IntegrityError:
return self.get(**kwargs), False
@classmethod
def get_or_none(cls, **kwargs):
"""
XXX: needs unit test
"""
try:
return cls.get(**kwargs)
except cls.DoesNotExist:
return None
@classmethod
def atomic(self):
"""Shortcut method for creating atomic context"""
return self._meta.database.atomic()
def to_cursor_ref(self):
"""Returns dict of values to uniquely reference this item"""
fields = self._meta.get_primary_keys()
assert fields
values = {field.name:self.__data__[field.name] for field in fields}
return values
@classmethod
def from_cursor_ref(self, cursor):
"""Returns model instance from unique cursor reference"""
return self.get(**cursor)
def refetch(self):
"""
Return new model instance with fresh data from database
Only works on models which have a primary or compound key
See https://github.com/coleifer/peewee/issues/638
XXX: Add support for models without PK
"""
ref = self.to_cursor_ref()
return self.from_cursor_ref(ref)
####################################################################
# Mixins
####################################################################
def utcnow_no_ms():
"""Returns utcnow without microseconds"""
return datetime.datetime.utcnow().replace(microsecond=0)
class TimestampModelMixin(Model):
"""Track creation and modification times"""
created = DateTimeField(default=utcnow_no_ms)
modified = DateTimeField()
def save(self, **kwargs):
self.modified = datetime.datetime.now()
return super(TimestampModelMixin, self).save(**kwargs)
####################################################################
# Pagination
####################################################################
class Pagination:
pass
class PrimaryKeyPagination(Pagination):
"""
Primary key pagination
It does not support models with compound keys or no primary key
as doing so would require using LIMIT/OFFSET which has terrible
performance at scale. If you want this, send a PR.
"""
@classmethod
def paginate_query(self, query, count, offset=None, sort=None):
"""
Apply pagination to query
:attr query: Instance of `peewee.Query`
:attr count: Max rows to return
:attr offset: Pagination offset, str/int
:attr sort: List of tuples, e.g. [('id', 'asc')]
:returns: Instance of `peewee.Query`
"""
assert isinstance(query, peewee.Query)
assert isinstance(count, int)
assert isinstance(offset, (str, int, type(None)))
assert isinstance(sort, (list, set, tuple, type(None)))
# ensure our model has a primary key
fields = query.model._meta.get_primary_keys()
if len(fields) == 0:
raise peewee.ProgrammingError(
'Cannot apply pagination on model without primary key')
# ensure our model doesn't use a compound primary key
if len(fields) > 1:
raise peewee.ProgrammingError(
'Cannot apply pagination on model with compound primary key')
# apply offset
if offset is not None:
query = query.where(fields[0] >= offset)
# do we need to apply sorting?
order_bys = []
if sort:
for field, direction in sort:
# does this field have a valid sort direction?
if not isinstance(direction, str):
raise ValueError("Invalid sort direction on field '{}'".format(field))
direction = direction.lower().strip()
if direction not in ['asc', 'desc']:
raise ValueError("Invalid sort direction on field '{}'".format(field))
# apply sorting
order_by = peewee.SQL(field)
order_by = getattr(order_by, direction)()
order_bys += [order_by]
# add primary key ordering after user sorting
order_bys += [fields[0].asc()]
# apply ordering and limits
query = query.order_by(*order_bys)
query = query.limit(count)
return query
####################################################################
# Model List
# XXX: Restrict which fields can be filtered
# XXX: Add sort capabilities
# XXX: Do we want to add encryption support? (yes but it should be outside here)
####################################################################
class ModelCRUD:
paginator = None
query = None
sort_fields = []
filter_fields = []
'''
def get_sort_schema(self):
"""
Returns marshmallow schema for validating sort parameters
This is dynamically generated from `sort_fields` but can be
overwritten with custom logic if desired
"""
attrs = {}
for field in self.sort_fields:
# convert sort direction to lower and remove any whitespace
key = 'lower_{}'.format(field)
attrs[key] = post_load(lambda item: item.lower().strip())
# validate sort direction
attrs[field] = marshmallow.fields.String(
validator=marshmallow.validate.OneOf('asc', 'desc'))
return type('SortSchema', (marshmallow.Schema,), attrs)
# do we have valid sort parameters?
sort_schema = self.get_sort_schema()
try:
clean_params = sort_schema.dump(params)
except marshmallow.ValidationError as exc:
nexc = ValueError("Invalid sort parameters specified")
nexc.errors = exc.messages
raise nexc
'''
def get_query(self):
"""Return query for our model"""
return self.query
def get_paginator(self):
"""Return pagination for our model"""
return self.paginator
def apply_filters(self, query, filters):
"""
Apply user specified filters to query
"""
assert isinstance(query, peewee.Query)
assert isinstance(filters, dict)
def list(self, filters, cursor, count):
"""
List items from query
"""
assert isinstance(filters, dict), "expected filters type 'dict'"
assert isinstance(cursor, dict), "expected cursor type 'dict'"
# start with our base query
query = self.get_query()
assert isinstance(query, peewee.Query)
# XXX: convert and apply user specified filters
#filters = {field.name: cursor[field.name] for field in fields}
#query.where(
paginator = self.get_paginator()
assert isinstance(paginator, Pagination)
# always include an extra row for next cursor position
count += 1
# apply pagination to query
pquery = paginator.filter_query(query, cursor, count)
items = [ item for item in pquery ]
# determine next cursor position
next_item = items.pop(1)
next_cursor = next_item.to_cursor_ref()
'''
# is this field allowed for sort?
if field not in self.sort_fields:
raise ValueError("Cannot sort on field '{}'".format(field))
'''
return items, next_cursor
def retrieve(self, cursor):
"""
Retrieve items from query
"""
assert isinstance(cursor, dict), "expected cursor type 'dict'"
# look for record in query
query = self.get_query()
assert isinstance(query, peewee.Query)
query
return query.get(**cursor)