Skip to content

Commit

Permalink
🔀 MERGE: SQLAlchemy v1.4 (v2 API) (#5122)
Browse files Browse the repository at this point in the history
These commits follow https://docs.sqlalchemy.org/en/14/changelog/migration_20.html
(see also https://docs.sqlalchemy.org/en/14/errors.html):

- Added `SQLALCHEMY_WARN_20` environmental variable
- Adressed all resulting warnings
- Add `future=True` flag for engine and session creation (V1 -> v2 API)

Each commit primarily address a single `SQLALCHEMY_WARN_20` warning.
  • Loading branch information
chrisjsewell authored Sep 16, 2021
2 parents a295616 + f926619 commit b31e5f5
Show file tree
Hide file tree
Showing 23 changed files with 139 additions and 150 deletions.
17 changes: 8 additions & 9 deletions aiida/backends/sqlalchemy/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,19 @@ def run_migrations_online():
from aiida.backends.sqlalchemy.models.base import Base
config = context.config # pylint: disable=no-member

connectable = config.attributes.get('connection', None)
connection = config.attributes.get('connection', None)

if connectable is None:
if connection is None:
from aiida.common.exceptions import ConfigurationError
raise ConfigurationError('An initialized connection is expected for the AiiDA online migrations.')

with connectable.connect() as connection:
context.configure( # pylint: disable=no-member
connection=connection,
target_metadata=Base.metadata,
transaction_per_migration=True,
)
context.configure( # pylint: disable=no-member
connection=connection,
target_metadata=Base.metadata,
transaction_per_migration=True,
)

context.run_migrations() # pylint: disable=no-member
context.run_migrations() # pylint: disable=no-member


try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
# pylint: disable=invalid-name,no-member,import-error,no-name-in-module
# pylint: disable=invalid-name,no-member,import-error,no-name-in-module,protected-access
"""This migration cleans the log records from non-Node entity records.
It removes from the DbLog table the legacy workflow records and records
that correspond to an unknown entity and places them to corresponding files.
Expand Down Expand Up @@ -95,7 +95,7 @@ def get_serialized_legacy_workflow_logs(connection):
)
res = list()
for row in query:
res.append(dict(list(zip(row.keys(), row))))
res.append(row._asdict())
return dumps_json(res)


Expand All @@ -114,7 +114,7 @@ def get_serialized_unknown_entity_logs(connection):
)
res = list()
for row in query:
res.append(dict(list(zip(row.keys(), row))))
res.append(row._asdict())
return dumps_json(res)


Expand All @@ -133,7 +133,7 @@ def get_serialized_logs_with_no_nodes(connection):
)
res = list()
for row in query:
res.append(dict(list(zip(row.keys(), row))))
res.append(row._asdict())
return dumps_json(res)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def upgrade():
column('attributes', JSONB))

nodes = connection.execute(
select([DbNode.c.id, DbNode.c.uuid]).where(
select(DbNode.c.id, DbNode.c.uuid).where(
DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall()

for pk, uuid in nodes:
Expand All @@ -64,7 +64,7 @@ def downgrade():
column('attributes', JSONB))

nodes = connection.execute(
select([DbNode.c.id, DbNode.c.uuid]).where(
select(DbNode.c.id, DbNode.c.uuid).where(
DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall()

for pk, _ in nodes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def export_workflow_data(connection):
DbWorkflowData = table('db_dbworkflowdata')
DbWorkflowStep = table('db_dbworkflowstep')

count_workflow = connection.execute(select([func.count()]).select_from(DbWorkflow)).scalar()
count_workflow_data = connection.execute(select([func.count()]).select_from(DbWorkflowData)).scalar()
count_workflow_step = connection.execute(select([func.count()]).select_from(DbWorkflowStep)).scalar()
count_workflow = connection.execute(select(func.count()).select_from(DbWorkflow)).scalar()
count_workflow_data = connection.execute(select(func.count()).select_from(DbWorkflowData)).scalar()
count_workflow_step = connection.execute(select(func.count()).select_from(DbWorkflowStep)).scalar()

# Nothing to do if all tables are empty
if count_workflow == 0 and count_workflow_data == 0 and count_workflow_step == 0:
Expand All @@ -78,9 +78,9 @@ def export_workflow_data(connection):
delete_on_close = configuration.PROFILE.is_test_profile

data = {
'workflow': [dict(row) for row in connection.execute(select(['*']).select_from(DbWorkflow))],
'workflow_data': [dict(row) for row in connection.execute(select(['*']).select_from(DbWorkflowData))],
'workflow_step': [dict(row) for row in connection.execute(select(['*']).select_from(DbWorkflowStep))],
'workflow': [dict(row) for row in connection.execute(select('*').select_from(DbWorkflow))],
'workflow_data': [dict(row) for row in connection.execute(select('*').select_from(DbWorkflowData))],
'workflow_step': [dict(row) for row in connection.execute(select('*').select_from(DbWorkflowStep))],
}

with NamedTemporaryFile(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def upgrade():
)

profile = get_profile()
node_count = connection.execute(select([func.count()]).select_from(DbNode)).scalar()
node_count = connection.execute(select(func.count()).select_from(DbNode)).scalar()
missing_repo_folder = []
shard_count = 256

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def migrate_infer_calculation_entry_point(connection):
column('process_type', String)
)

query_set = connection.execute(select([DbNode.c.type]).where(DbNode.c.type.like('calculation.%'))).fetchall()
query_set = connection.execute(select(DbNode.c.type).where(DbNode.c.type.like('calculation.%'))).fetchall()
type_strings = set(entry[0] for entry in query_set)
mapping_node_type_to_entry_point = infer_calculation_entry_point(type_strings=type_strings)

Expand All @@ -54,7 +54,7 @@ def migrate_infer_calculation_entry_point(connection):
# All affected entries should be logged to file that the user can consult.
if ENTRY_POINT_STRING_SEPARATOR not in entry_point_string:
query_set = connection.execute(
select([DbNode.c.uuid]).where(DbNode.c.type == op.inline_literal(type_string))
select(DbNode.c.uuid).where(DbNode.c.type == op.inline_literal(type_string))
).fetchall()

uuids = [str(entry.uuid) for entry in query_set]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def upgrade():
"""Migrations for the upgrade."""
op.drop_table('db_dbpath')
conn = op.get_bind()
conn.execute('DROP TRIGGER IF EXISTS autoupdate_tc ON db_dblink')
conn.execute('DROP FUNCTION IF EXISTS update_tc()')
conn.execute(sa.text('DROP TRIGGER IF EXISTS autoupdate_tc ON db_dblink'))
conn.execute(sa.text('DROP FUNCTION IF EXISTS update_tc()'))


def downgrade():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def upgrade():
column('attributes', JSONB))

nodes = connection.execute(
select([DbNode.c.id, DbNode.c.uuid]).where(
select(DbNode.c.id, DbNode.c.uuid).where(
DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall()

for pk, uuid in nodes:
Expand All @@ -61,11 +61,11 @@ def downgrade():
column('attributes', JSONB))

nodes = connection.execute(
select([DbNode.c.id, DbNode.c.uuid]).where(
select(DbNode.c.id, DbNode.c.uuid).where(
DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall()

for pk, uuid in nodes:
attributes = connection.execute(select([DbNode.c.attributes]).where(DbNode.c.id == pk)).fetchone()
attributes = connection.execute(select(DbNode.c.attributes).where(DbNode.c.id == pk)).fetchone()
symbols = numpy.array(attributes['symbols'])
utils.store_numpy_array_in_repository(uuid, 'symbols', symbols)
key = op.inline_literal('{"array|symbols"}')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def set_new_uuid(connection):
from aiida.common.utils import get_new_uuid

# Exit if there are no rows - e.g. initial setup
id_query = connection.execute('SELECT db_dblog.id FROM db_dblog')
id_query = connection.execute(sa.text('SELECT db_dblog.id FROM db_dblog'))
if id_query.rowcount == 0:
return

Expand All @@ -52,7 +52,7 @@ def set_new_uuid(connection):
UPDATE db_dblog as t SET
uuid = uuid(c.uuid)
from (values {key_values}) as c(id, uuid) where c.id = t.id"""
connection.execute(update_stm)
connection.execute(sa.text(update_stm))


def upgrade():
Expand Down
2 changes: 1 addition & 1 deletion aiida/backends/sqlalchemy/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"""Base SQLAlchemy models."""

from sqlalchemy import orm
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm.exc import UnmappedClassError

import aiida.backends.sqlalchemy
Expand Down
10 changes: 7 additions & 3 deletions aiida/backends/sqlalchemy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def install_tc(session):
"""
Install the transitive closure table with SqlAlchemy.
"""
from sqlalchemy import text

links_table_name = 'db_dblink'
links_table_input_field = 'input_id'
links_table_output_field = 'output_id'
Expand All @@ -68,9 +70,11 @@ def install_tc(session):
closure_table_child_field = 'child_id'

session.execute(
get_pg_tc(
links_table_name, links_table_input_field, links_table_output_field, closure_table_name,
closure_table_parent_field, closure_table_child_field
text(
get_pg_tc(
links_table_name, links_table_input_field, links_table_output_field, closure_table_name,
closure_table_parent_field, closure_table_child_field
)
)
)

Expand Down
4 changes: 2 additions & 2 deletions aiida/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ def create_sqlalchemy_engine(profile, **kwargs):
name=profile.database_name
)
return create_engine(
engine_url, json_serializer=json.dumps, json_deserializer=json.loads, future=False, encoding='utf-8', **kwargs
engine_url, json_serializer=json.dumps, json_deserializer=json.loads, future=True, encoding='utf-8', **kwargs
)


def create_scoped_session_factory(engine, **kwargs):
"""Create scoped SQLAlchemy session factory"""
from sqlalchemy.orm import scoped_session, sessionmaker
return scoped_session(sessionmaker(bind=engine, **kwargs))
return scoped_session(sessionmaker(bind=engine, future=True, **kwargs))


def delete_nodes_and_connections(pks):
Expand Down
22 changes: 9 additions & 13 deletions aiida/orm/implementation/sqlalchemy/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,13 @@ def transaction(self):
entering. Transactions can be nested.
"""
session = self.get_session()
nested = session.transaction.nested
try:
session.begin_nested()
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
if not nested:
# Make sure to commit the outermost session
session.commit()
if session.in_transaction():
with session.begin_nested():
yield session
else:
with session.begin():
with session.begin_nested():
yield session

@staticmethod
def get_session():
Expand Down Expand Up @@ -131,10 +126,11 @@ def execute_raw(self, query):
:param query: a string containing a raw SQL statement
:return: the result of the query
"""
from sqlalchemy import text
from sqlalchemy.exc import ResourceClosedError # pylint: disable=import-error,no-name-in-module

with self.transaction() as session:
queryset = session.execute(query)
queryset = session.execute(text(query))

try:
results = queryset.fetchall()
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/implementation/sqlalchemy/computers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def list_names():
def delete(self, pk):
try:
session = get_scoped_session()
session.query(DbComputer).get(pk).delete()
session.get(DbComputer, pk).delete()
session.commit()
except SQLAlchemyError as exc:
raise exceptions.InvalidOperation(
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/implementation/sqlalchemy/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,5 +367,5 @@ def query(
def delete(self, id): # pylint: disable=redefined-builtin
session = sa.get_scoped_session()

session.query(DbGroup).get(id).delete()
session.get(DbGroup, id).delete()
session.commit()
54 changes: 30 additions & 24 deletions aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ class SqlaJoiner:
"""A class containing the logic for SQLAlchemy entities joining entities."""

def __init__(
self, entity_mapper: _EntityMapper, filter_builder: Callable[[AliasedClass, FilterType], BooleanClauseList]
self, entity_mapper: _EntityMapper, filter_builder: Callable[[AliasedClass, FilterType],
Optional[BooleanClauseList]]
):
"""Initialise the class"""
self._entities = entity_mapper
Expand Down Expand Up @@ -185,7 +186,13 @@ def _join_descendants_recursive(
link1 = aliased(self._entities.Link)
link2 = aliased(self._entities.Link)
node1 = aliased(self._entities.Node)

link_filters = link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # follow input / create links
in_recursive_filters = self._build_filters(node1, filter_dict)
if in_recursive_filters is None:
filters = link_filters
else:
filters = and_(in_recursive_filters, link_filters)

selection_walk_list = [
link1.input_id.label('ancestor_id'),
Expand All @@ -195,12 +202,8 @@ def _join_descendants_recursive(
if expand_path:
selection_walk_list.append(array((link1.input_id, link1.output_id)).label('path'))

walk = select(selection_walk_list).select_from(join(node1, link1, link1.input_id == node1.id)).where(
and_(
in_recursive_filters, # I apply filters for speed here
link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # I follow input and create links
)
).cte(recursive=True)
walk = select(*selection_walk_list).select_from(join(node1, link1, link1.input_id == node1.id)
).where(filters).cte(recursive=True)

aliased_walk = aliased(walk)

Expand All @@ -214,13 +217,12 @@ def _join_descendants_recursive(

descendants_recursive = aliased(
aliased_walk.union_all(
select(selection_union_list).select_from(
join(
aliased_walk,
link2,
link2.input_id == aliased_walk.c.descendant_id,
)
).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)))
select(*selection_union_list
).select_from(join(
aliased_walk,
link2,
link2.input_id == aliased_walk.c.descendant_id,
)).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)))
)
) # .alias()

Expand Down Expand Up @@ -249,7 +251,13 @@ def _join_ancestors_recursive(
link1 = aliased(self._entities.Link)
link2 = aliased(self._entities.Link)
node1 = aliased(self._entities.Node)

link_filters = link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # follow input / create links
in_recursive_filters = self._build_filters(node1, filter_dict)
if in_recursive_filters is None:
filters = link_filters
else:
filters = and_(in_recursive_filters, link_filters)

selection_walk_list = [
link1.input_id.label('ancestor_id'),
Expand All @@ -259,9 +267,8 @@ def _join_ancestors_recursive(
if expand_path:
selection_walk_list.append(array((link1.output_id, link1.input_id)).label('path'))

walk = select(selection_walk_list).select_from(join(node1, link1, link1.output_id == node1.id)).where(
and_(in_recursive_filters, link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)))
).cte(recursive=True)
walk = select(*selection_walk_list).select_from(join(node1, link1, link1.output_id == node1.id)
).where(filters).cte(recursive=True)

aliased_walk = aliased(walk)

Expand All @@ -275,13 +282,12 @@ def _join_ancestors_recursive(

ancestors_recursive = aliased(
aliased_walk.union_all(
select(selection_union_list).select_from(
join(
aliased_walk,
link2,
link2.output_id == aliased_walk.c.ancestor_id,
)
).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)))
select(*selection_union_list
).select_from(join(
aliased_walk,
link2,
link2.output_id == aliased_walk.c.ancestor_id,
)).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)))
# I can't follow RETURN or CALL links
)
)
Expand Down
Loading

0 comments on commit b31e5f5

Please sign in to comment.