Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix, load options and limits for many to many truncating results #1389

Merged
merged 4 commits into from
Jun 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 88 additions & 53 deletions flask_appbuilder/models/sqla/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import sqlalchemy as sa
from sqlalchemy import func
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import aliased, Load
from sqlalchemy.orm import aliased, contains_eager, Load, load_only
from sqlalchemy.orm.descriptor_props import SynonymProperty
from sqlalchemy.sql.elements import BinaryExpression
from sqlalchemy_utils.types.uuid import UUIDType

from . import filters, Model
from ..base import BaseInterface
from ..filters import Filters
from ..group import GroupByCol, GroupByDateMonth, GroupByDateYear
from ..mixins import FileColumn, ImageColumn
from ..._compat import as_unicode
Expand Down Expand Up @@ -82,11 +83,9 @@ def model_name(self):
def is_model_already_joined(query, model):
return model in [mapper.class_ for mapper in query._join_entities]

def _get_base_query(
self, query=None, filters=None, order_column="", order_direction=""
):
if filters:
query = filters.apply_all(query)
def _apply_query_order(
self, query, order_column: str, order_direction: str
) -> BaseQuery:
if order_column != "":
# if Model has custom decorator **renders('<COL_NAME>')**
# this decorator will add a property to the method named *_col_name*
Expand All @@ -99,6 +98,13 @@ def _get_base_query(
query = query.order_by(self._get_attr(order_column).desc())
return query

def _get_base_query(
self, query=None, filters=None, order_column="", order_direction=""
):
if filters:
query = filters.apply_all(query)
return self._apply_query_order(query, order_column, order_direction)

def _query_join_relation(self, query: BaseQuery, root_relation: str) -> BaseQuery:
"""
Helper function that applies necessary joins for dotted columns on a
Expand Down Expand Up @@ -153,31 +159,41 @@ def _query_select_options(
if is_column_dotted(column):
root_relation = get_column_root_relation(column)
leaf_column = get_column_leaf(column)
if root_relation not in joined_models:
if self.is_relation_many_to_many(
root_relation
) or self.is_relation_one_to_many(root_relation):
load_options.append(
(
Load(self.obj)
.joinedload(root_relation)
.load_only(leaf_column)
)
)
continue
elif root_relation not in joined_models:
query = self._query_join_relation(query, root_relation)
joined_models.append(root_relation)
load_options.append(
(
Load(self.obj)
.joinedload(root_relation)
.load_only(leaf_column)
)
(contains_eager(root_relation).load_only(leaf_column))
)
else:
# is a custom property method field?
if hasattr(getattr(self.obj, column), "fget"):
pass
# is not a relation and not a function?
elif not self.is_relation(column) and not hasattr(
getattr(self.obj, column), "__call__"
):
load_options.append(Load(self.obj).load_only(column))
# it's a normal column
else:
load_options.append(Load(self.obj))
if not self.is_relation(
column
) and not self.is_property_or_function(column):
load_options.append(load_only(column))
query = query.options(*tuple(load_options))
return query

def _get_non_dotted_filters(self, filters):
dotted_filters = Filters(self.filter_converter_class, self, [], [])
_filters = []
if filters:
for flt, value in zip(filters.filters, filters.values):
if not is_column_dotted(flt.column_name):
_filters.append((flt.column_name, flt.__class__, value))
dotted_filters.add_filter_list(_filters)
return dotted_filters

def query(
self,
filters=None,
Expand All @@ -202,8 +218,6 @@ def query(
the current page size
"""
query = self.session.query(self.obj)
query = self._query_join_dotted_column(query, order_column)
query = self._query_select_options(query, select_columns)
query_count = self.session.query(func.count("*")).select_from(self.obj)

query_count = self._get_base_query(query=query_count, filters=filters)
Expand All @@ -218,6 +232,23 @@ def query(
pk_name = self.get_pk_name()
query = query.order_by(pk_name)

# If order by is not dotted (related) we need to apply it first
if not is_column_dotted(order_column):
query = self._get_non_dotted_filters(filters).apply_all(query)
query = self._apply_query_order(query, order_column, order_direction)

# Pagination comes first
if page and page_size:
query = query.offset(page * page_size)
if page_size:
query = query.limit(page_size)

if select_columns and order_column:
# Use from self strategy
select_columns = select_columns + [order_column]
# Everything uses an inner query because of joins to m/m m/1
query = self._query_select_options(query.from_self(), select_columns)

query = self._get_base_query(
query=query,
filters=filters,
Expand All @@ -226,11 +257,6 @@ def query(
)

count = query_count.scalar()

if page and page_size:
query = query.offset(page * page_size)
if page_size:
query = query.limit(page_size)
return count, query.all()

def query_simple_group(
Expand Down Expand Up @@ -262,19 +288,19 @@ def query_year_group(self, group_by="", filters=None):
-----------------------------------------
"""

def is_image(self, col_name):
def is_image(self, col_name: str) -> bool:
try:
return isinstance(self.list_columns[col_name].type, ImageColumn)
except Exception:
return False

def is_file(self, col_name):
def is_file(self, col_name: str) -> bool:
try:
return isinstance(self.list_columns[col_name].type, FileColumn)
except Exception:
return False

def is_string(self, col_name):
def is_string(self, col_name: str) -> bool:
try:
return (
_is_sqla_type(self.list_columns[col_name].type, sa.types.String)
Expand All @@ -283,97 +309,97 @@ def is_string(self, col_name):
except Exception:
return False

def is_text(self, col_name):
def is_text(self, col_name: str) -> bool:
try:
return _is_sqla_type(self.list_columns[col_name].type, sa.types.Text)
except Exception:
return False

def is_binary(self, col_name):
def is_binary(self, col_name: str) -> bool:
try:
return _is_sqla_type(self.list_columns[col_name].type, sa.types.LargeBinary)
except Exception:
return False

def is_integer(self, col_name):
def is_integer(self, col_name: str) -> bool:
try:
return _is_sqla_type(self.list_columns[col_name].type, sa.types.Integer)
except Exception:
return False

def is_numeric(self, col_name):
def is_numeric(self, col_name: str) -> bool:
try:
return _is_sqla_type(self.list_columns[col_name].type, sa.types.Numeric)
except Exception:
return False

def is_float(self, col_name):
def is_float(self, col_name: str) -> bool:
try:
return _is_sqla_type(self.list_columns[col_name].type, sa.types.Float)
except Exception:
return False

def is_boolean(self, col_name):
def is_boolean(self, col_name: str) -> bool:
try:
return _is_sqla_type(self.list_columns[col_name].type, sa.types.Boolean)
except Exception:
return False

def is_date(self, col_name):
def is_date(self, col_name: str) -> bool:
try:
return _is_sqla_type(self.list_columns[col_name].type, sa.types.Date)
except Exception:
return False

def is_datetime(self, col_name):
def is_datetime(self, col_name: str) -> bool:
try:
return _is_sqla_type(self.list_columns[col_name].type, sa.types.DateTime)
except Exception:
return False

def is_enum(self, col_name):
def is_enum(self, col_name: str) -> bool:
try:
return _is_sqla_type(self.list_columns[col_name].type, sa.types.Enum)
except Exception:
return False

def is_relation(self, col_name):
def is_relation(self, col_name: str) -> bool:
try:
return isinstance(
self.list_properties[col_name], sa.orm.properties.RelationshipProperty
)
except Exception:
return False

def is_relation_many_to_one(self, col_name):
def is_relation_many_to_one(self, col_name: str) -> bool:
try:
if self.is_relation(col_name):
return self.list_properties[col_name].direction.name == "MANYTOONE"
except Exception:
return False

def is_relation_many_to_many(self, col_name):
def is_relation_many_to_many(self, col_name: str) -> bool:
try:
if self.is_relation(col_name):
return self.list_properties[col_name].direction.name == "MANYTOMANY"
except Exception:
return False

def is_relation_one_to_one(self, col_name):
def is_relation_one_to_one(self, col_name: str) -> bool:
try:
if self.is_relation(col_name):
return self.list_properties[col_name].direction.name == "ONETOONE"
except Exception:
return False

def is_relation_one_to_many(self, col_name):
def is_relation_one_to_many(self, col_name: str) -> bool:
try:
if self.is_relation(col_name):
return self.list_properties[col_name].direction.name == "ONETOMANY"
except Exception:
return False

def is_nullable(self, col_name):
def is_nullable(self, col_name: str) -> bool:
if self.is_relation_many_to_one(col_name):
col = self.get_relation_fk(col_name)
return col.nullable
Expand All @@ -382,28 +408,37 @@ def is_nullable(self, col_name):
except Exception:
return False

def is_unique(self, col_name):
def is_unique(self, col_name: str) -> bool:
try:
return self.list_columns[col_name].unique is True
except Exception:
return False

def is_pk(self, col_name):
def is_pk(self, col_name: str) -> bool:
try:
return self.list_columns[col_name].primary_key
except Exception:
return False

def is_pk_composite(self):
def is_pk_composite(self) -> bool:
return len(self.obj.__mapper__.primary_key) > 1

def is_fk(self, col_name):
def is_fk(self, col_name: str) -> bool:
try:
return self.list_columns[col_name].foreign_keys
except Exception:
return False

def get_max_length(self, col_name):
def is_property(self, col_name: str) -> bool:
return hasattr(getattr(self.obj, col_name), "fget")

def is_function(self, col_name: str) -> bool:
return hasattr(getattr(self.obj, col_name), "__call__")

def is_property_or_function(self, col_name: str) -> bool:
return self.is_property(col_name) or self.is_function(col_name)

def get_max_length(self, col_name: str) -> int:
try:
if self.is_enum(col_name):
return -1
Expand Down
4 changes: 4 additions & 0 deletions flask_appbuilder/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ class ModelDottedMMApi(ModelRestApi):
list_columns = ["field_string", "children.field_integer"]
show_columns = ["field_string", "children.field_integer"]

self.modeldottedmmapi = ModelDottedMMApi
self.appbuilder.add_api(ModelDottedMMApi)

class ModelOMParentApi(ModelRestApi):
Expand Down Expand Up @@ -910,8 +911,11 @@ def test_get_list_dotted_mm_field(self):
rv = self.auth_client_get(client, token, uri)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(data["count"], MODEL2_DATA_SIZE)
self.assertEqual(len(data[API_RESULT_RES_KEY]), self.modeldottedmmapi.page_size)
i = 0
self.assertEqual(data[API_RESULT_RES_KEY][i]["field_string"], "0")
self.assertEqual(len(data[API_RESULT_RES_KEY][i]["children"]), 3)
self.assertIn({"field_integer": 1}, data[API_RESULT_RES_KEY][i]["children"])
self.assertIn({"field_integer": 2}, data[API_RESULT_RES_KEY][i]["children"])
self.assertIn({"field_integer": 3}, data[API_RESULT_RES_KEY][i]["children"])
Expand Down