Skip to content

Commit

Permalink
general pep8 and more clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
MRigal committed Apr 29, 2015
1 parent 42ba4a5 commit 042cc66
Show file tree
Hide file tree
Showing 16 changed files with 92 additions and 131 deletions.
31 changes: 24 additions & 7 deletions mongoengine/base/datastructures.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import weakref
import functools
import itertools

from mongoengine.common import _import_class
from mongoengine.errors import DoesNotExist, MultipleObjectsReturned

Expand All @@ -21,7 +21,7 @@ def __init__(self, dict_items, instance, name):
if isinstance(instance, (Document, EmbeddedDocument)):
self._instance = weakref.proxy(instance)
self._name = name
return super(BaseDict, self).__init__(dict_items)
super(BaseDict, self).__init__(dict_items)

def __getitem__(self, key, *args, **kwargs):
value = super(BaseDict, self).__getitem__(key)
Expand Down Expand Up @@ -66,15 +66,15 @@ def __setstate__(self, state):

def clear(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseDict, self).clear(*args, **kwargs)
return super(BaseDict, self).clear()

def pop(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseDict, self).pop(*args, **kwargs)

def popitem(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseDict, self).popitem(*args, **kwargs)
return super(BaseDict, self).popitem()

def setdefault(self, *args, **kwargs):
self._mark_as_changed()
Expand Down Expand Up @@ -190,7 +190,7 @@ def remove(self, *args, **kwargs):

def reverse(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseList, self).reverse(*args, **kwargs)
return super(BaseList, self).reverse()

def sort(self, *args, **kwargs):
self._mark_as_changed()
Expand Down Expand Up @@ -369,45 +369,58 @@ class StrictDict(object):
__slots__ = ()
_special_fields = set(['get', 'pop', 'iteritems', 'items', 'keys', 'create'])
_classes = {}

def __init__(self, **kwargs):
for k,v in kwargs.iteritems():
for k, v in kwargs.iteritems():
setattr(self, k, v)

def __getitem__(self, key):
key = '_reserved_' + key if key in self._special_fields else key
try:
return getattr(self, key)
except AttributeError:
raise KeyError(key)

def __setitem__(self, key, value):
key = '_reserved_' + key if key in self._special_fields else key
return setattr(self, key, value)

def __contains__(self, key):
return hasattr(self, key)

def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default

def pop(self, key, default=None):
v = self.get(key, default)
try:
delattr(self, key)
except AttributeError:
pass
return v

def iteritems(self):
for key in self:
yield key, self[key]

def items(self):
return [(k, self[k]) for k in iter(self)]

def keys(self):
return list(iter(self))

def __iter__(self):
return (key for key in self.__slots__ if hasattr(self, key))

def __len__(self):
return len(list(self.iteritems()))

def __eq__(self, other):
return self.items() == other.items()

def __neq__(self, other):
return self.items() != other.items()

Expand All @@ -418,15 +431,18 @@ def create(cls, allowed_keys):
if allowed_keys not in cls._classes:
class SpecificStrictDict(cls):
__slots__ = allowed_keys_tuple

def __repr__(self):
return "{%s}" % ', '.join('"{0!s}": {0!r}'.format(k,v) for (k,v) in self.iteritems())
return "{%s}" % ', '.join('"{0!s}": {0!r}'.format(k, v) for (k, v) in self.iteritems())

cls._classes[allowed_keys] = SpecificStrictDict
return cls._classes[allowed_keys]


class SemiStrictDict(StrictDict):
__slots__ = ('_extras')
_classes = {}

def __getattr__(self, attr):
try:
super(SemiStrictDict, self).__getattr__(attr)
Expand All @@ -435,6 +451,7 @@ def __getattr__(self, attr):
return self.__getattribute__('_extras')[attr]
except KeyError as e:
raise AttributeError(e)

def __setattr__(self, attr, value):
try:
super(SemiStrictDict, self).__setattr__(attr, value)
Expand Down
16 changes: 7 additions & 9 deletions mongoengine/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from mongoengine.errors import (ValidationError, InvalidDocumentError,
LookUpError, FieldDoesNotExist)
from mongoengine.python_support import PY3, txt_type

from mongoengine.base.common import get_document, ALLOW_INHERITANCE
from mongoengine.base.datastructures import (
BaseDict,
Expand Down Expand Up @@ -420,7 +419,7 @@ def to_json(self, *args, **kwargs):
:param use_db_field: Set to True by default but enables the output of the json structure with the field names and not the mongodb store db_names in case of set to False
"""
use_db_field = kwargs.pop('use_db_field', True)
return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs)
return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs)

@classmethod
def from_json(cls, json_data, created=False):
Expand Down Expand Up @@ -569,7 +568,7 @@ def _get_changed_fields(self, inspected=None):
continue
elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument))
and db_field_name not in changed_fields):
# Find all embedded fields that have been changed
# Find all embedded fields that have been changed
changed = data._get_changed_fields(inspected)
changed_fields += ["%s%s" % (key, k) for k in changed if k]
elif (isinstance(data, (list, tuple, dict)) and
Expand Down Expand Up @@ -614,7 +613,7 @@ def _delta(self):
else:
set_data = doc
if '_id' in set_data:
del(set_data['_id'])
del set_data['_id']

# Determine if any changed items were actually unset.
for path, value in set_data.items():
Expand All @@ -625,7 +624,7 @@ def _delta(self):
default = None
if (self._dynamic and len(parts) and parts[0] in
self._dynamic_fields):
del(set_data[path])
del set_data[path]
unset_data[path] = 1
continue
elif path in self._fields:
Expand Down Expand Up @@ -659,7 +658,7 @@ def _delta(self):
if default != value:
continue

del(set_data[path])
del set_data[path]
unset_data[path] = 1
return set_data, unset_data

Expand Down Expand Up @@ -775,7 +774,7 @@ def _build_index_spec(cls, spec):
allow_inheritance = cls._meta.get('allow_inheritance',
ALLOW_INHERITANCE)
include_cls = (allow_inheritance and not spec.get('sparse', False) and
spec.get('cls', True))
spec.get('cls', True))

# 733: don't include cls if index_cls is False unless there is an explicit cls with the index
include_cls = include_cls and (spec.get('cls', False) or cls._meta.get('index_cls', True))
Expand Down Expand Up @@ -805,7 +804,6 @@ def _build_index_spec(cls, spec):
parts = key.split('.')
if parts in (['pk'], ['id'], ['_id']):
key = '_id'
fields = []
else:
fields = cls._lookup_field(parts)
parts = []
Expand Down Expand Up @@ -966,7 +964,7 @@ def _lookup_field(cls, parts):
if hasattr(getattr(field, 'field', None), 'lookup_member'):
new_field = field.field.lookup_member(field_name)
else:
# Look up subfield on the previous field
# Look up subfield on the previous field
new_field = field.lookup_member(field_name)
if not new_field and isinstance(field, ComplexBaseField):
if hasattr(field.field, 'document_type') and cls._dynamic \
Expand Down
13 changes: 4 additions & 9 deletions mongoengine/base/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from mongoengine.common import _import_class
from mongoengine.errors import ValidationError

from mongoengine.base.common import ALLOW_INHERITANCE
from mongoengine.base.datastructures import (
BaseDict, BaseList, EmbeddedDocumentList
Expand All @@ -18,7 +17,6 @@


class BaseField(object):

"""A base class for fields in a MongoDB document. Instances of this class
may be added to subclasses of `Document` to define a document's schema.
Expand Down Expand Up @@ -191,7 +189,6 @@ def _validate(self, value, **kwargs):


class ComplexBaseField(BaseField):

"""Handles complex fields, such as lists / dictionaries.
Allows for nesting of embedded documents inside complex types.
Expand Down Expand Up @@ -308,8 +305,8 @@ def to_mongo(self, value):
return GenericReferenceField().to_mongo(value)
cls = value.__class__
val = value.to_mongo()
# If we its a document thats not inherited add _cls
if (isinstance(value, EmbeddedDocument)):
# If it's a document that is not inherited add _cls
if isinstance(value, EmbeddedDocument):
val['_cls'] = cls.__name__
return val

Expand Down Expand Up @@ -348,8 +345,8 @@ def to_mongo(self, value):
elif hasattr(v, 'to_mongo'):
cls = v.__class__
val = v.to_mongo()
# If we its a document thats not inherited add _cls
if (isinstance(v, (Document, EmbeddedDocument))):
# If it's a document that is not inherited add _cls
if isinstance(v, (Document, EmbeddedDocument)):
val['_cls'] = cls.__name__
value_dict[k] = val
else:
Expand Down Expand Up @@ -405,7 +402,6 @@ def _get_owner_document(self, owner_document):


class ObjectIdField(BaseField):

"""A field wrapper around MongoDB's ObjectIds.
"""

Expand Down Expand Up @@ -434,7 +430,6 @@ def validate(self, value):


class GeoJsonBaseField(BaseField):

"""A geo json field storing a geojson style object.
.. versionadded:: 0.8
Expand Down
19 changes: 6 additions & 13 deletions mongoengine/base/metaclasses.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
import warnings

import pymongo

from mongoengine.common import _import_class
from mongoengine.errors import InvalidDocumentError
from mongoengine.python_support import PY3
from mongoengine.queryset import (DO_NOTHING, DoesNotExist,
MultipleObjectsReturned,
QuerySet, QuerySetManager)

QuerySetManager)
from mongoengine.base.common import _document_registry, ALLOW_INHERITANCE
from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField

__all__ = ('DocumentMetaclass', 'TopLevelDocumentMetaclass')


class DocumentMetaclass(type):

"""Metaclass for all documents.
"""

Expand Down Expand Up @@ -146,7 +142,7 @@ def __new__(cls, name, bases, attrs):
for base in document_bases:
if _cls not in base._subclasses:
base._subclasses += (_cls,)
base._types = base._subclasses # TODO depreciate _types
base._types = base._subclasses # TODO depreciate _types

(Document, EmbeddedDocument, DictField,
CachedReferenceField) = cls._import_classes()
Expand Down Expand Up @@ -251,7 +247,6 @@ def _import_classes(cls):


class TopLevelDocumentMetaclass(DocumentMetaclass):

"""Metaclass for top-level documents (i.e. documents that have their own
collection in the database.
"""
Expand All @@ -261,7 +256,7 @@ def __new__(cls, name, bases, attrs):
super_new = super(TopLevelDocumentMetaclass, cls).__new__

# Set default _meta data if base class, otherwise get user defined meta
if (attrs.get('my_metaclass') == TopLevelDocumentMetaclass):
if attrs.get('my_metaclass') == TopLevelDocumentMetaclass:
# defaults
attrs['_meta'] = {
'abstract': True,
Expand All @@ -280,7 +275,7 @@ def __new__(cls, name, bases, attrs):
attrs['_meta'].update(attrs.get('meta', {}))
else:
attrs['_meta'] = attrs.get('meta', {})
# Explictly set abstract to false unless set
# Explicitly set abstract to false unless set
attrs['_meta']['abstract'] = attrs['_meta'].get('abstract', False)
attrs['_is_base_cls'] = False

Expand All @@ -295,7 +290,7 @@ def __new__(cls, name, bases, attrs):

# Clean up top level meta
if 'meta' in attrs:
del(attrs['meta'])
del attrs['meta']

# Find the parent document class
parent_doc_cls = [b for b in flattened_bases
Expand All @@ -308,7 +303,7 @@ def __new__(cls, name, bases, attrs):
and not parent_doc_cls._meta.get('abstract', True)):
msg = "Trying to set a collection on a subclass (%s)" % name
warnings.warn(msg, SyntaxWarning)
del(attrs['_meta']['collection'])
del attrs['_meta']['collection']

# Ensure abstract documents have abstract bases
if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'):
Expand Down Expand Up @@ -411,7 +406,6 @@ def __new__(cls, name, bases, attrs):


class MetaDict(dict):

"""Custom dictionary for meta classes.
Handles the merging of set indexes
"""
Expand All @@ -426,6 +420,5 @@ def merge(self, new_options):


class BasesTuple(tuple):

"""Special class to handle introspection of bases tuple in __new__"""
pass
3 changes: 2 additions & 1 deletion mongoengine/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
try:
connection = None
# check for shared connections
connection_settings_iterator = ((db_alias, settings.copy()) for db_alias, settings in _connection_settings.iteritems())
connection_settings_iterator = (
(db_alias, settings.copy()) for db_alias, settings in _connection_settings.iteritems())
for db_alias, connection_settings in connection_settings_iterator:
connection_settings.pop('name', None)
connection_settings.pop('username', None)
Expand Down
Loading

0 comments on commit 042cc66

Please sign in to comment.