Skip to content

Commit

Permalink
Merge pull request #45 from juliotrigo/support_hybrid_attributes
Browse files Browse the repository at this point in the history
Add support for Hybrid attributes
  • Loading branch information
juliotrigo authored May 7, 2020
2 parents 1d3c39b + c0726c8 commit dee4948
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 2 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ Release Notes
Here you can see the full list of changes between sqlalchemy-filters
versions, where semantic versioning is used: *major.minor.patch*.

Unreleased
----------

* Add support for hybrid attributes (properties and methods): filtering
and sorting (#45) as a continuation of the work started here (#32)
by @vkylamba
- Addresses (#22)

0.11.0
------

Expand Down
31 changes: 31 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ Assuming that we have a SQLAlchemy_ ``query`` object:
name = Column(String(50), nullable=False)
count = Column(Integer, nullable=True)
@hybrid_property
def count_square(self):
return self.count * self.count
@hybrid_method
def three_times_count(self):
return self.count * 3
Base = declarative_base(cls=Base)
Expand Down Expand Up @@ -137,6 +145,21 @@ It is also possible to apply filters to queries defined by fields, functions or
query_alt_2 = session.query(func.count(Foo.id))
query_alt_3 = session.query().select_from(Foo).add_column(Foo.id)
Hybrid attributes
^^^^^^^^^^^^^^^^^
You can filter by a `hybrid attribute`_: a `hybrid property`_ or a `hybrid method`_.
.. code-block:: python
query = session.query(Foo)
filter_spec = [{'field': 'count_square', 'op': '>=', 'value': 25}]
filter_spec = [{'field': 'three_times_count', 'op': '>=', 'value': 15}]
filtered_query = apply_filters(query, filter_spec)
result = filtered_query.all()
Restricted Loads
----------------
Expand Down Expand Up @@ -241,6 +264,11 @@ The behaviour is the same as in ``apply_filters``.
This allows flexibility for clients to sort by fields on related objects
without specifying all possible joins on the server beforehand.
Hybrid attributes
^^^^^^^^^^^^^^^^^
You can sort by a `hybrid attribute`_: a `hybrid property`_ or a `hybrid method`_.
Pagination
----------
Expand Down Expand Up @@ -489,3 +517,6 @@ for details.
.. _SQLAlchemy: https://www.sqlalchemy.org/
.. _hybrid attribute: https://docs.sqlalchemy.org/en/13/orm/extensions/hybrid.html
.. _hybrid property: https://docs.sqlalchemy.org/en/13/orm/extensions/hybrid.html#sqlalchemy.ext.hybrid.hybrid_property
.. _hybrid method: https://docs.sqlalchemy.org/en/13/orm/extensions/hybrid.html#sqlalchemy.ext.hybrid.hybrid_method
34 changes: 32 additions & 2 deletions sqlalchemy_filters/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.inspection import inspect
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.util import symbol
import types

from .exceptions import BadQuery, FieldNotFound, BadSpec

Expand All @@ -12,13 +14,41 @@ def __init__(self, model, field_name):
self.field_name = field_name

def get_sqlalchemy_field(self):
if self.field_name not in inspect(self.model).columns.keys():
if self.field_name not in self._get_valid_field_names():
raise FieldNotFound(
'Model {} has no column `{}`.'.format(
self.model, self.field_name
)
)
return getattr(self.model, self.field_name)
sqlalchemy_field = getattr(self.model, self.field_name)

# If it's a hybrid method, then we call it so that we can work with
# the result of the execution and not with the method object itself
if isinstance(sqlalchemy_field, types.MethodType):
sqlalchemy_field = sqlalchemy_field()

return sqlalchemy_field

def _get_valid_field_names(self):
inspect_mapper = inspect(self.model)
columns = inspect_mapper.columns
orm_descriptors = inspect_mapper.all_orm_descriptors

column_names = columns.keys()
hybrid_names = [
key for key, item in orm_descriptors.items()
if _is_hybrid_property(item) or _is_hybrid_method(item)
]

return set(column_names) | set(hybrid_names)


def _is_hybrid_property(orm_descriptor):
return orm_descriptor.extension_type == symbol('HYBRID_PROPERTY')


def _is_hybrid_method(orm_descriptor):
return orm_descriptor.extension_type == symbol('HYBRID_METHOD')


def get_query_models(query):
Expand Down
101 changes: 101 additions & 0 deletions test/interface/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,3 +1215,104 @@ def test_not_any_values_in_array(self, session, is_postgresql):
assert len(result) == 2
assert result[0].id == 1
assert result[1].id == 4


class TestHybridAttributes:

@pytest.mark.usefixtures('multiple_bars_inserted')
@pytest.mark.parametrize(
('field, expected_error'),
[
('foos', "Model <class 'test.models.Bar'> has no column `foos`."),
(
'__mapper__',
"Model <class 'test.models.Bar'> has no column `__mapper__`.",
),
(
'not_valid',
"Model <class 'test.models.Bar'> has no column `not_valid`.",
),
]
)
def test_orm_descriptors_not_valid_hybrid_attributes(
self, session, field, expected_error
):
query = session.query(Bar)
filters = [
{
'model': 'Bar',
'field': field,
'op': '==',
'value': 100
}
]
with pytest.raises(FieldNotFound) as exc:
apply_filters(query, filters)

assert expected_error in str(exc)

@pytest.mark.usefixtures('multiple_bars_inserted')
@pytest.mark.usefixtures('multiple_quxs_inserted')
def test_filter_by_hybrid_properties(self, session):
query = session.query(Bar, Qux)
filters = [
{
'model': 'Bar',
'field': 'count_square',
'op': '==',
'value': 100
},
{
'model': 'Qux',
'field': 'count_square',
'op': '>=',
'value': 26
},
]

filtered_query = apply_filters(query, filters)
result = filtered_query.all()

assert len(result) == 2
bars, quxs = zip(*result)

assert set(map(type, bars)) == {Bar}
assert {bar.id for bar in bars} == {2}
assert {bar.count_square for bar in bars} == {100}

assert set(map(type, quxs)) == {Qux}
assert {qux.id for qux in quxs} == {2, 4}
assert {qux.count_square for qux in quxs} == {100, 225}

@pytest.mark.usefixtures('multiple_bars_inserted')
@pytest.mark.usefixtures('multiple_quxs_inserted')
def test_filter_by_hybrid_methods(self, session):
query = session.query(Bar, Qux)
filters = [
{
'model': 'Bar',
'field': 'three_times_count',
'op': '==',
'value': 30
},
{
'model': 'Qux',
'field': 'three_times_count',
'op': '>=',
'value': 31
},
]

filtered_query = apply_filters(query, filters)
result = filtered_query.all()

assert len(result) == 1
bars, quxs = zip(*result)

assert set(map(type, bars)) == {Bar}
assert {bar.id for bar in bars} == {2}
assert {bar.three_times_count() for bar in bars} == {30}

assert set(map(type, quxs)) == {Qux}
assert {qux.id for qux in quxs} == {4}
assert {qux.three_times_count() for qux in quxs} == {45}
65 changes: 65 additions & 0 deletions test/interface/test_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,3 +571,68 @@ def test_multiple_sort_fields_desc_nulls_last(
('name_4', None),
('name_5', 50),
]


class TestSortHybridAttributes(object):

"""Tests that results are sorted only according to the provided
filters.
Does NOT test how rows with the same values are sorted since this is
not consistent across RDBMS.
Does NOT test whether `NULL` field values are placed first or last
when sorting since this may differ across RDBMSs.
SQL defines that `NULL` values should be placed together when
sorting, but it does not specify whether they should be placed first
or last.
"""

@pytest.mark.usefixtures('multiple_bars_with_no_nulls_inserted')
def test_single_sort_hybrid_property_asc(self, session):
query = session.query(Bar)
order_by = [{'field': 'count_square', 'direction': 'asc'}]

sorted_query = apply_sort(query, order_by)
results = sorted_query.all()

assert [result.count_square for result in results] == [
1, 4, 4, 9, 25, 100, 144, 225
]

@pytest.mark.usefixtures('multiple_bars_with_no_nulls_inserted')
def test_single_sort_hybrid_property_desc(self, session):
query = session.query(Bar)
order_by = [{'field': 'count_square', 'direction': 'desc'}]

sorted_query = apply_sort(query, order_by)
results = sorted_query.all()

assert [result.count_square for result in results] == [
225, 144, 100, 25, 9, 4, 4, 1
]

@pytest.mark.usefixtures('multiple_bars_with_no_nulls_inserted')
def test_single_sort_hybrid_method_asc(self, session):
query = session.query(Bar)
order_by = [{'field': 'three_times_count', 'direction': 'asc'}]

sorted_query = apply_sort(query, order_by)
results = sorted_query.all()

assert [result.three_times_count() for result in results] == [
3, 6, 6, 9, 15, 30, 36, 45
]

@pytest.mark.usefixtures('multiple_bars_with_no_nulls_inserted')
def test_single_sort_hybrid_method_desc(self, session):
query = session.query(Bar)
order_by = [{'field': 'three_times_count', 'direction': 'desc'}]

sorted_query = apply_sort(query, order_by)
results = sorted_query.all()

assert [result.three_times_count() for result in results] == [
45, 36, 30, 15, 9, 6, 6, 3
]
9 changes: 9 additions & 0 deletions test/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
)
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method
from sqlalchemy.orm import relationship


Expand All @@ -13,6 +14,14 @@ class Base(object):
name = Column(String(50), nullable=False)
count = Column(Integer, nullable=True)

@hybrid_property
def count_square(self):
return self.count * self.count

@hybrid_method
def three_times_count(self):
return self.count * 3


Base = declarative_base(cls=Base)
BasePostgresqlSpecific = declarative_base(cls=Base)
Expand Down

0 comments on commit dee4948

Please sign in to comment.