diff --git a/aiida/backends/sqlalchemy/migrations/env.py b/aiida/backends/sqlalchemy/migrations/env.py index d148bd54d2..9fa134f0f0 100644 --- a/aiida/backends/sqlalchemy/migrations/env.py +++ b/aiida/backends/sqlalchemy/migrations/env.py @@ -30,20 +30,19 @@ def run_migrations_online(): from aiida.backends.sqlalchemy.models.base import Base config = context.config # pylint: disable=no-member - connectable = config.attributes.get('connection', None) + connection = config.attributes.get('connection', None) - if connectable is None: + if connection is None: from aiida.common.exceptions import ConfigurationError raise ConfigurationError('An initialized connection is expected for the AiiDA online migrations.') - with connectable.connect() as connection: - context.configure( # pylint: disable=no-member - connection=connection, - target_metadata=Base.metadata, - transaction_per_migration=True, - ) + context.configure( # pylint: disable=no-member + connection=connection, + target_metadata=Base.metadata, + transaction_per_migration=True, + ) - context.run_migrations() # pylint: disable=no-member + context.run_migrations() # pylint: disable=no-member try: diff --git a/aiida/backends/sqlalchemy/migrations/versions/041a79fc615f_dblog_cleaning.py b/aiida/backends/sqlalchemy/migrations/versions/041a79fc615f_dblog_cleaning.py index 952bed3cac..a05796e0d5 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/041a79fc615f_dblog_cleaning.py +++ b/aiida/backends/sqlalchemy/migrations/versions/041a79fc615f_dblog_cleaning.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name,no-member,import-error,no-name-in-module +# pylint: disable=invalid-name,no-member,import-error,no-name-in-module,protected-access """This migration cleans the log records from non-Node entity records. It removes from the DbLog table the legacy workflow records and records that correspond to an unknown entity and places them to corresponding files. @@ -95,7 +95,7 @@ def get_serialized_legacy_workflow_logs(connection): ) res = list() for row in query: - res.append(dict(list(zip(row.keys(), row)))) + res.append(row._asdict()) return dumps_json(res) @@ -114,7 +114,7 @@ def get_serialized_unknown_entity_logs(connection): ) res = list() for row in query: - res.append(dict(list(zip(row.keys(), row)))) + res.append(row._asdict()) return dumps_json(res) @@ -133,7 +133,7 @@ def get_serialized_logs_with_no_nodes(connection): ) res = list() for row in query: - res.append(dict(list(zip(row.keys(), row)))) + res.append(row._asdict()) return dumps_json(res) diff --git a/aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py b/aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py index d53ec44ce3..d42f7d0813 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py +++ b/aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py @@ -46,7 +46,7 @@ def upgrade(): column('attributes', JSONB)) nodes = connection.execute( - select([DbNode.c.id, DbNode.c.uuid]).where( + select(DbNode.c.id, DbNode.c.uuid).where( DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall() for pk, uuid in nodes: @@ -64,7 +64,7 @@ def downgrade(): column('attributes', JSONB)) nodes = connection.execute( - select([DbNode.c.id, DbNode.c.uuid]).where( + select(DbNode.c.id, DbNode.c.uuid).where( DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall() for pk, _ in nodes: diff --git a/aiida/backends/sqlalchemy/migrations/versions/1b8ed3425af9_remove_legacy_workflows.py b/aiida/backends/sqlalchemy/migrations/versions/1b8ed3425af9_remove_legacy_workflows.py index 2b0eed82a1..cabb6b487a 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/1b8ed3425af9_remove_legacy_workflows.py +++ b/aiida/backends/sqlalchemy/migrations/versions/1b8ed3425af9_remove_legacy_workflows.py @@ -58,9 +58,9 @@ def export_workflow_data(connection): DbWorkflowData = table('db_dbworkflowdata') DbWorkflowStep = table('db_dbworkflowstep') - count_workflow = connection.execute(select([func.count()]).select_from(DbWorkflow)).scalar() - count_workflow_data = connection.execute(select([func.count()]).select_from(DbWorkflowData)).scalar() - count_workflow_step = connection.execute(select([func.count()]).select_from(DbWorkflowStep)).scalar() + count_workflow = connection.execute(select(func.count()).select_from(DbWorkflow)).scalar() + count_workflow_data = connection.execute(select(func.count()).select_from(DbWorkflowData)).scalar() + count_workflow_step = connection.execute(select(func.count()).select_from(DbWorkflowStep)).scalar() # Nothing to do if all tables are empty if count_workflow == 0 and count_workflow_data == 0 and count_workflow_step == 0: @@ -78,9 +78,9 @@ def export_workflow_data(connection): delete_on_close = configuration.PROFILE.is_test_profile data = { - 'workflow': [dict(row) for row in connection.execute(select(['*']).select_from(DbWorkflow))], - 'workflow_data': [dict(row) for row in connection.execute(select(['*']).select_from(DbWorkflowData))], - 'workflow_step': [dict(row) for row in connection.execute(select(['*']).select_from(DbWorkflowStep))], + 'workflow': [dict(row) for row in connection.execute(select('*').select_from(DbWorkflow))], + 'workflow_data': [dict(row) for row in connection.execute(select('*').select_from(DbWorkflowData))], + 'workflow_step': [dict(row) for row in connection.execute(select('*').select_from(DbWorkflowStep))], } with NamedTemporaryFile( diff --git a/aiida/backends/sqlalchemy/migrations/versions/1feaea71bd5a_migrate_repository.py b/aiida/backends/sqlalchemy/migrations/versions/1feaea71bd5a_migrate_repository.py index 24a74a6c19..304f7077de 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/1feaea71bd5a_migrate_repository.py +++ b/aiida/backends/sqlalchemy/migrations/versions/1feaea71bd5a_migrate_repository.py @@ -45,7 +45,7 @@ def upgrade(): ) profile = get_profile() - node_count = connection.execute(select([func.count()]).select_from(DbNode)).scalar() + node_count = connection.execute(select(func.count()).select_from(DbNode)).scalar() missing_repo_folder = [] shard_count = 256 diff --git a/aiida/backends/sqlalchemy/migrations/versions/239cea6d2452_provenance_redesign.py b/aiida/backends/sqlalchemy/migrations/versions/239cea6d2452_provenance_redesign.py index a0ff49e325..33f3edfaef 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/239cea6d2452_provenance_redesign.py +++ b/aiida/backends/sqlalchemy/migrations/versions/239cea6d2452_provenance_redesign.py @@ -41,7 +41,7 @@ def migrate_infer_calculation_entry_point(connection): column('process_type', String) ) - query_set = connection.execute(select([DbNode.c.type]).where(DbNode.c.type.like('calculation.%'))).fetchall() + query_set = connection.execute(select(DbNode.c.type).where(DbNode.c.type.like('calculation.%'))).fetchall() type_strings = set(entry[0] for entry in query_set) mapping_node_type_to_entry_point = infer_calculation_entry_point(type_strings=type_strings) @@ -54,7 +54,7 @@ def migrate_infer_calculation_entry_point(connection): # All affected entries should be logged to file that the user can consult. if ENTRY_POINT_STRING_SEPARATOR not in entry_point_string: query_set = connection.execute( - select([DbNode.c.uuid]).where(DbNode.c.type == op.inline_literal(type_string)) + select(DbNode.c.uuid).where(DbNode.c.type == op.inline_literal(type_string)) ).fetchall() uuids = [str(entry.uuid) for entry in query_set] diff --git a/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py b/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py index bd0ad4409f..91ba715abd 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py +++ b/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py @@ -31,8 +31,8 @@ def upgrade(): """Migrations for the upgrade.""" op.drop_table('db_dbpath') conn = op.get_bind() - conn.execute('DROP TRIGGER IF EXISTS autoupdate_tc ON db_dblink') - conn.execute('DROP FUNCTION IF EXISTS update_tc()') + conn.execute(sa.text('DROP TRIGGER IF EXISTS autoupdate_tc ON db_dblink')) + conn.execute(sa.text('DROP FUNCTION IF EXISTS update_tc()')) def downgrade(): diff --git a/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py b/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py index 765a4eaa6a..1c36359b36 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py +++ b/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py @@ -43,7 +43,7 @@ def upgrade(): column('attributes', JSONB)) nodes = connection.execute( - select([DbNode.c.id, DbNode.c.uuid]).where( + select(DbNode.c.id, DbNode.c.uuid).where( DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall() for pk, uuid in nodes: @@ -61,11 +61,11 @@ def downgrade(): column('attributes', JSONB)) nodes = connection.execute( - select([DbNode.c.id, DbNode.c.uuid]).where( + select(DbNode.c.id, DbNode.c.uuid).where( DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall() for pk, uuid in nodes: - attributes = connection.execute(select([DbNode.c.attributes]).where(DbNode.c.id == pk)).fetchone() + attributes = connection.execute(select(DbNode.c.attributes).where(DbNode.c.id == pk)).fetchone() symbols = numpy.array(attributes['symbols']) utils.store_numpy_array_in_repository(uuid, 'symbols', symbols) key = op.inline_literal('{"array|symbols"}') diff --git a/aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py b/aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py index 6060e03ef7..b7e4a80fa6 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py +++ b/aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py @@ -33,7 +33,7 @@ def set_new_uuid(connection): from aiida.common.utils import get_new_uuid # Exit if there are no rows - e.g. initial setup - id_query = connection.execute('SELECT db_dblog.id FROM db_dblog') + id_query = connection.execute(sa.text('SELECT db_dblog.id FROM db_dblog')) if id_query.rowcount == 0: return @@ -52,7 +52,7 @@ def set_new_uuid(connection): UPDATE db_dblog as t SET uuid = uuid(c.uuid) from (values {key_values}) as c(id, uuid) where c.id = t.id""" - connection.execute(update_stm) + connection.execute(sa.text(update_stm)) def upgrade(): diff --git a/aiida/backends/sqlalchemy/models/base.py b/aiida/backends/sqlalchemy/models/base.py index 73a7cba6cf..dd7f6ab9ad 100644 --- a/aiida/backends/sqlalchemy/models/base.py +++ b/aiida/backends/sqlalchemy/models/base.py @@ -11,7 +11,7 @@ """Base SQLAlchemy models.""" from sqlalchemy import orm -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import declarative_base from sqlalchemy.orm.exc import UnmappedClassError import aiida.backends.sqlalchemy diff --git a/aiida/backends/sqlalchemy/utils.py b/aiida/backends/sqlalchemy/utils.py index edb7369ff3..e7f690aa17 100644 --- a/aiida/backends/sqlalchemy/utils.py +++ b/aiida/backends/sqlalchemy/utils.py @@ -60,6 +60,8 @@ def install_tc(session): """ Install the transitive closure table with SqlAlchemy. """ + from sqlalchemy import text + links_table_name = 'db_dblink' links_table_input_field = 'input_id' links_table_output_field = 'output_id' @@ -68,9 +70,11 @@ def install_tc(session): closure_table_child_field = 'child_id' session.execute( - get_pg_tc( - links_table_name, links_table_input_field, links_table_output_field, closure_table_name, - closure_table_parent_field, closure_table_child_field + text( + get_pg_tc( + links_table_name, links_table_input_field, links_table_output_field, closure_table_name, + closure_table_parent_field, closure_table_child_field + ) ) ) diff --git a/aiida/backends/utils.py b/aiida/backends/utils.py index 0b42aa378d..30ab18ae01 100644 --- a/aiida/backends/utils.py +++ b/aiida/backends/utils.py @@ -38,14 +38,14 @@ def create_sqlalchemy_engine(profile, **kwargs): name=profile.database_name ) return create_engine( - engine_url, json_serializer=json.dumps, json_deserializer=json.loads, future=False, encoding='utf-8', **kwargs + engine_url, json_serializer=json.dumps, json_deserializer=json.loads, future=True, encoding='utf-8', **kwargs ) def create_scoped_session_factory(engine, **kwargs): """Create scoped SQLAlchemy session factory""" from sqlalchemy.orm import scoped_session, sessionmaker - return scoped_session(sessionmaker(bind=engine, **kwargs)) + return scoped_session(sessionmaker(bind=engine, future=True, **kwargs)) def delete_nodes_and_connections(pks): diff --git a/aiida/orm/implementation/sqlalchemy/backend.py b/aiida/orm/implementation/sqlalchemy/backend.py index 3661ee44a7..89f9efffa4 100644 --- a/aiida/orm/implementation/sqlalchemy/backend.py +++ b/aiida/orm/implementation/sqlalchemy/backend.py @@ -83,18 +83,13 @@ def transaction(self): entering. Transactions can be nested. """ session = self.get_session() - nested = session.transaction.nested - try: - session.begin_nested() - yield session - session.commit() - except Exception: - session.rollback() - raise - finally: - if not nested: - # Make sure to commit the outermost session - session.commit() + if session.in_transaction(): + with session.begin_nested(): + yield session + else: + with session.begin(): + with session.begin_nested(): + yield session @staticmethod def get_session(): @@ -131,10 +126,11 @@ def execute_raw(self, query): :param query: a string containing a raw SQL statement :return: the result of the query """ + from sqlalchemy import text from sqlalchemy.exc import ResourceClosedError # pylint: disable=import-error,no-name-in-module with self.transaction() as session: - queryset = session.execute(query) + queryset = session.execute(text(query)) try: results = queryset.fetchall() diff --git a/aiida/orm/implementation/sqlalchemy/computers.py b/aiida/orm/implementation/sqlalchemy/computers.py index 30eb2339c0..14e500f05b 100644 --- a/aiida/orm/implementation/sqlalchemy/computers.py +++ b/aiida/orm/implementation/sqlalchemy/computers.py @@ -137,7 +137,7 @@ def list_names(): def delete(self, pk): try: session = get_scoped_session() - session.query(DbComputer).get(pk).delete() + session.get(DbComputer, pk).delete() session.commit() except SQLAlchemyError as exc: raise exceptions.InvalidOperation( diff --git a/aiida/orm/implementation/sqlalchemy/groups.py b/aiida/orm/implementation/sqlalchemy/groups.py index 482f264e95..d6c34e5a9f 100644 --- a/aiida/orm/implementation/sqlalchemy/groups.py +++ b/aiida/orm/implementation/sqlalchemy/groups.py @@ -367,5 +367,5 @@ def query( def delete(self, id): # pylint: disable=redefined-builtin session = sa.get_scoped_session() - session.query(DbGroup).get(id).delete() + session.get(DbGroup, id).delete() session.commit() diff --git a/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py b/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py index e32e784901..2a1f996f31 100644 --- a/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py +++ b/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py @@ -79,7 +79,8 @@ class SqlaJoiner: """A class containing the logic for SQLAlchemy entities joining entities.""" def __init__( - self, entity_mapper: _EntityMapper, filter_builder: Callable[[AliasedClass, FilterType], BooleanClauseList] + self, entity_mapper: _EntityMapper, filter_builder: Callable[[AliasedClass, FilterType], + Optional[BooleanClauseList]] ): """Initialise the class""" self._entities = entity_mapper @@ -185,7 +186,13 @@ def _join_descendants_recursive( link1 = aliased(self._entities.Link) link2 = aliased(self._entities.Link) node1 = aliased(self._entities.Node) + + link_filters = link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # follow input / create links in_recursive_filters = self._build_filters(node1, filter_dict) + if in_recursive_filters is None: + filters = link_filters + else: + filters = and_(in_recursive_filters, link_filters) selection_walk_list = [ link1.input_id.label('ancestor_id'), @@ -195,12 +202,8 @@ def _join_descendants_recursive( if expand_path: selection_walk_list.append(array((link1.input_id, link1.output_id)).label('path')) - walk = select(selection_walk_list).select_from(join(node1, link1, link1.input_id == node1.id)).where( - and_( - in_recursive_filters, # I apply filters for speed here - link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # I follow input and create links - ) - ).cte(recursive=True) + walk = select(*selection_walk_list).select_from(join(node1, link1, link1.input_id == node1.id) + ).where(filters).cte(recursive=True) aliased_walk = aliased(walk) @@ -214,13 +217,12 @@ def _join_descendants_recursive( descendants_recursive = aliased( aliased_walk.union_all( - select(selection_union_list).select_from( - join( - aliased_walk, - link2, - link2.input_id == aliased_walk.c.descendant_id, - ) - ).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) + select(*selection_union_list + ).select_from(join( + aliased_walk, + link2, + link2.input_id == aliased_walk.c.descendant_id, + )).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) ) ) # .alias() @@ -249,7 +251,13 @@ def _join_ancestors_recursive( link1 = aliased(self._entities.Link) link2 = aliased(self._entities.Link) node1 = aliased(self._entities.Node) + + link_filters = link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # follow input / create links in_recursive_filters = self._build_filters(node1, filter_dict) + if in_recursive_filters is None: + filters = link_filters + else: + filters = and_(in_recursive_filters, link_filters) selection_walk_list = [ link1.input_id.label('ancestor_id'), @@ -259,9 +267,8 @@ def _join_ancestors_recursive( if expand_path: selection_walk_list.append(array((link1.output_id, link1.input_id)).label('path')) - walk = select(selection_walk_list).select_from(join(node1, link1, link1.output_id == node1.id)).where( - and_(in_recursive_filters, link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) - ).cte(recursive=True) + walk = select(*selection_walk_list).select_from(join(node1, link1, link1.output_id == node1.id) + ).where(filters).cte(recursive=True) aliased_walk = aliased(walk) @@ -275,13 +282,12 @@ def _join_ancestors_recursive( ancestors_recursive = aliased( aliased_walk.union_all( - select(selection_union_list).select_from( - join( - aliased_walk, - link2, - link2.output_id == aliased_walk.c.ancestor_id, - ) - ).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) + select(*selection_union_list + ).select_from(join( + aliased_walk, + link2, + link2.output_id == aliased_walk.c.ancestor_id, + )).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) # I can't follow RETURN or CALL links ) ) diff --git a/aiida/orm/implementation/sqlalchemy/querybuilder/main.py b/aiida/orm/implementation/sqlalchemy/querybuilder/main.py index 02c43aa80c..1c6848156d 100644 --- a/aiida/orm/implementation/sqlalchemy/querybuilder/main.py +++ b/aiida/orm/implementation/sqlalchemy/querybuilder/main.py @@ -19,8 +19,7 @@ from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.exc import SAWarning from sqlalchemy.ext.compiler import compiles -from sqlalchemy.orm import aliased, loading -from sqlalchemy.orm.context import ORMCompileState, QueryContext +from sqlalchemy.orm import aliased from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session from sqlalchemy.orm.util import AliasedClass @@ -72,35 +71,6 @@ def compile(element, compiler: TypeCompiler, **kwargs): # pylint: disable=funct return f'jsonb_typeof({compiler.process(element.clauses, **kwargs)})' -def _orm_setup_cursor_result( - session, - statement, - params, - execution_options, - bind_arguments, - result, -): - """Patched class method.""" - execution_context = result.context - compile_state = execution_context.compiled.compile_state - - # this is the patch required for turning off de-duplication of results - compile_state._has_mapper_entities = False # pylint: disable=protected-access - - load_options = execution_options.get('_sa_orm_load_options', QueryContext.default_load_options) - - querycontext = QueryContext( - compile_state, - statement, - params, - session, - load_options, - execution_options, - bind_arguments, - ) - return loading.instances(result, querycontext) - - class SqlaQueryBuilder(BackendQueryBuilder): """ QueryBuilder to use with SQLAlchemy-backend and @@ -229,7 +199,9 @@ def iterall(self, data: QueryDictType, batch_size: Optional[int]) -> Iterable[Li """Return an iterator over all the results of a list of lists.""" with self.use_query(data) as query: - for resultrow in query.yield_per(batch_size): # type: ignore[arg-type] # pylint: disable=not-an-iterable + stmt = query.statement.execution_options(yield_per=batch_size) + + for resultrow in self.get_session().execute(stmt): # we discard the first item of the result row, # which is what the query was initialised with # and not one of the requested projection (see self._build) @@ -240,7 +212,9 @@ def iterdict(self, data: QueryDictType, batch_size: Optional[int]) -> Iterable[D """Return an iterator over all the results of a list of dictionaries.""" with self.use_query(data) as query: - for row in query.yield_per(batch_size): # type: ignore[arg-type] # pylint: disable=not-an-iterable + stmt = query.statement.execution_options(yield_per=batch_size) + + for row in self.get_session().execute(stmt): # build the yield result yield_result: Dict[str, Dict[str, Any]] = {} for tag, projected_entities_dict in self._tag_to_projected_fields.items(): @@ -255,20 +229,12 @@ def iterdict(self, data: QueryDictType, batch_size: Optional[int]) -> Iterable[D @contextmanager def use_query(self, data: QueryDictType) -> Iterator[Query]: """Yield the built query.""" - # Currently, a monkey-patch is required to turn off de-duplication of results, - # carried out in the `use_query` method - # see: https://github.com/sqlalchemy/sqlalchemy/issues/4395#issuecomment-907293360 - # THIS CAN BE REMOVED WHEN MOVING TO THE VERSION 2 API - existing_func = ORMCompileState.orm_setup_cursor_result - ORMCompileState.orm_setup_cursor_result = _orm_setup_cursor_result # type: ignore[assignment] query = self._update_query(data) try: yield query except Exception: self.get_session().close() raise - finally: - ORMCompileState.orm_setup_cursor_result = existing_func # type: ignore[assignment] def _update_query(self, data: QueryDictType) -> Query: """Return the sqlalchemy.orm.Query instance for the current query specification. @@ -366,7 +332,9 @@ def _build(self) -> Query: alias = self._get_tag_alias(tag) except KeyError: raise ValueError(f'Unknown tag {tag!r} in filters, known: {list(self._tag_to_alias)}') - self._query = self._query.filter(self.build_filters(alias, filter_specs)) + filters = self.build_filters(alias, filter_specs) + if filters is not None: + self._query = self._query.filter(filters) # PROJECTIONS ########################## @@ -601,7 +569,7 @@ def get_column(colname: str, alias: AliasedClass) -> InstrumentedAttribute: '{}'.format(colname, alias, '\n'.join(alias._sa_class_manager.mapper.c.keys())) # pylint: disable=protected-access ) from exc - def build_filters(self, alias: AliasedClass, filter_spec: Dict[str, Any]) -> BooleanClauseList: + def build_filters(self, alias: AliasedClass, filter_spec: Dict[str, Any]) -> Optional[BooleanClauseList]: # pylint: disable=too-many-branches """Recurse through the filter specification and apply filter operations. :param alias: The alias of the ORM class the filter will be applied on @@ -612,17 +580,20 @@ def build_filters(self, alias: AliasedClass, filter_spec: Dict[str, Any]) -> Boo expressions: List[Any] = [] for path_spec, filter_operation_dict in filter_spec.items(): if path_spec in ('and', 'or', '~or', '~and', '!and', '!or'): - subexpressions = [ - self.build_filters(alias, sub_filter_spec) for sub_filter_spec in filter_operation_dict - ] - if path_spec == 'and': - expressions.append(and_(*subexpressions)) - elif path_spec == 'or': - expressions.append(or_(*subexpressions)) - elif path_spec in ('~and', '!and'): - expressions.append(not_(and_(*subexpressions))) - elif path_spec in ('~or', '!or'): - expressions.append(not_(or_(*subexpressions))) + subexpressions = [] + for sub_filter_spec in filter_operation_dict: + filters = self.build_filters(alias, sub_filter_spec) + if filters is not None: + subexpressions.append(filters) + if subexpressions: + if path_spec == 'and': + expressions.append(and_(*subexpressions)) + elif path_spec == 'or': + expressions.append(or_(*subexpressions)) + elif path_spec in ('~and', '!and'): + expressions.append(not_(and_(*subexpressions))) + elif path_spec in ('~or', '!or'): + expressions.append(not_(or_(*subexpressions))) else: column_name = path_spec.split('.')[0] @@ -650,7 +621,7 @@ def build_filters(self, alias: AliasedClass, filter_spec: Dict[str, Any]) -> Boo alias=alias ) ) - return and_(*expressions) + return and_(*expressions) if expressions else None def modify_expansions(self, alias: AliasedClass, expansions: List[str]) -> List[str]: """Modify names of projections if `**` was specified. diff --git a/aiida/orm/implementation/sqlalchemy/utils.py b/aiida/orm/implementation/sqlalchemy/utils.py index 6a1f9f654d..42607c31c4 100644 --- a/aiida/orm/implementation/sqlalchemy/utils.py +++ b/aiida/orm/implementation/sqlalchemy/utils.py @@ -146,7 +146,7 @@ def _in_transaction(): :return: boolean, True if currently in open transaction, False otherwise. """ - return get_scoped_session().transaction.nested + return get_scoped_session().in_nested_transaction() @contextlib.contextmanager diff --git a/tests/backends/aiida_sqlalchemy/test_migrations.py b/tests/backends/aiida_sqlalchemy/test_migrations.py index 81036ece0c..707cac77dc 100644 --- a/tests/backends/aiida_sqlalchemy/test_migrations.py +++ b/tests/backends/aiida_sqlalchemy/test_migrations.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-lines +# pylint: disable=too-many-lines,protected-access """Tests for the migration engine (Alembic) as well as for the AiiDA migrations for SQLAlchemy.""" from contextlib import contextmanager @@ -779,7 +779,7 @@ def setUpBeforeMigration(self): param_data = session.query(DbLog).filter(DbLog.objpk == param.id ).filter(DbLog.objname == 'something.else.' ).with_entities(*cols_to_project).one() - serialized_param_data = dumps_json([(dict(list(zip(param_data.keys(), param_data))))]) + serialized_param_data = dumps_json([param_data._asdict()]) # Getting the serialized logs for the unknown entity logs (as the export migration fuction # provides them) - this should coincide to the above serialized_unknown_exp_logs = log_migration.get_serialized_unknown_entity_logs(connection) @@ -792,7 +792,7 @@ def setUpBeforeMigration(self): leg_wf = session.query(DbLog).filter(DbLog.objpk == leg_workf.id).filter( DbLog.objname == 'aiida.workflows.user.topologicalworkflows.topo.TopologicalWorkflow' ).with_entities(*cols_to_project).one() - serialized_leg_wf_logs = dumps_json([(dict(list(zip(leg_wf.keys(), leg_wf))))]) + serialized_leg_wf_logs = dumps_json([leg_wf._asdict()]) # Getting the serialized logs for the legacy workflow logs (as the export migration function # provides them) - this should coincide to the above serialized_leg_wf_exp_logs = log_migration.get_serialized_legacy_workflow_logs(connection) @@ -803,9 +803,7 @@ def setUpBeforeMigration(self): # Getting the serialized logs that don't correspond to a DbNode record logs_no_node = session.query(DbLog).filter( DbLog.id.in_([log_5.id, log_6.id])).with_entities(*cols_to_project) - logs_no_node_list = list() - for log_no_node in logs_no_node: - logs_no_node_list.append((dict(list(zip(log_no_node.keys(), log_no_node))))) + logs_no_node_list = [log_no_node._asdict() for log_no_node in logs_no_node] serialized_logs_no_node = dumps_json(logs_no_node_list) # Getting the serialized logs that don't correspond to a node (as the export migration function diff --git a/tests/backends/aiida_sqlalchemy/test_schema.py b/tests/backends/aiida_sqlalchemy/test_schema.py index 1aa5341efc..bffa84dc7e 100644 --- a/tests/backends/aiida_sqlalchemy/test_schema.py +++ b/tests/backends/aiida_sqlalchemy/test_schema.py @@ -9,6 +9,10 @@ ########################################################################### # pylint: disable=import-error,no-name-in-module """Test object relationships in the database.""" +import warnings + +from sqlalchemy import exc as sa_exc + from aiida.backends.testbase import AiidaTestCase from aiida.backends.sqlalchemy.models.user import DbUser from aiida.backends.sqlalchemy.models.node import DbNode @@ -111,9 +115,6 @@ def test_user_node_2(self): storing USER does NOT induce storage of the NODE Assert the correct storage of user and node.""" - import warnings - from sqlalchemy import exc as sa_exc - # Create user dbu1 = DbUser('tests2@schema', 'spam', 'eggs', 'monty') @@ -164,7 +165,10 @@ def test_user_node_3(self): # Add only first node and commit session.add(dbn_1) - session.commit() + with warnings.catch_warnings(): + # suppress known SAWarning that we have not added dbn_2 + warnings.simplefilter('ignore', category=sa_exc.SAWarning) + session.commit() # Check for which object a pk has been assigned, which means that # things have been at least flushed into the database @@ -200,7 +204,10 @@ def test_user_node_4(self): # Add only first node and commit session.add(dbn_1) - session.commit() + with warnings.catch_warnings(): + # suppress known SAWarning that we have not add the other nodes + warnings.simplefilter('ignore', category=sa_exc.SAWarning) + session.commit() # Check for which object a pk has been assigned, which means that # things have been at least flushed into the database diff --git a/tests/backends/aiida_sqlalchemy/test_session.py b/tests/backends/aiida_sqlalchemy/test_session.py index c868ea37ed..8707134cf6 100644 --- a/tests/backends/aiida_sqlalchemy/test_session.py +++ b/tests/backends/aiida_sqlalchemy/test_session.py @@ -164,7 +164,7 @@ def test_node_access_with_sessions(self): self.assertIsNot(master_session, custom_session) # Manually load the DbNode in a different session - dbnode_reloaded = custom_session.query(sa.models.node.DbNode).get(node.id) + dbnode_reloaded = custom_session.get(sa.models.node.DbNode, node.id) # Now, go through one by one changing the possible attributes (of the model) # and check that they're updated when the user reads them from the aiida node diff --git a/tests/backends/aiida_sqlalchemy/test_utils.py b/tests/backends/aiida_sqlalchemy/test_utils.py index 1235ae4e4b..398e1122b5 100644 --- a/tests/backends/aiida_sqlalchemy/test_utils.py +++ b/tests/backends/aiida_sqlalchemy/test_utils.py @@ -58,8 +58,8 @@ def database_exists(url): try: if engine.dialect.name == 'postgresql': - text = f"SELECT 1 FROM pg_database WHERE datname='{database}'" - return bool(engine.execute(text).scalar()) + text = sa.text(f"SELECT 1 FROM pg_database WHERE datname='{database}'") + return bool(engine.connect().execute(text).scalar()) raise Exception('Only PostgreSQL is supported.') finally: @@ -98,9 +98,9 @@ def create_database(url, encoding='utf8'): from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT engine.raw_connection().set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) - text = f"CREATE DATABASE {quote(engine, database)} ENCODING '{encoding}'" - - engine.execute(text) + text = sa.text(f"CREATE DATABASE {quote(engine, database)} ENCODING '{encoding}'") + with engine.begin() as connection: + connection.execute(text) else: raise Exception('Only PostgreSQL with the psycopg2 driver is supported.') diff --git a/tests/orm/test_querybuilder.py b/tests/orm/test_querybuilder.py index c596062aa1..43159b020e 100644 --- a/tests/orm/test_querybuilder.py +++ b/tests/orm/test_querybuilder.py @@ -798,7 +798,7 @@ class TestQueryBuilderCornerCases: In this class corner cases of QueryBuilder are added. """ - def test_computer_json(self): # pylint: disable=no-self-use + def test_computer_json(self): """ In this test we check the correct behavior of QueryBuilder when retrieving the _metadata with no content. @@ -818,6 +818,14 @@ def test_computer_json(self): # pylint: disable=no-self-use qb.append(orm.Computer, project=['id', 'metadata'], outerjoin=True, with_node='calc') qb.all() + def test_empty_filters(self): + """Test that an empty filter is correctly handled.""" + orm.Data().store() + qb = orm.QueryBuilder().append(orm.Data, filters={}) + assert qb.count() == 1 + qb = orm.QueryBuilder().append(orm.Data, filters={'or': [{}, {}]}) + assert qb.count() == 1 + @pytest.mark.usefixtures('clear_database_before_test') class TestAttributes: