Skip to content

Commit

Permalink
Refactor, upgrade get_nested_collection_attributes
Browse files Browse the repository at this point in the history
1. Upgrade Query to Select
2. Factor out query-building logic. The previous version returned tuples
   of items OR models (ORM objects), depending on the calling code
   (several similar data access methods were combined into this one
   generic method in PR #12056). The Query object would "magically"
   convert tuples of ORM objects to ORM objects. The new unified Select
   object does not do that. As as result, with Select, this method would
   return tuples of items or tuples of models (not models):

   result1 = session.execute(statement2)
   result1 == [("element_identifier_0", "element_identifier_1", "extension", "state"), ...]

   result2 = session.execute(statement2)
   result2 == [(dataset1,), (dataset2,) ...]

   Factoring out the query-building logic and having the caller execute
   it depending on the expected data structure solves this.
  • Loading branch information
jdavcs committed Nov 27, 2023
1 parent f5c93cd commit c0f6b1d
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 32 deletions.
66 changes: 43 additions & 23 deletions lib/galaxy/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6118,7 +6118,7 @@ def __init__(self, id=None, collection_type=None, populated=True, element_count=
self.populated_state = DatasetCollection.populated_states.NEW
self.element_count = element_count

def _get_nested_collection_attributes(
def _build_nested_collection_attributes_stmt(
self,
collection_attributes: Optional[Iterable[str]] = None,
element_attributes: Optional[Iterable[str]] = None,
Expand All @@ -6145,10 +6145,8 @@ def _get_nested_collection_attributes(
dataset_permission_attributes = dataset_permission_attributes or ()
return_entities = return_entities or ()
dataset_collection = self
db_session = object_session(self)
dc = alias(DatasetCollection)
dce = alias(DatasetCollectionElement)

depth_collection_type = dataset_collection.collection_type
order_by_columns = [dce.c.element_index]
nesting_level = 0
Expand All @@ -6158,14 +6156,15 @@ def attribute_columns(column_collection, attributes, nesting_level=None):
return [getattr(column_collection, a).label(f"{a}{label_fragment}") for a in attributes]

q = (
db_session.query(
select(
*attribute_columns(dce.c, element_attributes, nesting_level),
*attribute_columns(dc.c, collection_attributes, nesting_level),
)
.select_from(dce, dc)
.join(dce, dce.c.dataset_collection_id == dc.c.id)
.filter(dc.c.id == dataset_collection.id)
)

while ":" in depth_collection_type:
nesting_level += 1
inner_dc = alias(DatasetCollection)
Expand Down Expand Up @@ -6203,15 +6202,21 @@ def attribute_columns(column_collection, attributes, nesting_level=None):
q = q.add_columns(entity)
if entity == DatasetCollectionElement:
q = q.filter(entity.id == dce.c.id)
return q.distinct().order_by(*order_by_columns)

q = q.distinct().order_by(*order_by_columns)
return q

@property
def dataset_states_and_extensions_summary(self):
if not hasattr(self, "_dataset_states_and_extensions_summary"):
q = self._get_nested_collection_attributes(hda_attributes=("extension",), dataset_attributes=("state",))
stmt = self._build_nested_collection_attributes_stmt(
hda_attributes=("extension",), dataset_attributes=("state",)
)
col_attrs = object_session(self).execute(stmt)

extensions = set()
states = set()
for extension, state in q:
for extension, state in col_attrs:
states.add(state)
extensions.add(extension)

Expand All @@ -6225,8 +6230,10 @@ def has_deferred_data(self):
has_deferred_data = False
if object_session(self):
# TODO: Optimize by just querying without returning the states...
q = self._get_nested_collection_attributes(dataset_attributes=("state",))
for (state,) in q:
stmt = self._build_nested_collection_attributes_stmt(dataset_attributes=("state",))
col_attrs = object_session(self).execute(stmt)

for (state,) in col_attrs:
if state == Dataset.states.DEFERRED:
has_deferred_data = True
break
Expand All @@ -6247,13 +6254,16 @@ def populated_optimized(self):
if ":" not in self.collection_type:
_populated_optimized = self.populated_state == DatasetCollection.populated_states.OK
else:
q = self._get_nested_collection_attributes(
stmt = self._build_nested_collection_attributes_stmt(
collection_attributes=("populated_state",),
inner_filter=InnerCollectionFilter(
"populated_state", operator.__ne__, DatasetCollection.populated_states.OK
),
)
_populated_optimized = q.session.query(~exists(q.subquery())).scalar()
stmt = stmt.subquery()
stmt = select(~exists(stmt))
session = object_session(self)
_populated_optimized = session.scalar(stmt)

self._populated_optimized = _populated_optimized

Expand All @@ -6269,9 +6279,11 @@ def populated(self):
@property
def dataset_action_tuples(self):
if not hasattr(self, "_dataset_action_tuples"):
q = self._get_nested_collection_attributes(dataset_permission_attributes=("action", "role_id"))
stmt = self._build_nested_collection_attributes_stmt(dataset_permission_attributes=("action", "role_id"))
col_attrs = object_session(self).execute(stmt)

_dataset_action_tuples = []
for _dataset_action_tuple in q:
for _dataset_action_tuple in col_attrs:
if _dataset_action_tuple[0] is None:
continue
_dataset_action_tuples.append(_dataset_action_tuple)
Expand All @@ -6282,24 +6294,26 @@ def dataset_action_tuples(self):

@property
def element_identifiers_extensions_and_paths(self):
q = self._get_nested_collection_attributes(
stmt = self._build_nested_collection_attributes_stmt(
element_attributes=("element_identifier",), hda_attributes=("extension",), return_entities=(Dataset,)
)
return [(row[:-2], row.extension, row.Dataset.get_file_name()) for row in q]
col_attrs = object_session(self).execute(stmt)
return [(row[:-2], row.extension, row.Dataset.get_file_name()) for row in col_attrs]

@property
def element_identifiers_extensions_paths_and_metadata_files(
self,
) -> List[List[Any]]:
results = []
if object_session(self):
q = self._get_nested_collection_attributes(
stmt = self._build_nested_collection_attributes_stmt(
element_attributes=("element_identifier",),
hda_attributes=("extension",),
return_entities=(HistoryDatasetAssociation, Dataset),
)
col_attrs = object_session(self).execute(stmt)
# element_identifiers, extension, path
for row in q:
for row in col_attrs:
result = [row[:-3], row.extension, row.Dataset.get_file_name()]
hda = row.HistoryDatasetAssociation
result.append(hda.get_metadata_file_paths_and_extensions())
Expand Down Expand Up @@ -6344,7 +6358,8 @@ def finalize(self, collection_type_description):
def dataset_instances(self):
db_session = object_session(self)
if db_session and self.id:
return self._get_nested_collection_attributes(return_entities=(HistoryDatasetAssociation,)).all()
stmt = self._build_nested_collection_attributes_stmt(return_entities=(HistoryDatasetAssociation,))
return db_session.scalars(stmt).all()
else:
# Sessionless context
instances = []
Expand All @@ -6360,7 +6375,8 @@ def dataset_instances(self):
def dataset_elements(self):
db_session = object_session(self)
if db_session and self.id:
return self._get_nested_collection_attributes(return_entities=(DatasetCollectionElement,)).all()
stmt = self._build_nested_collection_attributes_stmt(return_entities=(DatasetCollectionElement,))
return db_session.scalars(stmt).all()
elements = []
for element in self.elements:
if element.is_collection:
Expand Down Expand Up @@ -6445,9 +6461,11 @@ def copy(
return new_collection

def replace_failed_elements(self, replacements):
hda_id_to_element = dict(
self._get_nested_collection_attributes(return_entities=[DatasetCollectionElement], hda_attributes=["id"])
stmt = self._build_nested_collection_attributes_stmt(
return_entities=[DatasetCollectionElement], hda_attributes=["id"]
)
col_attrs = object_session(self).execute(stmt).all()
hda_id_to_element = dict(col_attrs)
for failed, replacement in replacements.items():
element = hda_id_to_element.get(failed.id)
if element:
Expand Down Expand Up @@ -6712,10 +6730,12 @@ def job_state_summary_dict(self):
@property
def dataset_dbkeys_and_extensions_summary(self):
if not hasattr(self, "_dataset_dbkeys_and_extensions_summary"):
rows = self.collection._get_nested_collection_attributes(hda_attributes=("_metadata", "extension"))
stmt = self.collection._build_nested_collection_attributes_stmt(hda_attributes=("_metadata", "extension"))
col_attrs = object_session(self).execute(stmt)

extensions = set()
dbkeys = set()
for row in rows:
for row in col_attrs:
if row is not None:
dbkey_field = row._metadata.get("dbkey")
if isinstance(dbkey_field, list):
Expand Down
34 changes: 25 additions & 9 deletions test/unit/data/test_galaxy_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def test_collections_in_library_folders(self):
# assert len(loaded_dataset_collection.datasets) == 2
# assert loaded_dataset_collection.collection_type == "pair"

# TODO breakup this test into separate tests that test the model's public attributes, not the internal query-building logic
def test_nested_collection_attributes(self):
u = model.User(email="mary2@example.com", password="password")
h1 = model.History(name="History 1", user=u)
Expand Down Expand Up @@ -392,18 +393,31 @@ def test_nested_collection_attributes(self):
)
self.model.session.add_all([d1, d2, c1, dce1, dce2, c2, dce3, c3, c4, dce4])
self.model.session.flush()
q = c2._get_nested_collection_attributes(

stmt = c2._build_nested_collection_attributes_stmt(
element_attributes=("element_identifier",), hda_attributes=("extension",), dataset_attributes=("state",)
)
assert [(r._fields) for r in q] == [
result = self.model.session.execute(stmt).all()
assert [(r._fields) for r in result] == [
("element_identifier_0", "element_identifier_1", "extension", "state"),
("element_identifier_0", "element_identifier_1", "extension", "state"),
]
assert q.all() == [("inner_list", "forward", "bam", "new"), ("inner_list", "reverse", "txt", "new")]
q = c2._get_nested_collection_attributes(return_entities=(model.HistoryDatasetAssociation,))
assert q.all() == [d1, d2]
q = c2._get_nested_collection_attributes(return_entities=(model.HistoryDatasetAssociation, model.Dataset))
assert q.all() == [(d1, d1.dataset), (d2, d2.dataset)]

stmt = c2._build_nested_collection_attributes_stmt(
element_attributes=("element_identifier",), hda_attributes=("extension",), dataset_attributes=("state",)
)
result = self.model.session.execute(stmt).all()
assert result == [("inner_list", "forward", "bam", "new"), ("inner_list", "reverse", "txt", "new")]

stmt = c2._build_nested_collection_attributes_stmt(return_entities=(model.HistoryDatasetAssociation,))
result = self.model.session.scalars(stmt).all()
assert result == [d1, d2]

stmt = c2._build_nested_collection_attributes_stmt(
return_entities=(model.HistoryDatasetAssociation, model.Dataset)
)
result = self.model.session.execute(stmt).all()
assert result == [(d1, d1.dataset), (d2, d2.dataset)]
# Assert properties that use _get_nested_collection_attributes return correct content
assert c2.dataset_instances == [d1, d2]
assert c2.dataset_elements == [dce1, dce2]
Expand All @@ -422,8 +436,10 @@ def test_nested_collection_attributes(self):
assert c3.dataset_instances == []
assert c3.dataset_elements == []
assert c3.dataset_states_and_extensions_summary == (set(), set())
q = c4._get_nested_collection_attributes(element_attributes=("element_identifier",))
assert q.all() == [("outer_list", "inner_list", "forward"), ("outer_list", "inner_list", "reverse")]

stmt = c4._build_nested_collection_attributes_stmt(element_attributes=("element_identifier",))
result = self.model.session.execute(stmt).all()
assert result == [("outer_list", "inner_list", "forward"), ("outer_list", "inner_list", "reverse")]
assert c4.dataset_elements == [dce1, dce2]
assert c4.element_identifiers_extensions_and_paths == [
(("outer_list", "inner_list", "forward"), "bam", "mock_dataset_14.dat"),
Expand Down

0 comments on commit c0f6b1d

Please sign in to comment.