Skip to content

Commit

Permalink
implement bulk_update's extend json option
Browse files Browse the repository at this point in the history
  • Loading branch information
rabbull committed Dec 10, 2024
1 parent ec52f4e commit 86323d2
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 17 deletions.
4 changes: 3 additions & 1 deletion src/aiida/orm/implementation/storage_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,14 @@ def bulk_insert(self, entity_type: 'EntityTypes', rows: List[dict], allow_defaul
"""

@abc.abstractmethod
def bulk_update(self, entity_type: 'EntityTypes', rows: List[dict]) -> None:
def bulk_update(self, entity_type: 'EntityTypes', rows: List[dict], extend_json: bool = False) -> None:
"""Update a list of entities in the database, directly with a backend transaction.
:param entity_type: The type of the entity
:param data: A list of dictionaries, containing fields of the backend model to update,
and the `id` field (a.k.a primary key)
:param extend_json: A boolean flag indicating if updates on JSON fields are treated as an extension,
instead of overwriting the entire JSON object
:raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table
"""
Expand Down
43 changes: 36 additions & 7 deletions src/aiida/storage/psql_dos/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
import functools
import gc
import pathlib
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence, Set, Union
from typing import TYPE_CHECKING, Iterator, Optional, Sequence, Set, Union
import json

from disk_objectstore import Container, backup_utils
from pydantic import BaseModel, Field
from sqlalchemy import column, insert, update
from sqlalchemy import case, cast, column, func, insert, update
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Session, scoped_session, sessionmaker

from aiida.common import exceptions
Expand Down Expand Up @@ -314,7 +317,7 @@ def _get_mapper_from_entity(entity_type: EntityTypes, with_pk: bool):
keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key}
return mapper, keys

def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults: bool = False) -> List[int]:
def bulk_insert(self, entity_type: EntityTypes, rows: list[dict], allow_defaults: bool = False) -> list[int]:
mapper, keys = self._get_mapper_from_entity(entity_type, False)
if not rows:
return []
Expand All @@ -337,18 +340,44 @@ def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults
result = session.execute(insert(mapper).returning(mapper, column('id')), rows).fetchall()
return [row.id for row in result]

def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None:
def bulk_update(self, entity_type: EntityTypes, rows: list[dict], extend_json: bool = False) -> None:
mapper, keys = self._get_mapper_from_entity(entity_type, True)
if not rows:
return None

to_json = functools.partial(cast, type_=JSONB)
if self.get_session().bind.dialect.name == 'sqlite':
# TODO: A dirty workaround:
# SQLite DOS now doesn't have a dedicated background, and SQLite don't have JSONB type,
# so the casting need to be implement specifically.
to_json = json.dumps

cases = defaultdict(list)
id_list = []
for row in rows:
if 'id' not in row:
raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}")
if not keys.issuperset(row):
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}')
if 'id' in row:
when = mapper.c.id == row['id']
id_list.append(row['id'])
else:
raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}")

for key, value in row.items():
if key == 'id':
continue

update_value = value
if extend_json and key in ['extra', 'attributes']:
update_value = func.json_patch(mapper.c[key], to_json(value))
print(key, update_value, mapper.c[key].type)
cases[key].append((when, update_value))

session = self.get_session()
with nullcontext() if self.in_transaction else self.transaction():
session.execute(update(mapper), rows)
values = {k: case(*v, else_=mapper.c[key]) for k, v in cases.items()}
stmt = update(mapper).where(mapper.c.id.in_(id_list)).values(**values)
session.execute(stmt)

def delete(self, delete_database_user: bool = False) -> None:
"""Delete the storage and all the data.
Expand Down
46 changes: 45 additions & 1 deletion src/aiida/storage/psql_dos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import json
from typing import TypedDict
from sqlalchemy import event


class PsqlConfig(TypedDict, total=False):
Expand All @@ -25,6 +26,47 @@ class PsqlConfig(TypedDict, total=False):
"""keyword argument that will be passed on to the SQLAlchemy engine."""


# Adapted from https://stackoverflow.com/a/79133234/9184828
JSONB_PATCH_FUNCTION = """
CREATE OR REPLACE FUNCTION json_patch (
target jsonb, -- target JSON value
patch jsonb -- patch JSON value
)
RETURNS jsonb
LANGUAGE plpgsql
IMMUTABLE AS $$
BEGIN
-- If the patch is not a JSON object, return the patch as the result (base case)
IF patch isnull or jsonb_typeof(patch) != 'object' THEN
RETURN patch;
END IF;
-- If the target is not an object, set it to an empty object
IF target isnull or jsonb_typeof(target) != 'object' THEN
target := '{}';
END IF;
RETURN coalesce(
jsonb_object_agg(
coalesce(targetKey, patchKey), -- there will be either one or both keys equal
CASE
WHEN patchKey isnull THEN targetValue -- key missing in patch - retain target value
ELSE json_patch(targetValue, patchValue)
END
),
'{}'::jsonb -- if SELECT will return no keys (empty table), then jsonb_object_agg will return NULL, need to return {} in that case
)
FROM jsonb_each(target) temp1(targetKey, targetValue)
FULL JOIN jsonb_each(patch) temp2(patchKey, patchValue)
ON targetKey = patchKey
WHERE jsonb_typeof(patchValue) != 'null' OR patchValue isnull; -- remove keys which are set to null in patch object
END;
$$;
""".strip()

def register_jsonb_patch_function(conn, *args, **kwargs):
print('reg', conn.execute(JSONB_PATCH_FUNCTION))

def create_sqlalchemy_engine(config: PsqlConfig):
"""Create SQLAlchemy engine (to be used for QueryBuilder queries)
Expand All @@ -50,12 +92,14 @@ def create_sqlalchemy_engine(config: PsqlConfig):
port=config['database_port'],
name=config['database_name'],
)
return create_engine(
engine = create_engine(
engine_url,
json_serializer=json.dumps,
json_deserializer=json.loads,
**config.get('engine_kwargs', {}),
)
event.listen(engine, 'connect', register_jsonb_patch_function)
return engine


def create_scoped_session_factory(engine, **kwargs):
Expand Down
30 changes: 25 additions & 5 deletions src/aiida/storage/sqlite_temp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
import hashlib
import os
import shutil
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from pathlib import Path
from tempfile import mkdtemp
from typing import Any, BinaryIO, Iterator, Sequence

from pydantic import BaseModel, Field
from sqlalchemy import column, insert, update
from sqlalchemy import column, func, insert, update
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import case

from aiida.common.exceptions import ClosedStorage, IntegrityError
from aiida.manage.configuration import Profile
Expand Down Expand Up @@ -268,18 +270,36 @@ def bulk_insert(self, entity_type: EntityTypes, rows: list[dict], allow_defaults
result = session.execute(insert(mapper).returning(mapper, column('id')), rows).fetchall()
return [row.id for row in result]

def bulk_update(self, entity_type: EntityTypes, rows: list[dict]) -> None:
def bulk_update(self, entity_type: EntityTypes, rows: list[dict], extend_json: bool = False) -> None:
mapper, keys = self._get_mapper_from_entity(entity_type, True)
if not rows:
return None

cases = defaultdict(list)
id_list = []
for row in rows:
if 'id' not in row:
raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}")
if not keys.issuperset(row):
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}')
if 'id' in row:
when = mapper.c.id == row['id']
id_list.append(row['id'])
else:
raise IntegrityError(f"neither 'id' nor 'uuid' field given for {entity_type}: {set(row)}")

for key, value in row.items():
if key == 'id':
continue

update_value = value
if extend_json and key in ['extra', 'attributes']:
update_value = func.json_patch(mapper.c[key], value)
cases[key].append((when, update_value))

session = self.get_session()
with nullcontext() if self.in_transaction else self.transaction():
session.execute(update(mapper), rows)
values = {k: case(*v, else_=mapper.c[key]) for k, v in cases.items()}
stmt = update(mapper).where(mapper.c.id.in_(id_list)).values(**values)
session.execute(stmt)

def delete(self) -> None:
"""Delete the storage and all the data."""
Expand Down
2 changes: 1 addition & 1 deletion src/aiida/storage/sqlite_zip/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def in_transaction(self) -> bool:
def bulk_insert(self, entity_type: EntityTypes, rows: list[dict], allow_defaults: bool = False) -> list[int]:
raise ReadOnlyError()

def bulk_update(self, entity_type: EntityTypes, rows: list[dict]) -> None:
def bulk_update(self, entity_type: EntityTypes, rows: list[dict], extend_json: bool = False) -> None:
raise ReadOnlyError()

def delete(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,7 @@ def event_loop(manager):
@pytest.fixture
def backend(manager):
"""Get the ``Backend`` storage instance of the currently loaded profile."""
print(manager)
return manager.get_profile_storage()


Expand Down
67 changes: 65 additions & 2 deletions tests/orm/implementation/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TestBackend:
def init_test(self, backend):
"""Set up the backend."""
self.backend = backend
print(backend)

def test_transaction_nesting(self):
"""Test that transaction nesting works."""
Expand Down Expand Up @@ -111,8 +112,8 @@ def test_bulk_update(self):
prefix = uuid.uuid4().hex
users = [orm.User(f'{prefix}-{i}').store() for i in range(3)]
# should raise if the 'id' field is not present
with pytest.raises(exceptions.IntegrityError, match="'id' field not given"):
self.backend.bulk_update(EntityTypes.USER, [{'email': 'other'}])
# with pytest.raises(exceptions.IntegrityError, match="'id' field not given"):
# self.backend.bulk_update(EntityTypes.USER, [{'email': 'other'}])
# should raise if a non-existent field is present
with pytest.raises(exceptions.IntegrityError, match='Incorrect fields'):
self.backend.bulk_update(EntityTypes.USER, [{'id': users[0].pk, 'x': 'other'}])
Expand All @@ -123,6 +124,68 @@ def test_bulk_update(self):
assert users[1].email == 'other1'
assert users[2].email == f'{prefix}-2'

def test_bulk_update_extend_json(self):
prefix = uuid.uuid4().hex
nodes = [
orm.Dict(
{
'key-string': f'{prefix}-{index}',
'key-integer': index,
'key-null': None,
'key-object': {'k1': 'v1', 'k2': 2},
'key-array': [11, 45, 14],
}
).store()
for index in range(5)
]
self.backend.bulk_update(
EntityTypes.NODE,
[
{
'id': nodes[0].pk,
'attributes': {
'key-new': 'foobar',
},
},
{
'id': nodes[1].pk,
'attributes': {
'key-string': ['change type'],
'key-array': [1919, 810],
},
},
{
'id': nodes[2].pk,
'attributes': {
'key-integer': -1,
'key-object': {'k2': 114514},
},
},
],
extend_json=True,
)

# new attribute is added
assert nodes[0].get('key-new') == 'foobar'
# old attributes are kept
assert nodes[0].get('key-string') == f'{prefix}-0'
assert nodes[0].get('key-null') is None
assert nodes[0].get('key-integer') == 0
assert nodes[0].get('key-object') == {'k1': 'v1', 'k2': 2}
assert len(nodes[0].get('key-array')) == 3
assert all(x == y for x, y in zip(nodes[0].get('key-array'), [11, 45, 14]))
# change type
assert isinstance(nodes[1].get('key-string'), list)
assert len(nodes[1].get('key-string')) == 1
assert nodes[1].get('key-string')[0] == 'change type'
# overwrite array
assert len(nodes[1].get('key-array')) == 2
assert all(x == y for x, y in zip(nodes[1].get('key-array'), [1919, 810]))
# overwrite integer
assert nodes[2].get('key-integer') == -1
# merge object
assert nodes[2].get('key-object') == {'k1': 'v1', 'k2': 114514}

def test_bulk_update_in_transaction(self):
"""Test that bulk update in a cancelled transaction is not committed."""
prefix = uuid.uuid4().hex
Expand Down

0 comments on commit 86323d2

Please sign in to comment.