diff --git a/aiida/common/progress_reporter.py b/aiida/common/progress_reporter.py index f4c1ee9b2c..42df0c383d 100644 --- a/aiida/common/progress_reporter.py +++ b/aiida/common/progress_reporter.py @@ -50,9 +50,25 @@ def __init__(self, *, total: int, desc: Optional[str] = None, **kwargs: Any): :param desc: A description of the process """ - self.total = total - self.desc = desc - self.increment = 0 + self._total = total + self._desc = desc + self._increment: int = 0 + + @property + def total(self) -> int: + """Return the total iterations expected.""" + return self._total + + @property + def desc(self) -> Optional[str]: + """Return the description of the process.""" + return self._desc + + @property + def n(self) -> int: # pylint: disable=invalid-name + """Return the current iteration.""" + # note using `n` as the attribute name is necessary for compatibility with tqdm + return self._increment def __enter__(self) -> 'ProgressReporterAbstract': """Enter the contextmanager.""" @@ -71,7 +87,7 @@ def set_description_str(self, text: Optional[str] = None, refresh: bool = True): :param refresh: Force refresh of the progress reporter """ - self.desc = text + self._desc = text def update(self, n: int = 1): # pylint: disable=invalid-name """Update the progress counter. @@ -79,7 +95,17 @@ def update(self, n: int = 1): # pylint: disable=invalid-name :param n: Increment to add to the internal counter of iterations """ - self.increment += n + self._increment += n + + def reset(self, total: Optional[int] = None): + """Resets current iterations to 0. + + :param total: If not None, update number of expected iterations. + + """ + self._increment = 0 + if total is not None: + self._total = total class ProgressReporterNull(ProgressReporterAbstract): diff --git a/aiida/tools/importexport/archive/readers.py b/aiida/tools/importexport/archive/readers.py index b8de50c717..2ac76966ce 100644 --- a/aiida/tools/importexport/archive/readers.py +++ b/aiida/tools/importexport/archive/readers.py @@ -13,7 +13,7 @@ import os from pathlib import Path from types import TracebackType -from typing import Any, cast, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type +from typing import Any, Callable, cast, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type import zipfile import tarfile @@ -22,7 +22,6 @@ from aiida.common.log import AIIDA_LOGGER from aiida.common.exceptions import InvalidOperation from aiida.common.folders import Folder, SandboxFolder -from aiida.common.progress_reporter import get_progress_reporter from aiida.tools.importexport.common.config import EXPORT_VERSION, ExportFileFormat, NODES_EXPORT_SUBFOLDER from aiida.tools.importexport.common.exceptions import (CorruptArchive, IncompatibleArchiveVersionError) from aiida.tools.importexport.archive.common import ArchiveMetadata @@ -37,6 +36,7 @@ 'ReaderJsonTar', 'ReaderJsonZip', 'get_reader', + 'null_callback', ) ARCHIVE_READER_LOGGER = AIIDA_LOGGER.getChild('archive.reader') @@ -58,6 +58,10 @@ def get_reader(file_format: str) -> Type['ArchiveReaderAbstract']: return cast(Type[ArchiveReaderAbstract], readers[file_format]) +def null_callback(action: str, value: Any): # pylint: disable=unused-argument + """A null callback function.""" + + class ArchiveReaderAbstract(ABC): """An abstract interface for AiiDA archive readers. @@ -184,15 +188,21 @@ def iter_link_data(self) -> Iterator[dict]: """Iterate over links: {'input': , 'output': , 'label':