Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reset method toProgressReporterAbstract #4522

Merged
merged 6 commits into from
Oct 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions aiida/common/progress_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,25 @@ def __init__(self, *, total: int, desc: Optional[str] = None, **kwargs: Any):
:param desc: A description of the process

"""
self.total = total
self.desc = desc
self.increment = 0
self._total = total
self._desc = desc
self._increment: int = 0

@property
ltalirz marked this conversation as resolved.
Show resolved Hide resolved
def total(self) -> int:
"""Return the total iterations expected."""
return self._total

@property
def desc(self) -> Optional[str]:
"""Return the description of the process."""
return self._desc

@property
def n(self) -> int: # pylint: disable=invalid-name
"""Return the current iteration."""
# note using `n` as the attribute name is necessary for compatibility with tqdm
return self._increment

def __enter__(self) -> 'ProgressReporterAbstract':
"""Enter the contextmanager."""
Expand All @@ -71,15 +87,25 @@ def set_description_str(self, text: Optional[str] = None, refresh: bool = True):
:param refresh: Force refresh of the progress reporter

"""
self.desc = text
self._desc = text

def update(self, n: int = 1): # pylint: disable=invalid-name
"""Update the progress counter.

:param n: Increment to add to the internal counter of iterations

"""
self.increment += n
self._increment += n

def reset(self, total: Optional[int] = None):
"""Resets current iterations to 0.

:param total: If not None, update number of expected iterations.

"""
self._increment = 0
if total is not None:
self._total = total


class ProgressReporterNull(ProgressReporterAbstract):
Expand Down
127 changes: 64 additions & 63 deletions aiida/tools/importexport/archive/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import os
from pathlib import Path
from types import TracebackType
from typing import Any, cast, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type
from typing import Any, Callable, cast, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type
import zipfile
import tarfile

Expand All @@ -22,7 +22,6 @@
from aiida.common.log import AIIDA_LOGGER
from aiida.common.exceptions import InvalidOperation
from aiida.common.folders import Folder, SandboxFolder
from aiida.common.progress_reporter import get_progress_reporter
from aiida.tools.importexport.common.config import EXPORT_VERSION, ExportFileFormat, NODES_EXPORT_SUBFOLDER
from aiida.tools.importexport.common.exceptions import (CorruptArchive, IncompatibleArchiveVersionError)
from aiida.tools.importexport.archive.common import ArchiveMetadata
Expand All @@ -37,6 +36,7 @@
'ReaderJsonTar',
'ReaderJsonZip',
'get_reader',
'null_callback',
)

ARCHIVE_READER_LOGGER = AIIDA_LOGGER.getChild('archive.reader')
Expand All @@ -58,6 +58,10 @@ def get_reader(file_format: str) -> Type['ArchiveReaderAbstract']:
return cast(Type[ArchiveReaderAbstract], readers[file_format])


def null_callback(action: str, value: Any): # pylint: disable=unused-argument
"""A null callback function."""


class ArchiveReaderAbstract(ABC):
"""An abstract interface for AiiDA archive readers.

Expand Down Expand Up @@ -184,15 +188,21 @@ def iter_link_data(self) -> Iterator[dict]:
"""Iterate over links: {'input': <UUID>, 'output': <UUID>, 'label': <LABEL>, 'type': <TYPE>}"""

@abstractmethod
def iter_node_repos(self,
uuids: Iterable[str],
progress: bool = True,
description='Iterating node repos') -> Iterator[Folder]:
def iter_node_repos(
self,
uuids: Iterable[str],
callback: Callable[[str, Any], None] = null_callback,
) -> Iterator[Folder]:
"""Yield temporary folders containing the contents of the repository for each node.

:param uuids: UUIDs of the nodes over whose repository folders to iterate
:param progress: report progress
:param description: description for progress report
:param callback: a callback to report on the process, ``callback(action, value)``,
with the following callback signatures:

- ``callback('init', {'total': <int>, 'description': <str>})``,
to signal the start of a process, its total iterations and description
- ``callback('update', <int>)``,
to signal an update to the process and the number of iterations to progress

:raises `~aiida.tools.importexport.common.exceptions.CorruptArchive`: If the repository does not exist.
"""
Expand All @@ -204,7 +214,7 @@ def node_repository(self, uuid: str) -> Folder:

:raises `~aiida.tools.importexport.common.exceptions.CorruptArchive`: If the repository does not exist.
"""
return next(self.iter_node_repos([uuid], progress=False))
return next(self.iter_node_repos([uuid]))


class ReaderJsonBase(ArchiveReaderAbstract):
Expand Down Expand Up @@ -259,11 +269,17 @@ def _get_data(self):
"""Retrieve the data JSON."""
raise NotImplementedError()

def _extract(self, *, path_prefix: str, progress: bool):
def _extract(self, *, path_prefix: str, callback: Callable[[str, Any], None]):
"""Extract repository data to a temporary folder.

:param path_prefix: Only extract paths starting with this prefix.
:param progress: Whether to report progress of the extraction
:param callback: a callback to report on the process, ``callback(action, value)``,
with the following callback signatures:

- ``callback('init', {'total': <int>, 'description': <str>})``,
to signal the start of a process, its total iterations and description
- ``callback('update', <int>)``,
to signal an update to the process and the number of iterations to progress

:raises TypeError: if parameter types are not respected
"""
Expand Down Expand Up @@ -343,10 +359,11 @@ def iter_link_data(self) -> Iterator[dict]:
for value in self._get_data()['links_uuid']:
yield value

def iter_node_repos(self,
uuids: Iterable[str],
progress: bool = True,
description='Iterating node repositories') -> Iterator[Folder]:
def iter_node_repos(
self,
uuids: Iterable[str],
callback: Callable[[str, Any], None] = null_callback,
) -> Iterator[Folder]:
path_prefixes = [os.path.join(self.REPO_FOLDER, export_shard_uuid(uuid)) for uuid in uuids]

if not path_prefixes:
Expand All @@ -357,26 +374,17 @@ def iter_node_repos(self,
# unarchive the common folder if it does not exist
common_prefix = os.path.commonprefix(path_prefixes)
if not self._sandbox.get_subfolder(common_prefix).exists():
self._extract(path_prefix=common_prefix, progress=progress)

if progress:
with get_progress_reporter()(total=len(path_prefixes), desc=description) as report:
for uuid, path_prefix in zip(uuids, path_prefixes):
report.update()
subfolder = self._sandbox.get_subfolder(path_prefix)
if not subfolder.exists():
raise CorruptArchive(
f'Unable to find the repository folder for Node with UUID={uuid} in the archive'
)
yield subfolder
else:
for uuid, path_prefix in zip(uuids, path_prefixes):
subfolder = self._sandbox.get_subfolder(path_prefix)
if not subfolder.exists():
raise CorruptArchive(
f'Unable to find the repository folder for Node with UUID={uuid} in the exported file'
)
yield subfolder
self._extract(path_prefix=common_prefix, callback=callback)

callback('init', {'total': len(path_prefixes), 'description': 'Iterating node repositories'})
for uuid, path_prefix in zip(uuids, path_prefixes):
callback('update', 1)
subfolder = self._sandbox.get_subfolder(path_prefix)
if not subfolder.exists():
raise CorruptArchive(
f'Unable to find the repository folder for Node with UUID={uuid} in the exported file'
)
yield subfolder


class ReaderJsonZip(ReaderJsonBase):
Expand Down Expand Up @@ -410,20 +418,16 @@ def _get_data(self):
raise CorruptArchive(f'required file {self.FILENAME_DATA} is not included')
return self._data

def _extract(self, *, path_prefix: str, progress: bool):
def _extract(self, *, path_prefix: str, callback: Callable[[str, Any], None] = null_callback):
self.assert_within_context()
assert self._sandbox is not None # required by mypy
try:
with zipfile.ZipFile(self.filename, 'r', allowZip64=True) as handle:
members = [m for m in handle.namelist() if m.startswith(path_prefix)]
if progress:
with get_progress_reporter()(total=len(members), desc='Extracting repository files') as report:
for membername in members:
report.update()
handle.extract(path=self._sandbox.abspath, member=membername)
else:
for membername in members:
handle.extract(path=self._sandbox.abspath, member=membername)
callback('init', {'total': len(members), 'description': 'Extracting repository files'})
for membername in members:
callback('update', 1)
handle.extract(path=self._sandbox.abspath, member=membername)
except zipfile.BadZipfile as error:
raise CorruptArchive(f'The input file cannot be read: {error}')

Expand Down Expand Up @@ -457,29 +461,26 @@ def _get_data(self):
raise CorruptArchive(f'required file `{self.FILENAME_DATA}` is not included')
return self._data

def _extract(self, *, path_prefix: str, progress: bool):
def _extract(self, *, path_prefix: str, callback: Callable[[str, Any], None] = null_callback):
self.assert_within_context()
assert self._sandbox is not None # required by mypy
try:
with tarfile.open(self.filename, 'r:*', format=tarfile.PAX_FORMAT) as handle:
members = [m for m in handle.getmembers() if m.name.startswith(path_prefix)]
if progress:
with get_progress_reporter()(total=len(members), desc='Extracting repository files') as report:
for member in members:
report.update()
if member.isdev():
# safety: skip if character device, block device or FIFO
msg = f'WARNING, device found inside the import file: {member.name}'
ARCHIVE_READER_LOGGER.warning(msg)
if member.issym() or member.islnk():
# safety: although dereference=True set in export, so this should not occur
msg = f'WARNING, symlink found inside the import file: {member.name}'
ARCHIVE_READER_LOGGER.warning(msg)
continue
handle.extract(path=self._sandbox.abspath, member=member.name)
else:
for membername in members:
handle.extract(path=self._sandbox.abspath, member=membername)
callback('init', {'total': len(members), 'description': 'Extracting repository files'})
for member in members:
callback('update', 1)
if member.isdev():
# safety: skip if character device, block device or FIFO
msg = f'WARNING, device found inside the import file: {member.name}'
ARCHIVE_READER_LOGGER.warning(msg)
continue
if member.issym() or member.islnk():
# safety: although dereference=True set in export, so this should not occur
msg = f'WARNING, symlink found inside the import file: {member.name}'
ARCHIVE_READER_LOGGER.warning(msg)
continue
handle.extract(path=self._sandbox.abspath, member=member.name)
except zipfile.BadZipfile:
raise TypeError('The input file format is not valid (not a zip file)')

Expand Down Expand Up @@ -507,7 +508,7 @@ def _get_data(self):
self._data = json.loads(path.read_text(encoding='utf8'))
return self._data

def _extract(self, *, path_prefix: str, progress: bool):
def _extract(self, *, path_prefix: str, callback: Callable[[str, Any], None] = null_callback):
# pylint: disable=unused-argument
self.assert_within_context()
assert self._sandbox is not None # required by mypy
Expand Down
34 changes: 33 additions & 1 deletion aiida/tools/importexport/dbimport/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
###########################################################################
"""Common import functions for both database backend"""
import copy
from typing import Dict, Optional
from typing import Dict, List, Optional

from aiida.common import timezone
from aiida.common.folders import RepositoryFolder
from aiida.common.progress_reporter import get_progress_reporter
from aiida.orm import Group, ImportGroup, Node, QueryBuilder
from aiida.orm.utils._repository import Repository
from aiida.tools.importexport.archive.readers import ArchiveReaderAbstract
from aiida.tools.importexport.common import exceptions
from aiida.tools.importexport.common.config import NODE_ENTITY_NAME
from aiida.tools.importexport.dbimport.utils import IMPORT_LOGGER
Expand All @@ -22,6 +25,35 @@
MAX_GROUPS = 100


def _copy_node_repositories(*, uuids_to_create: List[str], reader: ArchiveReaderAbstract):
"""Copy repositories of new nodes from the archive to the AiiDa profile.

:param uuids_to_create: the node UUIDs to copy
:param reader: the archive reader

"""
if not uuids_to_create:
return
IMPORT_LOGGER.debug('CREATING NEW NODE REPOSITORIES...')
with get_progress_reporter()(total=1) as progress:
ltalirz marked this conversation as resolved.
Show resolved Hide resolved

# the callback will handle manipulation of the progress reporter
def _callback(action, value):
if action == 'init':
progress.reset(value['total'])
progress.set_description_str(value['description'])
elif action == 'update':
progress.update(value)

for import_entry_uuid, subfolder in zip(
uuids_to_create, reader.iter_node_repos(uuids_to_create, callback=_callback)
):
destdir = RepositoryFolder(section=Repository._section_name, uuid=import_entry_uuid) # pylint: disable=protected-access
# Replace the folder, possibly destroying existing previous folders, and move the files
# (faster if we are on the same filesystem, and in any case the source is a SandboxFolder)
destdir.replace_with_folder(subfolder.abspath, move=True, overwrite=True)


def _make_import_group(
*, group: Optional[ImportGroup], existing_entries: Dict[str, Dict[str, dict]],
new_entries: Dict[str, Dict[str, dict]], foreign_ids_reverse_mappings: Dict[str, Dict[str, int]]
Expand Down
15 changes: 2 additions & 13 deletions aiida/tools/importexport/dbimport/backends/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
from typing import Any, Dict, Iterable, Optional, Set, Tuple
import warnings

from aiida.common.folders import RepositoryFolder
from aiida.common.links import LinkType, validate_link_label
from aiida.common.progress_reporter import get_progress_reporter
from aiida.common.utils import get_object_from_string, validate_uuid
from aiida.common.warnings import AiidaDeprecationWarning
from aiida.manage.configuration import get_config_option
from aiida.orm.utils._repository import Repository
from aiida.orm import Group

from aiida.tools.importexport.common import exceptions
Expand All @@ -36,7 +34,7 @@
from aiida.tools.importexport.archive.readers import ArchiveReaderAbstract, get_reader

from aiida.tools.importexport.dbimport.backends.common import (
_make_import_group, _sanitize_extras, MAX_COMPUTERS, MAX_GROUPS
_copy_node_repositories, _make_import_group, _sanitize_extras, MAX_COMPUTERS, MAX_GROUPS
)


Expand Down Expand Up @@ -447,17 +445,8 @@ def _store_entity_data(

# Before storing entries in the DB, I store the files (if these are nodes).
# Note: only for new entries!
if objects_to_create:
IMPORT_LOGGER.debug('CREATING NEW NODE REPOSITORIES...')
uuids_to_create = [obj.uuid for obj in objects_to_create]
for import_entry_uuid, subfolder in zip(
uuids_to_create,
reader.iter_node_repos(uuids_to_create, progress=True, description='Copying Repository Folders')
):
destdir = RepositoryFolder(section=Repository._section_name, uuid=import_entry_uuid)
# Replace the folder, possibly destroying existing previous folders, and move the files
# (faster if we are on the same filesystem, and in any case the source is a SandboxFolder)
destdir.replace_with_folder(subfolder.abspath, move=True, overwrite=True)
_copy_node_repositories(uuids_to_create=uuids_to_create, reader=reader)

# For the existing nodes that are also in the imported list we also update their extras if necessary
if existing_entries[entity_name]:
Expand Down
Loading