diff --git a/.github/workflows/code-style.yml b/.github/workflows/code-style.yml index 951d86f0..a9fc3bbc 100644 --- a/.github/workflows/code-style.yml +++ b/.github/workflows/code-style.yml @@ -16,7 +16,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.11 + python-version: 3.12 cache: pip cache-dependency-path: | **/pyproject.toml @@ -26,10 +26,10 @@ jobs: run: | bash scripts/build-info.sh - - name: Install Flake8 (5.0.4) + - name: Install Flake8 (7.1.1) run: | python -m pip install --upgrade pip - pip install flake8==5.0.4 + pip install flake8==7.1.1 - name: Install dependencies run: | @@ -51,6 +51,7 @@ jobs: echo "- $package" echo "-------------------------------------------------" pip install packages/$package + rm -rf packages/$package/build echo "=================================================" done @@ -71,7 +72,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.11 + python-version: 3.12 cache: pip cache-dependency-path: | **/pyproject.toml @@ -81,10 +82,10 @@ jobs: run: | bash scripts/build-info.sh - - name: Install MyPy (1.4.1) + - name: Install MyPy (1.14.0) run: | python -m pip install --upgrade pip - pip install mypy==1.4.1 + pip install mypy==1.14.0 - name: Install dependencies run: | @@ -106,14 +107,17 @@ jobs: echo "- $package" echo "-------------------------------------------------" pip install packages/$package + rm -rf packages/$package/build echo "=================================================" done - name: Check typing with MyPy run: | - mypy --install-types --ignore-missing-imports --check-untyped-defs --non-interactive packages/*/dsw - + mypy --install-types --ignore-missing-imports \ + --check-untyped-defs --non-interactive \ + packages/*/dsw + # Consistency of version tagging version: name: Version consts.py runs-on: ubuntu-latest @@ -145,3 +149,58 @@ jobs: bash scripts/check-version.sh \ packages/dsw-tdk/dsw/tdk/consts.py \ packages/dsw-tdk/pyproject.toml + + # Pylint + pylint: + name: Pylint + runs-on: ubuntu-latest + + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.12 + cache: pip + cache-dependency-path: | + **/pyproject.toml + **/requirements*.txt + + - name: Create build info + run: | + bash scripts/build-info.sh + + - name: Install PyLint (3.3.3) + run: | + python -m pip install --upgrade pip + pip install pylint==3.3.3 + + - name: Install dependencies + run: | + ROOT=$(pwd) + for package in $(ls packages); do + echo "-------------------------------------------------" + echo "- $package" + echo "-------------------------------------------------" + cd "$ROOT/packages/$package" + pip install -r requirements.txt + make local-deps + echo "=================================================" + done + + - name: Install packages + run: | + for package in $(ls packages); do + echo "-------------------------------------------------" + echo "- $package" + echo "-------------------------------------------------" + pip install packages/$package + rm -rf packages/$package/build + echo "=================================================" + done + + - name: Lint with PyLint + run: | + pylint --rcfile=.pylintrc.ini packages/*/dsw diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 52007c66..4463f6a4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,8 +15,8 @@ jobs: - 'ubuntu-latest' - 'windows-latest' python-version: - - '3.10' - '3.11' + - '3.12' runs-on: ${{ matrix.os }} diff --git a/.pylintrc.ini b/.pylintrc.ini new file mode 100644 index 00000000..f9d4f08d --- /dev/null +++ b/.pylintrc.ini @@ -0,0 +1,638 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths= + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules=dkim + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.11 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# Add paths to the list of the source roots. Supports globbing patterns. The +# source root is an absolute path or a path relative to the current working +# directory used to determine a package namespace for modules located under the +# source root. +source-roots= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + db, + s3, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type alias names. If left empty, type +# alias names will be checked with the set naming style. +#typealias-rgx= + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + asyncSetUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=10 + +# Maximum number of attributes for a class (see R0902). +max-attributes=20 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=1 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + missing-function-docstring, + missing-module-docstring, + missing-class-docstring, + broad-exception-caught, + duplicate-code, + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +#output-format= + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. No available dictionaries : You need to install +# both the python package and the system dependency for enchant to work.. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/packages/dsw-command-queue/dsw/command_queue/command_queue.py b/packages/dsw-command-queue/dsw/command_queue/command_queue.py index 62919fb0..17da12eb 100644 --- a/packages/dsw-command-queue/dsw/command_queue/command_queue.py +++ b/packages/dsw-command-queue/dsw/command_queue/command_queue.py @@ -1,13 +1,14 @@ import abc import datetime -import func_timeout import logging import os import platform -import psycopg -import psycopg.generators import select import signal + +import func_timeout +import psycopg +import psycopg.generators import tenacity from dsw.database import Database @@ -22,7 +23,6 @@ RETRY_QUEUE_MULTIPLIER = 0.5 RETRY_QUEUE_TRIES = 5 -INTERRUPTED = False IS_LINUX = platform == 'Linux' if IS_LINUX: @@ -30,16 +30,6 @@ signal.set_wakeup_fd(_QUEUE_PIPE_W) -def signal_handler(recv_signal, frame): - global INTERRUPTED - LOG.warning(f'Received interrupt signal: {recv_signal}') - INTERRUPTED = True - - -signal.signal(signal.SIGINT, signal_handler) -signal.signal(signal.SIGABRT, signal_handler) - - class CommandJobError(BaseException): def __init__(self, job_id: str, message: str, try_again: bool, @@ -56,8 +46,7 @@ def __str__(self): def log_message(self): if self.exc is None: return self.message - else: - return f'{self.message} (caused by: [{type(self.exc).__name__}] {str(self.exc)})' + return f'{self.message} (caused by: [{type(self.exc).__name__}] {str(self.exc)})' def db_message(self): if self.exc is None: @@ -82,7 +71,7 @@ def create(job_id: str, message: str, try_again: bool = True, class CommandWorker: @abc.abstractmethod - def work(self, payload: PersistentCommand): + def work(self, command: PersistentCommand): pass def process_timeout(self, e: BaseException): @@ -94,7 +83,7 @@ def process_exception(self, e: BaseException): class CommandQueue: - def __init__(self, worker: CommandWorker, db: Database, + def __init__(self, *, worker: CommandWorker, db: Database, channel: str, component: str, wait_timeout: float, work_timeout: int | None = None): self.worker = worker @@ -105,6 +94,10 @@ def __init__(self, worker: CommandWorker, db: Database, ) self.wait_timeout = wait_timeout self.work_timeout = work_timeout + self._interrupted = False + + signal.signal(signal.SIGINT, self._signal_handler) + signal.signal(signal.SIGABRT, self._signal_handler) @tenacity.retry( reraise=True, @@ -131,19 +124,19 @@ def run(self): LOG.debug('Waiting for notifications') w = select.select(fds, [], [], self.wait_timeout) - if INTERRUPTED: + if self._interrupted: LOG.debug('Interrupt signal received, ending...') break if w == ([], [], []): - LOG.debug(f'Nothing received in this cycle ' - f'(timeouted after {self.wait_timeout} seconds)') + LOG.debug('Nothing received in this cycle (timeout %s seconds)', + self.wait_timeout) else: notifications = 0 for n in psycopg.generators.notifies(queue_conn.connection.pgconn): notifications += 1 LOG.debug(str(n)) - LOG.info(f'Notifications received ({notifications} in total)') + LOG.info('Notifications received (%s in total)', notifications) LOG.debug('Exiting command queue') @tenacity.retry( @@ -162,11 +155,12 @@ def _fetch_and_process_queued(self): count = 0 while self.fetch_and_process(): count += 1 - LOG.info(f'There are no more commands to process ({count} processed)') + LOG.info('There are no more commands to process (%s processed)', + count) def accept_notification(self, payload: psycopg.Notify) -> bool: - LOG.debug(f'Accepting notification from channel "{payload.channel}" ' - f'(PID = {payload.pid}) {payload.payload}') + LOG.debug('Accepting notification from channel "%s" (PID = %s) %s', + payload.channel, payload.pid, payload.payload) LOG.debug('Trying to fetch a new job') return self.fetch_and_process() @@ -178,16 +172,25 @@ def fetch_and_process(self) -> bool: ) result = cursor.fetchall() if len(result) != 1: - LOG.debug(f'Fetched {len(result)} persistent commands') + LOG.debug('Fetched %s persistent commands', len(result)) return False command = PersistentCommand.from_dict_row(result[0]) - LOG.info(f'Retrieved persistent command {command.uuid} for processing') - LOG.debug(f'Previous state: {command.state}') - LOG.debug(f'Attempts: {command.attempts} / {command.max_attempts}') - LOG.debug(f'Last error: {command.last_error_message}') - attempt_number = command.attempts + 1 + LOG.info('Retrieved persistent command %s for processing', command.uuid) + LOG.debug('Previous state: %s', command.state) + LOG.debug('Attempts: %s / %s', command.attempts, command.max_attempts) + LOG.debug('Last error: %s', command.last_error_message) + self._process(command) + + LOG.debug('Committing transaction') + self.db.conn_query.connection.commit() + cursor.close() + LOG.info('Notification processing finished') + return True + + def _process(self, command: PersistentCommand): + attempt_number = command.attempts + 1 try: self.db.execute_query( query=self.queries.query_command_start(), @@ -204,7 +207,8 @@ def work(): LOG.info('Processing (without any timeout set)') work() else: - LOG.info(f'Processing (with timeout set to {self.work_timeout} seconds)') + LOG.info('Processing (with timeout set to %s seconds)', + self.work_timeout) func_timeout.func_timeout( timeout=self.work_timeout, func=work, @@ -260,8 +264,7 @@ def work(): uuid=command.uuid, ) - LOG.debug('Committing transaction') - self.db.conn_query.connection.commit() - cursor.close() - LOG.info('Notification processing finished') - return True + def _signal_handler(self, recv_signal, frame): + LOG.warning('Received interrupt signal: %s (frame: %s)', + recv_signal, frame) + self._interrupted = True diff --git a/packages/dsw-command-queue/dsw/command_queue/query.py b/packages/dsw-command-queue/dsw/command_queue/query.py index 82874c5b..1535b455 100644 --- a/packages/dsw-command-queue/dsw/command_queue/query.py +++ b/packages/dsw-command-queue/dsw/command_queue/query.py @@ -1,4 +1,7 @@ -class CommandState: +import enum + + +class CommandState(enum.Enum): NEW = 'NewPersistentCommandState' DONE = 'DonePersistentCommandState' ERROR = 'ErrorPersistentCommandState' @@ -20,9 +23,11 @@ def query_get_command(self, exp=2, interval='1 min') -> str: FROM persistent_command WHERE component = '{self.component}' AND attempts < max_attempts - AND state != '{CommandState.DONE}' - AND state != '{CommandState.IGNORE}' - AND (updated_at AT TIME ZONE 'UTC') < (%(now)s - ({exp} ^ attempts - 1) * INTERVAL '{interval}') + AND state != '{CommandState.DONE.value}' + AND state != '{CommandState.IGNORE.value}' + AND (updated_at AT TIME ZONE 'UTC') + < + (%(now)s - ({exp} ^ attempts - 1) * INTERVAL '{interval}') ORDER BY attempts ASC, updated_at DESC LIMIT 1 FOR UPDATE SKIP LOCKED; """ @@ -33,7 +38,7 @@ def query_command_error() -> str: UPDATE persistent_command SET attempts = %(attempts)s, last_error_message = %(error_message)s, - state = '{CommandState.ERROR}', + state = '{CommandState.ERROR.value}', updated_at = %(updated_at)s WHERE uuid = %(uuid)s; """ @@ -55,7 +60,7 @@ def query_command_done() -> str: return f""" UPDATE persistent_command SET attempts = %(attempts)s, - state = '{CommandState.DONE}', + state = '{CommandState.DONE.value}', updated_at = %(updated_at)s WHERE uuid = %(uuid)s; """ diff --git a/packages/dsw-command-queue/pyproject.toml b/packages/dsw-command-queue/pyproject.toml index f2104196..9b02ca0b 100644 --- a/packages/dsw-command-queue/pyproject.toml +++ b/packages/dsw-command-queue/pyproject.toml @@ -16,13 +16,13 @@ classifiers = [ 'Development Status :: 5 - Production/Stable', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python', - 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Topic :: Database', 'Topic :: Text Processing', 'Topic :: Utilities', ] -requires-python = '>=3.10, <4' +requires-python = '>=3.11, <4' dependencies = [ 'func-timeout', # DSW diff --git a/packages/dsw-config/dsw/config/keys.py b/packages/dsw-config/dsw/config/keys.py index df44715e..db3a4c75 100644 --- a/packages/dsw-config/dsw/config/keys.py +++ b/packages/dsw-config/dsw/config/keys.py @@ -1,60 +1,61 @@ +# pylint: disable=too-few-public-methods import collections +import typing -from typing import Any, Optional, Generic, TypeVar, Callable +T = typing.TypeVar('T') -T = TypeVar('T') - -def cast_bool(value: Any) -> bool: +def cast_bool(value: typing.Any) -> bool: return bool(value) -def cast_optional_bool(value: Any) -> Optional[bool]: +def cast_optional_bool(value: typing.Any) -> bool | None: if value is None: return None return bool(value) -def cast_int(value: Any) -> int: +def cast_int(value: typing.Any) -> int: return int(value) -def cast_optional_int(value: Any) -> Optional[int]: +def cast_optional_int(value: typing.Any) -> int | None: if value is None: return None return int(value) -def cast_float(value: Any) -> float: +def cast_float(value: typing.Any) -> float: return float(value) -def cast_optional_float(value: Any) -> Optional[float]: +def cast_optional_float(value: typing.Any) -> float | None: if value is None: return None return float(value) -def cast_str(value: Any) -> str: +def cast_str(value: typing.Any) -> str: return str(value) -def cast_optional_str(value: Any) -> Optional[str]: +def cast_optional_str(value: typing.Any) -> str | None: if value is None: return None return str(value) -def cast_optional_dict(value: Any) -> Optional[dict]: +def cast_optional_dict(value: typing.Any) -> dict | None: if not isinstance(value, dict): return None return value -class ConfigKey(Generic[T]): +class ConfigKey(typing.Generic[T]): - def __init__(self, yaml_path: list[str], cast: Callable[[Any], T], + def __init__(self, *, yaml_path: list[str], + cast: typing.Callable[[typing.Any], T], var_names=None, default=None, required=False): self.yaml_path = yaml_path self.var_names = var_names or [] # type: list[str] @@ -63,12 +64,13 @@ def __init__(self, yaml_path: list[str], cast: Callable[[Any], T], self.cast = cast def __str__(self): - return 'ConfigKey: ' + '.'.join(self.yaml_path) + return f'ConfigKey: {".".join(self.yaml_path)}' class ConfigKeysMeta(type): @classmethod + # pylint: disable-next=unused-argument def __prepare__(mcs, name, bases, **kwargs): return collections.OrderedDict() diff --git a/packages/dsw-config/dsw/config/logging.py b/packages/dsw-config/dsw/config/logging.py index 944035cf..a7148a76 100644 --- a/packages/dsw-config/dsw/config/logging.py +++ b/packages/dsw-config/dsw/config/logging.py @@ -28,6 +28,8 @@ def __init__(self, *args, **kwargs): def prepare_logging(logging_cfg): + # pylint: disable-next=no-member + logger_dict = logging.root.manager.loggerDict if logging_cfg.dict_config is not None: logging.config.dictConfig(logging_cfg.dict_config) else: @@ -36,15 +38,13 @@ def prepare_logging(logging_cfg): level=logging_cfg.global_level, format=logging_cfg.message_format ) - dsw_loggers = (logging.getLogger(name) - for name in logging.root.manager.loggerDict.keys() + dsw_loggers = (logging.getLogger(name) for name in logger_dict if name.lower().startswith('dsw')) for logger in dsw_loggers: logger.setLevel(logging_cfg.level) # Set for all existing loggers logging.getLogger().addFilter(filter=LOG_FILTER) - loggers = (logging.getLogger(name) - for name in logging.root.manager.loggerDict.keys()) + loggers = (logging.getLogger(name) for name in logger_dict) for logger in loggers: logger.addFilter(filter=LOG_FILTER) # Set for any future loggers diff --git a/packages/dsw-config/dsw/config/model.py b/packages/dsw-config/dsw/config/model.py index 1a2b44db..ae9bfadb 100644 --- a/packages/dsw-config/dsw/config/model.py +++ b/packages/dsw-config/dsw/config/model.py @@ -1,4 +1,4 @@ -from typing import Optional +import dataclasses from .logging import prepare_logging, LOG_FILTER @@ -19,54 +19,44 @@ def __str__(self): return _config_to_string(self) +@dataclasses.dataclass class GeneralConfig(ConfigModel): - - def __init__(self, environment: str, client_url: str, secret: str): - self.environment = environment - self.client_url = client_url - self.secret = secret + environment: str + client_url: str + secret: str +@dataclasses.dataclass class SentryConfig(ConfigModel): - - def __init__(self, enabled: bool, workers_dsn: Optional[str], - traces_sample_rate: Optional[float], max_breadcrumbs: Optional[int], - environment: str): - self.enabled = enabled - self.workers_dsn = workers_dsn - self.traces_sample_rate = traces_sample_rate - self.max_breadcrumbs = max_breadcrumbs - self.environment = environment + enabled: bool + workers_dsn: str | None + traces_sample_rate: float | None + max_breadcrumbs: int | None + environment: str +@dataclasses.dataclass class DatabaseConfig(ConfigModel): - - def __init__(self, connection_string: str, connection_timeout: int, - queue_timeout: int): - self.connection_string = connection_string - self.connection_timeout = connection_timeout - self.queue_timeout = queue_timeout + connection_string: str + connection_timeout: int + queue_timeout: int +@dataclasses.dataclass class S3Config(ConfigModel): - - def __init__(self, url: str, username: str, password: str, - bucket: str, region: str): - self.url = url - self.username = username - self.password = password - self.bucket = bucket - self.region = region + url: str + username: str + password: str + bucket: str + region: str +@dataclasses.dataclass class LoggingConfig(ConfigModel): - - def __init__(self, level: str, global_level: str, message_format: str, - dict_config: Optional[dict] = None): - self.level = level - self.global_level = global_level - self.message_format = message_format - self.dict_config = dict_config + level: str + global_level: str + message_format: str + dict_config: dict | None = None def apply(self): prepare_logging(self) @@ -76,20 +66,17 @@ def set_logging_extra(key: str, value: str): LOG_FILTER.set_extra(key, value) +@dataclasses.dataclass class AWSConfig(ConfigModel): - - def __init__(self, access_key_id: Optional[str], secret_access_key: Optional[str], - region: Optional[str]): - self.access_key_id = access_key_id - self.secret_access_key = secret_access_key - self.region = region + access_key_id: str | None + secret_access_key: str | None + region: str | None @property def has_credentials(self) -> bool: return self.access_key_id is not None and self.secret_access_key is not None +@dataclasses.dataclass class CloudConfig(ConfigModel): - - def __init__(self, multi_tenant: bool): - self.multi_tenant = multi_tenant + multi_tenant: bool diff --git a/packages/dsw-config/dsw/config/parser.py b/packages/dsw-config/dsw/config/parser.py index 38e53077..d1c11b2b 100644 --- a/packages/dsw-config/dsw/config/parser.py +++ b/packages/dsw-config/dsw/config/parser.py @@ -1,7 +1,7 @@ import os -import yaml +import typing -from typing import List, Any, IO +import yaml from .keys import ConfigKey, ConfigKeys from .model import GeneralConfig, SentryConfig, S3Config, \ @@ -10,14 +10,14 @@ class MissingConfigurationError(Exception): - def __init__(self, missing: List[str]): + def __init__(self, missing: list[str]): self.missing = missing class DSWConfigParser: def __init__(self, keys=ConfigKeys): - self.cfg = dict() + self.cfg = {} self.keys = keys @staticmethod @@ -28,7 +28,7 @@ def can_read(content: str): except Exception: return False - def read_file(self, fp: IO): + def read_file(self, fp: typing.IO): self.cfg = yaml.load(fp, Loader=yaml.FullLoader) or self.cfg def read_string(self, content: str): @@ -50,12 +50,12 @@ def has_value_for_key(self, key: ConfigKey): if self.has_value_for_path(key.yaml_path): return True for var_name in key.var_names: - if var_name in os.environ.keys() or \ - self._prefix_var(var_name) in os.environ.keys(): + if var_name in os.environ or self._prefix_var(var_name) in os.environ: return True + return False def get_or_default(self, key: ConfigKey): - x = self.cfg # type: Any + x: typing.Any = self.cfg for p in key.yaml_path: if not hasattr(x, 'keys') or p not in x.keys(): return key.default @@ -64,9 +64,9 @@ def get_or_default(self, key: ConfigKey): def get(self, key: ConfigKey): for var_name in key.var_names: - if var_name in os.environ.keys(): + if var_name in os.environ: return key.cast(os.environ[var_name]) - if self._prefix_var(var_name) in os.environ.keys(): + if self._prefix_var(var_name) in os.environ: return key.cast(os.environ[self._prefix_var(var_name)]) return key.cast(self.get_or_default(key)) diff --git a/packages/dsw-config/dsw/config/sentry.py b/packages/dsw-config/dsw/config/sentry.py index 4dea8fef..e8953782 100644 --- a/packages/dsw-config/dsw/config/sentry.py +++ b/packages/dsw-config/dsw/config/sentry.py @@ -16,7 +16,7 @@ class SentryReporter: filters = [] # type: list[EventProcessor] @classmethod - def initialize(cls, config: SentryConfig, prog_name: str, release: str, + def initialize(cls, *, config: SentryConfig, prog_name: str, release: str, breadcrumb_level: int | None = logging.INFO, event_level: int | None = logging.ERROR): cls.report = config.enabled and config.workers_dsn is not None diff --git a/packages/dsw-config/pyproject.toml b/packages/dsw-config/pyproject.toml index 182bc454..2abd1bcd 100644 --- a/packages/dsw-config/pyproject.toml +++ b/packages/dsw-config/pyproject.toml @@ -16,12 +16,12 @@ classifiers = [ 'Development Status :: 5 - Production/Stable', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python', - 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Topic :: Text Processing', 'Topic :: Utilities', ] -requires-python = '>=3.10, <4' +requires-python = '>=3.11, <4' dependencies = [ 'PyYAML', 'sentry-sdk', diff --git a/packages/dsw-data-seeder/dsw/data_seeder/cli.py b/packages/dsw-data-seeder/dsw/data_seeder/cli.py index 72a03836..0dd1f933 100644 --- a/packages/dsw-data-seeder/dsw/data_seeder/cli.py +++ b/packages/dsw-data-seeder/dsw/data_seeder/cli.py @@ -1,7 +1,8 @@ -import click # type: ignore import pathlib +import sys +import typing -from typing import IO, Optional +import click from dsw.config.parser import MissingConfigurationError @@ -15,7 +16,7 @@ def load_config_str(config_str: str) -> SeederConfig: parser = SeederConfigParser() if not parser.can_read(config_str): click.echo('Error: Cannot parse config file', err=True) - exit(1) + sys.exit(1) try: parser.read_string(config_str) @@ -24,14 +25,15 @@ def load_config_str(config_str: str) -> SeederConfig: click.echo('Error: Missing configuration', err=True) for missing_item in e.missing: click.echo(f' - {missing_item}', err=True) - exit(1) + sys.exit(1) config = parser.config config.log.apply() return config -def validate_config(ctx, param, value: Optional[IO]): +# pylint: disable-next=unused-argument +def validate_config(ctx, param, value: typing.IO | None): content = '' if value is not None: content = value.read() @@ -75,7 +77,7 @@ def seed(ctx: click.Context, recipe: str, tenant_uuid: str): @cli.command(name='list', help='List recipes for data seeding.') @click.pass_context -def list(ctx: click.Context): +def recipes_list(ctx: click.Context): workdir = ctx.obj['workdir'] recipes = SeedRecipe.load_from_dir(workdir) for recipe in recipes.values(): @@ -84,4 +86,5 @@ def list(ctx: click.Context): def main(): - cli(obj=dict()) + # pylint: disable-next=no-value-for-parameter + cli(obj={}) diff --git a/packages/dsw-data-seeder/dsw/data_seeder/config.py b/packages/dsw-data-seeder/dsw/data_seeder/config.py index 8ab2402d..f4920b10 100644 --- a/packages/dsw-data-seeder/dsw/data_seeder/config.py +++ b/packages/dsw-data-seeder/dsw/data_seeder/config.py @@ -1,3 +1,5 @@ +import dataclasses + from dsw.config import DSWConfigParser from dsw.config.keys import ConfigKey, ConfigKeys, ConfigKeysContainer, \ cast_str, cast_int, cast_optional_int @@ -5,6 +7,7 @@ LoggingConfig, SentryConfig, CloudConfig, GeneralConfig +# pylint: disable-next=too-few-public-methods class _ExperimentalKeys(ConfigKeysContainer): job_timeout = ConfigKey( yaml_path=['experimental', 'jobTimeout'], @@ -14,29 +17,26 @@ class _ExperimentalKeys(ConfigKeysContainer): ) +# pylint: disable-next=too-few-public-methods class MailerConfigKeys(ConfigKeys): experimental = _ExperimentalKeys +@dataclasses.dataclass class ExperimentalConfig(ConfigModel): - - def __init__(self, job_timeout: int | None): - self.job_timeout = job_timeout + job_timeout: int | None +@dataclasses.dataclass class SeederConfig: - - def __init__(self, db: DatabaseConfig, s3: S3Config, log: LoggingConfig, - sentry: SentryConfig, cloud: CloudConfig, general: GeneralConfig, - extra_dbs: dict[str, DatabaseConfig], experimental: ExperimentalConfig): - self.general = general - self.db = db - self.s3 = s3 - self.log = log - self.sentry = sentry - self.cloud = cloud - self.extra_dbs = extra_dbs - self.experimental = experimental + db: DatabaseConfig + s3: S3Config + log: LoggingConfig + sentry: SentryConfig + cloud: CloudConfig + general: GeneralConfig + extra_dbs: dict[str, DatabaseConfig] + experimental: ExperimentalConfig def __str__(self): return f'SeederConfig\n' \ @@ -60,7 +60,7 @@ def __init__(self): @property def extra_dbs(self) -> dict[str, DatabaseConfig]: result = {} - for db_id, val in self.cfg.get('extraDatabases', {}).items(): + for db_id in self.cfg.get('extraDatabases', {}).keys(): result[db_id] = DatabaseConfig( connection_string=self.get( key=ConfigKey( diff --git a/packages/dsw-data-seeder/dsw/data_seeder/context.py b/packages/dsw-data-seeder/dsw/data_seeder/context.py index b8307fd1..39ccde54 100644 --- a/packages/dsw-data-seeder/dsw/data_seeder/context.py +++ b/packages/dsw-data-seeder/dsw/data_seeder/context.py @@ -1,11 +1,10 @@ +import dataclasses import pathlib -from typing import Optional, TYPE_CHECKING +from dsw.database import Database +from dsw.storage import S3Storage -if TYPE_CHECKING: - from .config import SeederConfig - from dsw.database import Database - from dsw.storage import S3Storage +from .config import SeederConfig class ContextNotInitializedError(RuntimeError): @@ -14,19 +13,17 @@ def __init__(self): super().__init__('Context cannot be retrieved, not initialized') +@dataclasses.dataclass class AppContext: - - def __init__(self, db, s3, cfg, workdir): - self.db = db # type: Database - self.s3 = s3 # type: S3Storage - self.cfg = cfg # type: SeederConfig - self.workdir = workdir # type: pathlib.Path + db: Database + s3: S3Storage + cfg: SeederConfig + workdir: pathlib.Path +@dataclasses.dataclass class JobContext: - - def __init__(self, trace_id: str): - self.trace_id = trace_id + trace_id: str class _Context: @@ -49,7 +46,7 @@ def reset_ids(self): class Context: - _instance = None # type: Optional[_Context] + _instance: _Context | None = None @classmethod def get(cls) -> _Context: diff --git a/packages/dsw-data-seeder/dsw/data_seeder/handlers.py b/packages/dsw-data-seeder/dsw/data_seeder/handlers.py index 1a7e0005..2a813c46 100644 --- a/packages/dsw-data-seeder/dsw/data_seeder/handlers.py +++ b/packages/dsw-data-seeder/dsw/data_seeder/handlers.py @@ -1,11 +1,14 @@ import os import pathlib +import sys from .cli import load_config_str -from .consts import VAR_APP_CONFIG_PATH, VAR_WORKDIR_PATH, VAR_SEEDER_RECIPE +from .consts import (VAR_APP_CONFIG_PATH, VAR_WORKDIR_PATH, + VAR_SEEDER_RECIPE, DEFAULT_ENCODING) from .seeder import DataSeeder +# pylint: disable-next=unused-argument def lambda_handler(event, context): config_path = pathlib.Path(os.getenv(VAR_APP_CONFIG_PATH, '/var/task/application.yml')) workdir_path = pathlib.Path(os.getenv(VAR_WORKDIR_PATH, '/var/task/data')) @@ -13,8 +16,8 @@ def lambda_handler(event, context): if recipe_name is None: print(f'Error: Missing recipe name (environment variable {VAR_SEEDER_RECIPE})') - exit(1) + sys.exit(1) - config = load_config_str(config_path.read_text()) + config = load_config_str(config_path.read_text(encoding=DEFAULT_ENCODING)) seeder = DataSeeder(config, workdir_path) seeder.run_once(recipe_name) diff --git a/packages/dsw-data-seeder/dsw/data_seeder/seeder.py b/packages/dsw-data-seeder/dsw/data_seeder/seeder.py index f9a84866..534e6d10 100644 --- a/packages/dsw-data-seeder/dsw/data_seeder/seeder.py +++ b/packages/dsw-data-seeder/dsw/data_seeder/seeder.py @@ -1,13 +1,13 @@ import collections -import dateutil.parser import json import logging import mimetypes import pathlib import time +import typing import uuid -from typing import Optional +import dateutil.parser from dsw.command_queue import CommandWorker, CommandQueue from dsw.config.sentry import SentryReporter @@ -47,11 +47,13 @@ def id(self) -> str: class SeedRecipe: - def __init__(self, name: str, description: str, root: pathlib.Path, + # pylint: disable-next=too-many-arguments + def __init__(self, *, name: str, description: str, root: pathlib.Path, db_scripts: dict[str, DBScript], db_placeholder: str, - s3_dir: Optional[pathlib.Path], s3_fname_replace: dict[str, str], - uuids_count: int, uuids_placeholder: Optional[str], + s3_dir: pathlib.Path | None, s3_fname_replace: dict[str, str], + uuids_count: int, uuids_placeholder: str | None, init_wait: float): + # pylint: disable-next=too-many-instance-attributes self.name = name self.description = description self.root = root @@ -59,12 +61,12 @@ def __init__(self, name: str, description: str, root: pathlib.Path, self.db_placeholder = db_placeholder self.s3_dir = s3_dir self.s3_fname_replace = s3_fname_replace - self._db_scripts_data = collections.OrderedDict() # type: dict[str, str] - self.s3_objects = collections.OrderedDict() # type: dict[pathlib.Path, str] + self._db_scripts_data: dict[str, str] = collections.OrderedDict() + self.s3_objects: dict[pathlib.Path, str] = collections.OrderedDict() self.prepared = False self.uuids_count = uuids_count self.uuids_placeholder = uuids_placeholder - self.uuids_replacement = dict() # type: dict[str, str] + self.uuids_replacement: dict[str, str] = {} self.init_wait = init_wait def _load_db_scripts(self): @@ -147,10 +149,10 @@ def load_from_json(recipe_file: pathlib.Path) -> 'SeedRecipe': data = json.loads(recipe_file.read_text( encoding=DEFAULT_ENCODING, )) - db = data.get('db', {}) # type: dict - s3 = data.get('s3', {}) # type: dict - scripts = db.get('scripts', []) # type: list[dict] - db_scripts = collections.OrderedDict() # type: dict[str, DBScript] + db: dict[str, typing.Any] = data.get('db', {}) + s3: dict[str, typing.Any] = data.get('s3', {}) + scripts: list[dict] = db.get('scripts', []) + db_scripts: dict[str, DBScript] = collections.OrderedDict() for index, script in enumerate(scripts): target = script.get('target', '') filename = str(script.get('filename', '')) @@ -158,7 +160,7 @@ def load_from_json(recipe_file: pathlib.Path) -> 'SeedRecipe': continue filepath = pathlib.Path(filename) if '*' in filename: - for item in sorted([s for s in recipe_file.parent.glob(filename)]): + for item in sorted(list(recipe_file.parent.glob(filename))): s = DBScript(item, target, index) db_scripts[s.id] = s elif filepath.is_absolute(): @@ -272,15 +274,16 @@ def run_once(self, recipe_name: str): queue = self._run_preparation(recipe_name) queue.run_once() - def work(self, cmd: PersistentCommand): - Context.get().update_trace_id(cmd.uuid) - SentryReporter.set_tags(command_uuid=cmd.uuid) + def work(self, command: PersistentCommand): + Context.get().update_trace_id(command.uuid) + SentryReporter.set_tags(command_uuid=command.uuid) self.recipe.run_prepare() - tenant_uuid = cmd.body['tenantUuid'] - LOG.info(f'Seeding recipe "{self.recipe.name}" ' - f'to tenant with UUID "{tenant_uuid}"') - if cmd.attempts == 0 and self.recipe.init_wait > 0.01: - LOG.info(f'Waiting for {self.recipe.init_wait} seconds (first attempt)') + tenant_uuid = command.body['tenantUuid'] + LOG.info('Seeding recipe "%s" to tenant "%s"', + self.recipe.name, tenant_uuid) + if command.attempts == 0 and self.recipe.init_wait > 0.01: + LOG.info('Waiting for %s seconds (first attempt)', + self.recipe.init_wait) time.sleep(self.recipe.init_wait) self.execute(tenant_uuid) Context.get().update_trace_id('-') @@ -295,8 +298,8 @@ def process_exception(self, e: BaseException): @staticmethod def _update_component_info(): built_at = dateutil.parser.parse(BUILD_INFO.built_at) - LOG.info(f'Updating component info ({BUILD_INFO.version}, ' - f'{built_at.isoformat(timespec="seconds")})') + LOG.info('Updating component info (%s, %s)', + BUILD_INFO.version, built_at.isoformat(timespec="seconds")) Context.get().app.db.update_component_info( name=COMPONENT_NAME, version=BUILD_INFO.version, @@ -305,7 +308,7 @@ def _update_component_info(): def seed(self, recipe_name: str, tenant_uuid: str): self._prepare_recipe(recipe_name) - LOG.info(f'Executing recipe "{recipe_name}"') + LOG.info('Executing recipe "%s"', recipe_name) self.execute(tenant_uuid=tenant_uuid) def execute(self, tenant_uuid: str): @@ -317,9 +320,9 @@ def execute(self, tenant_uuid: str): try: LOG.info('Running SQL scripts') for script_id, sql_script in self.recipe.iterate_db_scripts(tenant_uuid): - LOG.debug(f' -> Executing script: {script_id}') + LOG.debug(' -> Executing script: %s', script_id) script = self.recipe.db_scripts[script_id] - if script.target in self.dbs.keys(): + if script.target in self.dbs: used_targets.add(script.target) with self.dbs[script.target].conn_query.new_cursor(use_dict=True) as c: c.execute(query=sql_script) @@ -330,9 +333,9 @@ def execute(self, tenant_uuid: str): phase = 'S3' LOG.info('Transferring S3 objects') for local_file, object_name in self.recipe.iterate_s3_objects(): - LOG.debug(f' -> Reading: {local_file.name}') + LOG.debug(' -> Reading: %s', local_file.name) data = local_file.read_bytes() - LOG.debug(f' -> Sending: {object_name}') + LOG.debug(' -> Sending: %s', object_name) app_ctx.s3.store_object( tenant_uuid=tenant_uuid, object_name=object_name, @@ -341,20 +344,24 @@ def execute(self, tenant_uuid: str): ) LOG.debug(' OK (stored)') except Exception as e: - LOG.warning(f'Exception appeared [{type(e).__name__}]: {e}') + LOG.warning('Exception appeared [%s]: %s', type(e).__name__, e) LOG.error('Failed with unexpected error', exc_info=e) LOG.info('Rolling back DB changes') - LOG.debug(f'Used extra DBs: {used_targets}') + LOG.debug('Used extra DBs: %s', str(used_targets)) conn = app_ctx.db.conn_query.connection - LOG.debug(f'DEFAULT will roll back: {conn.pgconn.status} / {conn.pgconn.transaction_status}') + LOG.debug('DEFAULT will roll back: %s / %s', + conn.pgconn.status, conn.pgconn.transaction_status) conn.rollback() - LOG.debug(f'DEFAULT rolled back: {conn.pgconn.status} / {conn.pgconn.transaction_status}') + LOG.debug('DEFAULT rolled back: %s / %s', + conn.pgconn.status, conn.pgconn.transaction_status) for target in used_targets: conn = self.dbs[target].conn_query.connection - LOG.debug(f'{target} will roll back: {conn.pgconn.status} / {conn.pgconn.transaction_status}') + LOG.debug('%s will roll back: %s / %s', + target, conn.pgconn.status, conn.pgconn.transaction_status) conn.rollback() - LOG.debug(f'{target} rolled back: {conn.pgconn.status} / {conn.pgconn.transaction_status}') + LOG.debug('%s rolled back: %s / %s', + target, conn.pgconn.status, conn.pgconn.transaction_status) raise RuntimeError(f'{phase}: {e}') from e else: LOG.info('Committing DB changes') diff --git a/packages/dsw-data-seeder/pyproject.toml b/packages/dsw-data-seeder/pyproject.toml index 9e889633..f66c5974 100644 --- a/packages/dsw-data-seeder/pyproject.toml +++ b/packages/dsw-data-seeder/pyproject.toml @@ -16,13 +16,13 @@ classifiers = [ 'Development Status :: 5 - Production/Stable', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python', - 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Topic :: Database', 'Topic :: Text Processing', 'Topic :: Utilities', ] -requires-python = '>=3.10, <4' +requires-python = '>=3.11, <4' dependencies = [ 'click', 'python-dateutil', diff --git a/packages/dsw-database/dsw/database/database.py b/packages/dsw-database/dsw/database/database.py index 38afa952..16f7844e 100644 --- a/packages/dsw-database/dsw/database/database.py +++ b/packages/dsw-database/dsw/database/database.py @@ -1,13 +1,13 @@ import datetime import logging +import typing + import psycopg import psycopg.conninfo import psycopg.rows import psycopg.types.json import tenacity -from typing import List, Iterable, Optional - from dsw.config.model import DatabaseConfig from .model import DBDocumentTemplate, DBDocumentTemplateFile, \ @@ -28,46 +28,56 @@ def wrap_json_data(data: dict): return psycopg.types.json.Json(data) +# pylint: disable-next=too-many-public-methods class Database: - # TODO: refactor queries and models - SELECT_DOCUMENT = 'SELECT * FROM document WHERE uuid = %s AND tenant_uuid = %s LIMIT 1;' - SELECT_QTN_DOCUMENTS = 'SELECT * FROM document WHERE questionnaire_uuid = %s AND tenant_uuid = %s;' - SELECT_DOCUMENT_SUBMISSIONS = 'SELECT * FROM submission WHERE document_uuid = %s AND tenant_uuid = %s;' - SELECT_QTN_SUBMISSIONS = 'SELECT s.* FROM document d JOIN submission s ON d.uuid = s.document_uuid ' \ - 'WHERE d.questionnaire_uuid = %s AND d.tenant_uuid = %s;' - SELECT_QTN_SIMPLE = 'SELECT qtn.* FROM questionnaire qtn ' \ - 'WHERE qtn.uuid = %s AND qtn.tenant_uuid = %s;' - SELECT_TENANT_CONFIG = 'SELECT * FROM tenant_config WHERE uuid = %(tenant_uuid)s LIMIT 1;' - SELECT_TENANT_LIMIT = 'SELECT uuid, storage FROM tenant_limit_bundle WHERE uuid = %(tenant_uuid)s LIMIT 1;' + SELECT_DOCUMENT = ('SELECT * FROM document ' + 'WHERE uuid = %s AND tenant_uuid = %s LIMIT 1;') + SELECT_QTN_DOCUMENTS = ('SELECT * FROM document ' + 'WHERE questionnaire_uuid = %s AND tenant_uuid = %s;') + SELECT_DOCUMENT_SUBMISSIONS = ('SELECT * FROM submission ' + 'WHERE document_uuid = %s AND tenant_uuid = %s;') + SELECT_QTN_SUBMISSIONS = ('SELECT s.* ' + 'FROM document d JOIN submission s ON d.uuid = s.document_uuid ' + 'WHERE d.questionnaire_uuid = %s AND d.tenant_uuid = %s;') + SELECT_QTN_SIMPLE = ('SELECT qtn.* FROM questionnaire qtn ' + 'WHERE qtn.uuid = %s AND qtn.tenant_uuid = %s;') + SELECT_TENANT_CONFIG = ('SELECT * FROM tenant_config ' + 'WHERE uuid = %(tenant_uuid)s LIMIT 1;') + SELECT_TENANT_LIMIT = ('SELECT uuid, storage FROM tenant_limit_bundle ' + 'WHERE uuid = %(tenant_uuid)s LIMIT 1;') UPDATE_DOCUMENT_STATE = 'UPDATE document SET state = %s, worker_log = %s WHERE uuid = %s;' UPDATE_DOCUMENT_RETRIEVED = 'UPDATE document SET retrieved_at = %s, state = %s WHERE uuid = %s;' - UPDATE_DOCUMENT_FINISHED = 'UPDATE document SET finished_at = %s, state = %s, ' \ - 'file_name = %s, content_type = %s, worker_log = %s, ' \ - 'file_size = %s WHERE uuid = %s;' - SELECT_TEMPLATE = 'SELECT * FROM document_template WHERE id = %s AND tenant_uuid = %s LIMIT 1;' - SELECT_TEMPLATE_FILES = 'SELECT * FROM document_template_file ' \ - 'WHERE document_template_id = %s AND tenant_uuid = %s;' - SELECT_TEMPLATE_ASSETS = 'SELECT * FROM document_template_asset ' \ - 'WHERE document_template_id = %s AND tenant_uuid = %s;' - CHECK_TABLE_EXISTS = 'SELECT EXISTS(SELECT * FROM information_schema.tables' \ - ' WHERE table_name = %(table_name)s)' - SELECT_MAIL_CONFIG = 'SELECT icm.* ' \ - 'FROM tenant_config tc JOIN instance_config_mail icm ' \ - 'ON tc.mail_config_uuid = icm.uuid ' \ - 'WHERE tc.uuid = %(tenant_uuid)s;' - UPDATE_COMPONENT_INFO = 'INSERT INTO component (name, version, built_at, created_at, updated_at) ' \ - 'VALUES (%(name)s, %(version)s, %(built_at)s, %(created_at)s, %(updated_at)s)' \ - 'ON CONFLICT (name) DO ' \ - 'UPDATE SET version = %(version)s, built_at = %(built_at)s, updated_at = %(updated_at)s;' + UPDATE_DOCUMENT_FINISHED = ('UPDATE document SET finished_at = %s, state = %s, ' + 'file_name = %s, content_type = %s, worker_log = %s, ' + 'file_size = %s WHERE uuid = %s;') + SELECT_TEMPLATE = ('SELECT * FROM document_template ' + 'WHERE id = %s AND tenant_uuid = %s LIMIT 1;') + SELECT_TEMPLATE_FILES = ('SELECT * FROM document_template_file ' + 'WHERE document_template_id = %s AND tenant_uuid = %s;') + SELECT_TEMPLATE_ASSETS = ('SELECT * FROM document_template_asset ' + 'WHERE document_template_id = %s AND tenant_uuid = %s;') + CHECK_TABLE_EXISTS = ('SELECT EXISTS(SELECT * FROM information_schema.tables' + ' WHERE table_name = %(table_name)s)') + SELECT_MAIL_CONFIG = ('SELECT icm.* ' + 'FROM tenant_config tc JOIN instance_config_mail icm ' + 'ON tc.mail_config_uuid = icm.uuid ' + 'WHERE tc.uuid = %(tenant_uuid)s;') + UPDATE_COMPONENT_INFO = ('INSERT INTO component ' + '(name, version, built_at, created_at, updated_at) ' + 'VALUES (%(name)s, %(version)s, %(built_at)s, ' + '%(created_at)s, %(updated_at)s)' + 'ON CONFLICT (name) DO ' + 'UPDATE SET version = %(version)s, built_at = %(built_at)s, ' + 'updated_at = %(updated_at)s;') SELECT_COMPONENT_INFO = 'SELECT * FROM component WHERE name = %(name)s;' - SUM_FILE_SIZES = 'SELECT (SELECT COALESCE(SUM(file_size)::bigint, 0) ' \ - 'FROM document WHERE tenant_uuid = %(tenant_uuid)s) ' \ - '+ (SELECT COALESCE(SUM(file_size)::bigint, 0) ' \ - 'FROM document_template_asset WHERE tenant_uuid = %(tenant_uuid)s) ' \ - '+ (SELECT COALESCE(SUM(file_size)::bigint, 0) ' \ - 'FROM questionnaire_file WHERE tenant_uuid = %(tenant_uuid)s) ' \ - 'AS result;' + SUM_FILE_SIZES = ('SELECT (SELECT COALESCE(SUM(file_size)::bigint, 0) ' + 'FROM document WHERE tenant_uuid = %(tenant_uuid)s) ' + '+ (SELECT COALESCE(SUM(file_size)::bigint, 0) ' + 'FROM document_template_asset WHERE tenant_uuid = %(tenant_uuid)s) ' + '+ (SELECT COALESCE(SUM(file_size)::bigint, 0) ' + 'FROM questionnaire_file WHERE tenant_uuid = %(tenant_uuid)s) ' + 'AS result;') def __init__(self, cfg: DatabaseConfig, connect: bool = True, with_queue: bool = True): @@ -123,7 +133,7 @@ def _check_table_exists(self, table_name: str) -> bool: before=tenacity.before_log(LOG, logging.DEBUG), after=tenacity.after_log(LOG, logging.DEBUG), ) - def fetch_document(self, document_uuid: str, tenant_uuid: str) -> Optional[DBDocument]: + def fetch_document(self, document_uuid: str, tenant_uuid: str) -> DBDocument | None: with self.conn_query.new_cursor(use_dict=True) as cursor: cursor.execute( query=self.SELECT_DOCUMENT, @@ -141,7 +151,7 @@ def fetch_document(self, document_uuid: str, tenant_uuid: str) -> Optional[DBDoc before=tenacity.before_log(LOG, logging.DEBUG), after=tenacity.after_log(LOG, logging.DEBUG), ) - def fetch_tenant_config(self, tenant_uuid: str) -> Optional[DBTenantConfig]: + def fetch_tenant_config(self, tenant_uuid: str) -> DBTenantConfig | None: return self.get_tenant_config(tenant_uuid) @tenacity.retry( @@ -151,7 +161,7 @@ def fetch_tenant_config(self, tenant_uuid: str) -> Optional[DBTenantConfig]: before=tenacity.before_log(LOG, logging.DEBUG), after=tenacity.after_log(LOG, logging.DEBUG), ) - def fetch_tenant_limits(self, tenant_uuid: str) -> Optional[DBTenantLimits]: + def fetch_tenant_limits(self, tenant_uuid: str) -> DBTenantLimits | None: with self.conn_query.new_cursor(use_dict=True) as cursor: cursor.execute( query=self.SELECT_TENANT_LIMIT, @@ -171,7 +181,7 @@ def fetch_tenant_limits(self, tenant_uuid: str) -> Optional[DBTenantLimits]: ) def fetch_template( self, template_id: str, tenant_uuid: str - ) -> Optional[DBDocumentTemplate]: + ) -> DBDocumentTemplate | None: with self.conn_query.new_cursor(use_dict=True) as cursor: cursor.execute( query=self.SELECT_TEMPLATE, @@ -191,7 +201,7 @@ def fetch_template( ) def fetch_template_files( self, template_id: str, tenant_uuid: str - ) -> List[DBDocumentTemplateFile]: + ) -> list[DBDocumentTemplateFile]: with self.conn_query.new_cursor(use_dict=True) as cursor: cursor.execute( query=self.SELECT_TEMPLATE_FILES, @@ -208,7 +218,7 @@ def fetch_template_files( ) def fetch_template_assets( self, template_id: str, tenant_uuid: str - ) -> List[DBDocumentTemplateAsset]: + ) -> list[DBDocumentTemplateAsset]: with self.conn_query.new_cursor(use_dict=True) as cursor: cursor.execute( query=self.SELECT_TEMPLATE_ASSETS, @@ -223,7 +233,8 @@ def fetch_template_assets( before=tenacity.before_log(LOG, logging.DEBUG), after=tenacity.after_log(LOG, logging.DEBUG), ) - def fetch_qtn_documents(self, questionnaire_uuid: str, tenant_uuid: str) -> List[DBDocument]: + def fetch_qtn_documents(self, questionnaire_uuid: str, + tenant_uuid: str) -> list[DBDocument]: with self.conn_query.new_cursor(use_dict=True) as cursor: cursor.execute( query=self.SELECT_QTN_DOCUMENTS, @@ -238,7 +249,8 @@ def fetch_qtn_documents(self, questionnaire_uuid: str, tenant_uuid: str) -> List before=tenacity.before_log(LOG, logging.DEBUG), after=tenacity.after_log(LOG, logging.DEBUG), ) - def fetch_document_submissions(self, document_uuid: str, tenant_uuid: str) -> List[DBSubmission]: + def fetch_document_submissions(self, document_uuid: str, + tenant_uuid: str) -> list[DBSubmission]: with self.conn_query.new_cursor(use_dict=True) as cursor: cursor.execute( query=self.SELECT_DOCUMENT_SUBMISSIONS, @@ -253,7 +265,8 @@ def fetch_document_submissions(self, document_uuid: str, tenant_uuid: str) -> Li before=tenacity.before_log(LOG, logging.DEBUG), after=tenacity.after_log(LOG, logging.DEBUG), ) - def fetch_questionnaire_submissions(self, questionnaire_uuid: str, tenant_uuid: str) -> List[DBSubmission]: + def fetch_questionnaire_submissions(self, questionnaire_uuid: str, + tenant_uuid: str) -> list[DBSubmission]: with self.conn_query.new_cursor(use_dict=True) as cursor: cursor.execute( query=self.SELECT_QTN_SUBMISSIONS, @@ -268,7 +281,8 @@ def fetch_questionnaire_submissions(self, questionnaire_uuid: str, tenant_uuid: before=tenacity.before_log(LOG, logging.DEBUG), after=tenacity.after_log(LOG, logging.DEBUG), ) - def fetch_questionnaire_simple(self, questionnaire_uuid: str, tenant_uuid: str) -> DBQuestionnaireSimple: + def fetch_questionnaire_simple(self, questionnaire_uuid: str, + tenant_uuid: str) -> DBQuestionnaireSimple: with self.conn_query.new_cursor(use_dict=True) as cursor: cursor.execute( query=self.SELECT_QTN_SIMPLE, @@ -305,7 +319,7 @@ def update_document_retrieved(self, retrieved_at: datetime.datetime, query=self.UPDATE_DOCUMENT_RETRIEVED, params=( retrieved_at, - DocumentState.PROCESSING, + DocumentState.PROCESSING.value, document_uuid, ), ) @@ -319,7 +333,7 @@ def update_document_retrieved(self, retrieved_at: datetime.datetime, after=tenacity.after_log(LOG, logging.DEBUG), ) def update_document_finished( - self, finished_at: datetime.datetime, file_name: str, file_size: int, + self, *, finished_at: datetime.datetime, file_name: str, file_size: int, content_type: str, worker_log: str, document_uuid: str ) -> bool: with self.conn_query.new_cursor() as cursor: @@ -327,7 +341,7 @@ def update_document_finished( query=self.UPDATE_DOCUMENT_FINISHED, params=( finished_at, - DocumentState.FINISHED, + DocumentState.FINISHED.value, file_name, content_type, worker_log, @@ -360,7 +374,7 @@ def get_currently_used_size(self, tenant_uuid: str): before=tenacity.before_log(LOG, logging.DEBUG), after=tenacity.after_log(LOG, logging.DEBUG), ) - def get_tenant_config(self, tenant_uuid: str) -> Optional[DBTenantConfig]: + def get_tenant_config(self, tenant_uuid: str) -> DBTenantConfig | None: if not self._check_table_exists(table_name='tenant_config'): return None with self.conn_query.new_cursor(use_dict=True) as cursor: @@ -372,8 +386,8 @@ def get_tenant_config(self, tenant_uuid: str) -> Optional[DBTenantConfig]: result = cursor.fetchone() return DBTenantConfig.from_dict_row(data=result) except Exception as e: - LOG.warning(f'Could not retrieve tenant_config for tenant' - f' "{tenant_uuid}": {str(e)}') + LOG.warning('Could not retrieve tenant_config for tenant "%s": %s', + tenant_uuid, str(e)) return None @tenacity.retry( @@ -383,7 +397,7 @@ def get_tenant_config(self, tenant_uuid: str) -> Optional[DBTenantConfig]: before=tenacity.before_log(LOG, logging.DEBUG), after=tenacity.after_log(LOG, logging.DEBUG), ) - def get_mail_config(self, tenant_uuid: str) -> Optional[DBInstanceConfigMail]: + def get_mail_config(self, tenant_uuid: str) -> DBInstanceConfigMail | None: with self.conn_query.new_cursor(use_dict=True) as cursor: if not self._check_table_exists(table_name='instance_config_mail'): return None @@ -397,8 +411,8 @@ def get_mail_config(self, tenant_uuid: str) -> Optional[DBInstanceConfigMail]: return None return DBInstanceConfigMail.from_dict_row(data=result) except Exception as e: - LOG.warning(f'Could not retrieve instance_config_mail for tenant' - f' "{tenant_uuid}": {str(e)}') + LOG.warning('Could not retrieve instance_config_mail for tenant "%s": %s', + tenant_uuid, str(e)) return None @tenacity.retry( @@ -411,7 +425,7 @@ def get_mail_config(self, tenant_uuid: str) -> Optional[DBInstanceConfigMail]: def update_component_info(self, name: str, version: str, built_at: datetime.datetime): with self.conn_query.new_cursor(use_dict=True) as cursor: if not self._check_table_exists(table_name='component'): - return None + return ts_now = datetime.datetime.now(tz=datetime.UTC) try: cursor.execute( @@ -426,7 +440,7 @@ def update_component_info(self, name: str, version: str, built_at: datetime.date ) self.conn_query.connection.commit() except Exception as e: - LOG.warning(f'Could not update component info: {str(e)}') + LOG.warning('Could not update component info: %s', str(e)) @tenacity.retry( reraise=True, @@ -435,7 +449,7 @@ def update_component_info(self, name: str, version: str, built_at: datetime.date before=tenacity.before_log(LOG, logging.DEBUG), after=tenacity.after_log(LOG, logging.DEBUG), ) - def get_component_info(self, name: str) -> Optional[DBComponent]: + def get_component_info(self, name: str) -> DBComponent | None: if not self._check_table_exists(table_name='component'): return None with self.conn_query.new_cursor(use_dict=True) as cursor: @@ -449,7 +463,7 @@ def get_component_info(self, name: str) -> Optional[DBComponent]: return None return DBComponent.from_dict_row(data=result) except Exception as e: - LOG.warning(f'Could not get component info: {str(e)}') + LOG.warning('Could not get component info: %s', str(e)) return None @tenacity.retry( @@ -459,7 +473,7 @@ def get_component_info(self, name: str) -> Optional[DBComponent]: before=tenacity.before_log(LOG, logging.DEBUG), after=tenacity.after_log(LOG, logging.DEBUG), ) - def execute_queries(self, queries: Iterable[str]): + def execute_queries(self, queries: typing.Iterable[str]): with self.conn_query.new_cursor(use_dict=True) as cursor: for query in queries: cursor.execute(query=query) @@ -486,7 +500,7 @@ def __init__(self, name: str, dsn: str, timeout=30000, autocommit=False): connect_timeout=timeout, ) self.autocommit = autocommit - self._connection = None # type: Optional[psycopg.Connection] + self._connection: psycopg.Connection | None = None @tenacity.retry( reraise=True, @@ -496,11 +510,15 @@ def __init__(self, name: str, dsn: str, timeout=30000, autocommit=False): after=tenacity.after_log(LOG, logging.DEBUG), ) def _connect_db(self): - LOG.info(f'Creating connection to PostgreSQL database "{self.name}"') + LOG.info('Creating connection to PostgreSQL database "%s"', self.name) try: - connection = psycopg.connect(conninfo=self.dsn, autocommit=self.autocommit) # type: psycopg.Connection + connection: psycopg.Connection = psycopg.connect( + conninfo=self.dsn, + autocommit=self.autocommit, + ) except Exception as e: - LOG.error(f'Failed to connect to PostgreSQL database "{self.name}": {str(e)}') + LOG.error('Failed to connect to PostgreSQL database "%s": %s', + self.name, str(e)) raise e # test connection cursor = connection.cursor() @@ -508,7 +526,7 @@ def _connect_db(self): result = cursor.fetchone() if result is None: raise RuntimeError('Failed to verify DB connection') - LOG.debug(f'DB connection verified (result={result[0]})') + LOG.debug('DB connection verified (result=%s)', result[0]) cursor.close() connection.commit() self._connection = connection @@ -534,6 +552,6 @@ def reset(self): def close(self): if self._connection: - LOG.info(f'Closing connection to PostgreSQL database "{self.name}"') + LOG.info('Closing connection to PostgreSQL database "%s"', self.name) self._connection.close() self._connection = None diff --git a/packages/dsw-database/dsw/database/model.py b/packages/dsw-database/dsw/database/model.py index 2461ee80..9a14a6ca 100644 --- a/packages/dsw-database/dsw/database/model.py +++ b/packages/dsw-database/dsw/database/model.py @@ -1,21 +1,20 @@ import dataclasses import datetime +import enum import json -from typing import Optional - NULL_UUID = '00000000-0000-0000-0000-000000000000' -class DocumentState: +class DocumentState(enum.Enum): QUEUED = 'QueuedDocumentState' PROCESSING = 'InProgressDocumentState' FAILED = 'ErrorDocumentState' FINISHED = 'DoneDocumentState' -class DocumentTemplatePhase: +class DocumentTemplatePhase(enum.Enum): RELEASED = 'ReleasedTemplatePhase' DEPRECATED = 'DeprecatedTemplatePhase' DRAFT = 'DraftTemplatePhase' @@ -55,8 +54,8 @@ class DBDocument: content_type: str worker_log: str created_by: str - retrieved_at: Optional[datetime.datetime] - finished_at: Optional[datetime.datetime] + retrieved_at: datetime.datetime | None + finished_at: datetime.datetime | None created_at: datetime.datetime tenant_uuid: str file_size: int @@ -191,11 +190,11 @@ class PersistentCommand: component: str function: str body: dict - last_error_message: Optional[str] + last_error_message: str | None attempts: int max_attempts: int tenant_uuid: str - created_by: Optional[str] + created_by: str | None created_at: datetime.datetime updated_at: datetime.datetime @@ -220,28 +219,28 @@ def from_dict_row(data: dict): @dataclasses.dataclass class DBTenantConfig: uuid: str - organization: Optional[dict] - authentication: Optional[dict] - privacy_and_support: Optional[dict] - dashboard: Optional[dict] - look_and_feel: Optional[dict] - registry: Optional[dict] - knowledge_model: Optional[dict] - questionnaire: Optional[dict] - submission: Optional[dict] - owl: Optional[dict] - mail_config_uuid: Optional[str] + organization: dict | None + authentication: dict | None + privacy_and_support: dict | None + dashboard: dict | None + look_and_feel: dict | None + registry: dict | None + knowledge_model: dict | None + questionnaire: dict | None + submission: dict | None + owl: dict | None + mail_config_uuid: str | None created_at: datetime.datetime updated_at: datetime.datetime @property - def app_title(self) -> Optional[str]: + def app_title(self) -> str | None: if self.look_and_feel is None: return None return self.look_and_feel.get('appTitle', None) @property - def support_email(self) -> Optional[str]: + def support_email(self) -> str | None: if self.privacy_and_support is None: return None return self.privacy_and_support.get('supportEmail', None) @@ -269,7 +268,7 @@ def from_dict_row(data: dict): @dataclasses.dataclass class DBTenantLimits: tenant_uuid: str - storage: Optional[int] + storage: int | None @staticmethod def from_dict_row(data: dict): @@ -390,19 +389,19 @@ class DBInstanceConfigMail: uuid: str enabled: bool provider: str - sender_name: Optional[str] - sender_email: Optional[str] - smtp_host: Optional[str] - smtp_port: Optional[int] - smtp_security: Optional[str] - smtp_username: Optional[str] - smtp_password: Optional[str] - aws_access_key_id: Optional[str] - aws_secret_access_key: Optional[str] - aws_region: Optional[str] - rate_limit_window: Optional[int] - rate_limit_count: Optional[int] - timeout: Optional[int] + sender_name: str | None + sender_email: str | None + smtp_host: str | None + smtp_port: int | None + smtp_security: str | None + smtp_username: str | None + smtp_password: str | None + aws_access_key_id: str | None + aws_secret_access_key: str | None + aws_region: str | None + rate_limit_window: int | None + rate_limit_count: int | None + timeout: int | None @staticmethod def from_dict_row(data: dict): diff --git a/packages/dsw-database/pyproject.toml b/packages/dsw-database/pyproject.toml index 92f03e48..3b44fd84 100644 --- a/packages/dsw-database/pyproject.toml +++ b/packages/dsw-database/pyproject.toml @@ -16,12 +16,12 @@ classifiers = [ 'Development Status :: 5 - Production/Stable', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python', - 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Topic :: Database', 'Topic :: Utilities', ] -requires-python = '>=3.10, <4' +requires-python = '>=3.11, <4' dependencies = [ 'psycopg[binary]', 'tenacity', diff --git a/packages/dsw-document-worker/dsw/document_worker/__main__.py b/packages/dsw-document-worker/dsw/document_worker/__main__.py index 9c388806..e4ffebab 100644 --- a/packages/dsw-document-worker/dsw/document_worker/__main__.py +++ b/packages/dsw-document-worker/dsw/document_worker/__main__.py @@ -1,4 +1,5 @@ from .cli import main from .consts import PROG_NAME +# pylint: disable-next=no-value-for-parameter main(prog_name=PROG_NAME) diff --git a/packages/dsw-document-worker/dsw/document_worker/cli.py b/packages/dsw-document-worker/dsw/document_worker/cli.py index 700cddd5..82ec57da 100644 --- a/packages/dsw-document-worker/dsw/document_worker/cli.py +++ b/packages/dsw-document-worker/dsw/document_worker/cli.py @@ -1,7 +1,8 @@ -import click import pathlib +import sys +import typing -from typing import IO, Optional +import click from dsw.config.parser import MissingConfigurationError from dsw.config.sentry import SentryReporter @@ -9,13 +10,14 @@ from .config import DocumentWorkerConfig, DocumentWorkerConfigParser from .consts import (VERSION, VAR_APP_CONFIG_PATH, VAR_WORKDIR_PATH, DEFAULT_ENCODING) +from .worker import DocumentWorker def load_config_str(config_str: str) -> DocumentWorkerConfig: parser = DocumentWorkerConfigParser() if not parser.can_read(config_str): click.echo('Error: Cannot parse config file', err=True) - exit(1) + sys.exit(1) try: parser.read_string(config_str) @@ -24,14 +26,15 @@ def load_config_str(config_str: str) -> DocumentWorkerConfig: click.echo('Error: Missing configuration', err=True) for missing_item in e.missing: click.echo(f' - {missing_item}', err=True) - exit(1) + sys.exit(1) config = parser.config config.log.apply() return config -def validate_config(ctx, param, value: Optional[IO]): +# pylint: disable-next=unused-argument +def validate_config(ctx, param, value: typing.IO | None): content = '' if value is not None: content = value.read() @@ -46,17 +49,16 @@ def validate_config(ctx, param, value: Optional[IO]): type=click.File('r', encoding=DEFAULT_ENCODING)) @click.argument('workdir', envvar=VAR_WORKDIR_PATH,) def main(config: DocumentWorkerConfig, workdir: str): - from .worker import DocumentWorker config.log.apply() workdir_path = pathlib.Path(workdir) workdir_path.mkdir(parents=True, exist_ok=True) if not workdir_path.is_dir(): click.echo(f'Workdir {workdir_path.as_posix()} is not usable') - exit(2) + sys.exit(2) try: worker = DocumentWorker(config, workdir_path) worker.run() except Exception as e: SentryReporter.capture_exception(e) click.echo(f'Ended with error: {e}') - exit(2) + sys.exit(2) diff --git a/packages/dsw-document-worker/dsw/document_worker/config.py b/packages/dsw-document-worker/dsw/document_worker/config.py index 4e0ab935..c471e15b 100644 --- a/packages/dsw-document-worker/dsw/document_worker/config.py +++ b/packages/dsw-document-worker/dsw/document_worker/config.py @@ -1,5 +1,5 @@ +import dataclasses import shlex -from typing import List, Optional from dsw.config import DSWConfigParser from dsw.config.keys import ConfigKey, ConfigKeys, ConfigKeysContainer, \ @@ -10,6 +10,7 @@ from .consts import DocumentNamingStrategy +# pylint: disable-next=too-few-public-methods class _DocumentsKeys(ConfigKeysContainer): naming_strategy = ConfigKey( yaml_path=['documents', 'naming', 'strategy'], @@ -19,6 +20,7 @@ class _DocumentsKeys(ConfigKeysContainer): ) +# pylint: disable-next=too-few-public-methods class _ExperimentalKeys(ConfigKeysContainer): job_timeout = ConfigKey( yaml_path=['experimental', 'jobTimeout'], @@ -34,6 +36,7 @@ class _ExperimentalKeys(ConfigKeysContainer): ) +# pylint: disable-next=too-few-public-methods class _CommandPandocKeys(ConfigKeysContainer): executable = ConfigKey( yaml_path=['externals', 'pandoc', 'executable'], @@ -55,44 +58,43 @@ class _CommandPandocKeys(ConfigKeysContainer): ) +# pylint: disable-next=too-few-public-methods class DocWorkerConfigKeys(ConfigKeys): documents = _DocumentsKeys experimental = _ExperimentalKeys cmd_pandoc = _CommandPandocKeys +@dataclasses.dataclass class DocumentsConfig(ConfigModel): + naming_strategy: str def __init__(self, naming_strategy: str): self.naming_strategy = DocumentNamingStrategy.get(naming_strategy) +@dataclasses.dataclass class ExperimentalConfig(ConfigModel): - - def __init__(self, job_timeout: Optional[int], - max_doc_size: Optional[float]): - self.job_timeout = job_timeout - self.max_doc_size = max_doc_size + job_timeout: int | None + max_doc_size: int | None +@dataclasses.dataclass class CommandConfig: - - def __init__(self, executable: str, args: str, timeout: float): - self.executable = executable - self.args = args - self.timeout = timeout + executable: str + args: str + timeout: float @property - def command(self) -> List[str]: + def command(self) -> list[str]: return [self.executable] + shlex.split(self.args) +@dataclasses.dataclass class TemplateRequestsConfig: - - def __init__(self, enabled: bool, limit: int, timeout: int): - self.enabled = enabled - self.limit = limit - self.timeout = timeout + enabled: bool + limit: int + timeout: int @staticmethod def load(data: dict): @@ -103,14 +105,12 @@ def load(data: dict): ) +@dataclasses.dataclass class TemplateConfig: - - def __init__(self, ids: List[str], requests: TemplateRequestsConfig, - secrets: dict[str, str], send_sentry: bool): - self.ids = ids - self.requests = requests - self.secrets = secrets - self.send_sentry = send_sentry + ids: list[str] + requests: TemplateRequestsConfig + secrets: dict[str, str] + send_sentry: bool @staticmethod def load(data: dict): @@ -124,12 +124,11 @@ def load(data: dict): ) +@dataclasses.dataclass class TemplatesConfig: + templates: list[TemplateConfig] - def __init__(self, templates: List[TemplateConfig]): - self.templates = templates - - def get_config(self, template_id: str) -> Optional[TemplateConfig]: + def get_config(self, template_id: str) -> TemplateConfig | None: for template in self.templates: if any((template_id.startswith(prefix) for prefix in template.ids)): @@ -137,22 +136,18 @@ def get_config(self, template_id: str) -> Optional[TemplateConfig]: return None +@dataclasses.dataclass class DocumentWorkerConfig: - - def __init__(self, db: DatabaseConfig, s3: S3Config, log: LoggingConfig, - doc: DocumentsConfig, pandoc: CommandConfig, - templates: TemplatesConfig, experimental: ExperimentalConfig, - cloud: CloudConfig, sentry: SentryConfig, general: GeneralConfig): - self.db = db - self.s3 = s3 - self.log = log - self.doc = doc - self.pandoc = pandoc - self.templates = templates - self.experimental = experimental - self.cloud = cloud - self.sentry = sentry - self.general = general + db: DatabaseConfig + s3: S3Config + log: LoggingConfig + doc: DocumentsConfig + pandoc: CommandConfig + templates: TemplatesConfig + experimental: ExperimentalConfig + cloud: CloudConfig + sentry: SentryConfig + general: GeneralConfig def __str__(self): return f'DocumentWorkerConfig\n' \ @@ -170,7 +165,6 @@ def __str__(self): class DocumentWorkerConfigParser(DSWConfigParser): - TEMPLATES_SECTION = 'templates' def __init__(self): diff --git a/packages/dsw-document-worker/dsw/document_worker/consts.py b/packages/dsw-document-worker/dsw/document_worker/consts.py index 1f22cefc..a0668389 100644 --- a/packages/dsw-document-worker/dsw/document_worker/consts.py +++ b/packages/dsw-document-worker/dsw/document_worker/consts.py @@ -1,3 +1,4 @@ +# pylint: disable=too-few-public-methods CMD_CHANNEL = 'doc_worker' CMD_COMPONENT = 'doc_worker' COMPONENT_NAME = 'Document Worker' @@ -49,5 +50,5 @@ class DocumentNamingStrategy: } @classmethod - def get(cls, name: str): + def get(cls, name: str) -> str: return cls._NAMES.get(name.lower(), cls._DEFAULT) diff --git a/packages/dsw-document-worker/dsw/document_worker/context.py b/packages/dsw-document-worker/dsw/document_worker/context.py index 2862dfcc..cc67d453 100644 --- a/packages/dsw-document-worker/dsw/document_worker/context.py +++ b/packages/dsw-document-worker/dsw/document_worker/context.py @@ -1,11 +1,10 @@ +import dataclasses import pathlib -from typing import Optional, TYPE_CHECKING +from dsw.database import Database +from dsw.storage import S3Storage -if TYPE_CHECKING: - from .config import DocumentWorkerConfig - from dsw.database import Database - from dsw.storage import S3Storage +from .config import DocumentWorkerConfig class ContextNotInitializedError(RuntimeError): @@ -14,19 +13,17 @@ def __init__(self): super().__init__('Context cannot be retrieved, not initialized') +@dataclasses.dataclass class AppContext: - - def __init__(self, db, s3, cfg, workdir): - self.db = db # type: Database - self.s3 = s3 # type: S3Storage - self.cfg = cfg # type: DocumentWorkerConfig - self.workdir = workdir # type: pathlib.Path + db: Database + s3: S3Storage + cfg: DocumentWorkerConfig + workdir: pathlib.Path +@dataclasses.dataclass class JobContext: - - def __init__(self, trace_id: str): - self.trace_id = trace_id + trace_id: str class _Context: @@ -49,7 +46,7 @@ def reset_ids(self): class Context: - _instance = None # type: Optional[_Context] + _instance: _Context | None = None @classmethod def get(cls) -> _Context: diff --git a/packages/dsw-document-worker/dsw/document_worker/conversions.py b/packages/dsw-document-worker/dsw/document_worker/conversions.py index efe79a01..db30da50 100644 --- a/packages/dsw-document-worker/dsw/document_worker/conversions.py +++ b/packages/dsw-document-worker/dsw/document_worker/conversions.py @@ -1,10 +1,11 @@ import logging import os import pathlib -import rdflib import shlex import subprocess +import rdflib + from .config import DocumentWorkerConfig from .consts import EXIT_SUCCESS, DEFAULT_ENCODING from .documents import FileFormat, FileFormats @@ -16,14 +17,12 @@ def run_conversion(*, args: list, workdir: str, input_data: bytes, name: str, source_format: FileFormat, target_format: FileFormat, timeout=None) -> bytes: command = ' '.join(args) - LOG.info(f'Calling "{command}" to convert from {source_format} to {target_format}') - p = subprocess.Popen(args, - cwd=workdir, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - stdout, stderr = p.communicate(input=input_data, timeout=timeout) - exit_code = p.returncode + LOG.info('Calling "%s" to convert from %s to %s', + command, source_format, target_format) + with subprocess.Popen(args, cwd=workdir, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE) as proc: + stdout, stderr = proc.communicate(input=input_data, timeout=timeout) + exit_code = proc.returncode if exit_code != EXIT_SUCCESS: raise FormatConversionException( name, source_format, target_format, @@ -49,7 +48,8 @@ class Pandoc: FILTERS_PATH = pathlib.Path(os.getenv('PANDOC_FILTERS', '/pandoc/filters')) TEMPLATES_PATH = pathlib.Path(os.getenv('PANDOC_TEMPLATES', '/pandoc/templates')) - def __init__(self, config: DocumentWorkerConfig, filter_names: list[str], template_name: str | None): + def __init__(self, config: DocumentWorkerConfig, filter_names: list[str], + template_name: str | None): self.config = config self.filter_names = filter_names self.template_name = template_name @@ -76,7 +76,7 @@ def _extra_args(self): args.extend(['--filter', str(self.FILTERS_PATH / filter_name)]) return shlex.split(' '.join(args)) - def __call__(self, source_format: FileFormat, target_format: FileFormat, + def __call__(self, *, source_format: FileFormat, target_format: FileFormat, data: bytes, metadata: dict, workdir: str) -> bytes: args = ['-f', source_format.name, '-t', target_format.name, '-o', '-'] config_args = shlex.split(self.config.pandoc.args) @@ -112,7 +112,7 @@ class RdfLibConvert: def __init__(self, config: DocumentWorkerConfig): self.config = config - def __call__(self, source_format: FileFormat, target_format: FileFormat, + def __call__(self, *, source_format: FileFormat, target_format: FileFormat, data: bytes, metadata: dict) -> bytes: g = rdflib.Dataset() g.parse( diff --git a/packages/dsw-document-worker/dsw/document_worker/documents.py b/packages/dsw-document-worker/dsw/document_worker/documents.py index 17e6ed47..971f36bb 100644 --- a/packages/dsw-document-worker/dsw/document_worker/documents.py +++ b/packages/dsw-document-worker/dsw/document_worker/documents.py @@ -1,8 +1,8 @@ +import typing + import pathvalidate import slugify -from typing import Optional - from dsw.database.database import DBDocument from .consts import DEFAULT_ENCODING, DocumentNamingStrategy @@ -63,8 +63,16 @@ class FileFormats: TAR_GZIP = FileFormat('gzip', 'application/gzip', 'tar.gz') TAR_BZIP2 = FileFormat('bzip2', 'application/x-bzip2', 'tar.bz2') TAR_LZMA = FileFormat('lzma', 'application/x-lzma', 'tar.xz') - XLSX = FileFormat('xlsx', 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', 'xlsx') - XLSM = FileFormat('xlsm', ' application/vnd.ms-excel.sheet.macroEnabled.12', 'xlsm') + XLSX = FileFormat( + 'xlsx', + 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + 'xlsx', + ) + XLSM = FileFormat( + 'xlsm', + 'application/vnd.ms-excel.sheet.macroEnabled.12', + 'xlsm', + ) @staticmethod def get(name: str): @@ -107,7 +115,7 @@ def get(name: str): class DocumentFile: def __init__(self, file_format: FileFormat, content: bytes, - encoding: Optional[str] = None): + encoding: str | None = None): self.file_format = file_format self._content = content self.byte_size = len(content) @@ -164,8 +172,8 @@ def _name_slugify(document: DBDocument) -> str: class DocumentNameGiver: - _FALLBACK = _name_uuid - _STRATEGIES = { + _FALLBACK: typing.Callable[[DBDocument], str] = _name_uuid + _STRATEGIES: dict[str, typing.Callable[[DBDocument], str]] = { DocumentNamingStrategy.UUID: _name_uuid, DocumentNamingStrategy.SANITIZE: _name_sanitize, DocumentNamingStrategy.SLUGIFY: _name_slugify, diff --git a/packages/dsw-document-worker/dsw/document_worker/exceptions.py b/packages/dsw-document-worker/dsw/document_worker/exceptions.py index 2baca785..cc2ff857 100644 --- a/packages/dsw-document-worker/dsw/document_worker/exceptions.py +++ b/packages/dsw-document-worker/dsw/document_worker/exceptions.py @@ -11,8 +11,7 @@ def __str__(self): def log_message(self): if self.exc is None: return self.msg - else: - return f'{self.msg}: [{type(self.exc).__name__}] {str(self.exc)}' + return f'{self.msg}: [{type(self.exc).__name__}] {str(self.exc)}' def db_message(self): if self.exc is None: diff --git a/packages/dsw-document-worker/dsw/document_worker/handlers.py b/packages/dsw-document-worker/dsw/document_worker/handlers.py index be7f439b..714d412d 100644 --- a/packages/dsw-document-worker/dsw/document_worker/handlers.py +++ b/packages/dsw-document-worker/dsw/document_worker/handlers.py @@ -2,14 +2,15 @@ import pathlib from .cli import load_config_str -from .consts import VAR_APP_CONFIG_PATH, VAR_WORKDIR_PATH +from .consts import VAR_APP_CONFIG_PATH, VAR_WORKDIR_PATH, DEFAULT_ENCODING from .worker import DocumentWorker +# pylint: disable-next=unused-argument def lambda_handler(event, context): config_path = pathlib.Path(os.getenv(VAR_APP_CONFIG_PATH, '/var/task/application.yml')) workdir_path = pathlib.Path(os.getenv(VAR_WORKDIR_PATH, '/var/task/templates')) - config = load_config_str(config_path.read_text()) + config = load_config_str(config_path.read_text(encoding=DEFAULT_ENCODING)) doc_worker = DocumentWorker(config, workdir_path) doc_worker.run_once() diff --git a/packages/dsw-document-worker/dsw/document_worker/limits.py b/packages/dsw-document-worker/dsw/document_worker/limits.py index 17988b76..77cb5455 100644 --- a/packages/dsw-document-worker/dsw/document_worker/limits.py +++ b/packages/dsw-document-worker/dsw/document_worker/limits.py @@ -2,8 +2,6 @@ from .exceptions import JobException from .utils import byte_size_format -from typing import Optional - class LimitsEnforcer: @@ -20,7 +18,7 @@ def check_doc_size(job_id: str, doc_size: int): @staticmethod def check_size_usage(job_id: str, doc_size: int, - used_size: int, limit_size: Optional[int]): + used_size: int, limit_size: int | None): limit_size = abs(limit_size) if limit_size is not None else None if limit_size is None or doc_size + used_size < limit_size: diff --git a/packages/dsw-document-worker/dsw/document_worker/model/context.py b/packages/dsw-document-worker/dsw/document_worker/model/context.py index 06da8b53..78a3aa28 100644 --- a/packages/dsw-document-worker/dsw/document_worker/model/context.py +++ b/packages/dsw-document-worker/dsw/document_worker/model/context.py @@ -1,12 +1,13 @@ -# TODO: move to dsw-models +# pylint: disable=too-many-lines, unused-argument, too-many-arguments, +import abc import datetime -import dateutil.parser as dp +import typing -from typing import Optional, Iterable, Union, ItemsView +import dateutil.parser as dp from ..consts import NULL_UUID -AnnotationsT = dict[str, Union[str, list[str]]] +AnnotationsT = dict[str, str | list[str]] TODO_LABEL_UUID = "615b9028-5e3f-414f-b245-12d2ae2eeb20" @@ -15,12 +16,12 @@ def _datetime(timestamp: str) -> datetime.datetime: def _load_annotations(annotations: list[dict[str, str]]) -> AnnotationsT: - result = {} # type: AnnotationsT - semi_result = {} # type: dict[str, list[str]] + result: AnnotationsT = {} + semi_result: dict[str, list[str]] = {} for item in annotations: key = item.get('key', '') value = item.get('value', '') - if key in semi_result.keys(): + if key in semi_result: semi_result[key].append(value) else: semi_result[key] = [value] @@ -32,14 +33,94 @@ def _load_annotations(annotations: list[dict[str, str]]) -> AnnotationsT: return result +class SimpleAuthor: + + def __init__(self, *, uuid: str, first_name: str, last_name: str, + image_url: str | None, gravatar_hash: str | None): + self.uuid = uuid + self.first_name = first_name + self.last_name = last_name + self.image_url = image_url + self.gravatar_hash = gravatar_hash + + @staticmethod + def load(data: dict | None, **options): + if data is None: + return None + return SimpleAuthor( + uuid=data['uuid'], + first_name=data['firstName'], + last_name=data['lastName'], + image_url=data['imageUrl'], + gravatar_hash=data['gravatarHash'], + ) + + +class User: + + def __init__(self, *, uuid: str, first_name: str, last_name: str, email: str, + role: str, created_at: datetime.datetime, updated_at: datetime.datetime, + affiliation: str | None, permissions: list[str], sources: list[str], + image_url: str | None): + self.uuid = uuid + self.first_name = first_name + self.last_name = last_name + self.email = email + self.role = role + self.image_url = image_url + self.affiliation = affiliation + self.permissions = permissions + self.sources = sources + self.created_at = created_at + self.updated_at = updated_at + + @staticmethod + def load(data: dict, **options): + if data is None: + return None + return User( + uuid=data['uuid'], + first_name=data['firstName'], + last_name=data['lastName'], + email=data['email'], + role=data['role'], + image_url=data['imageUrl'], + affiliation=data['affiliation'], + permissions=data['permissions'], + sources=data['sources'], + created_at=_datetime(data['createdAt']), + updated_at=_datetime(data['updatedAt']), + ) + + +class Organization: + + def __init__(self, *, org_id: str, name: str, description: str | None, + affiliations: list[str]): + self.id = org_id + self.name = name + self.description = description + self.affiliations = affiliations + + @staticmethod + def load(data: dict, **options): + return Organization( + org_id=data['organizationId'], + name=data['name'], + description=data['description'], + affiliations=data['affiliations'], + ) + + class Tag: - def __init__(self, uuid, name, description, color, annotations): - self.uuid = uuid # type: str - self.name = name # type: str - self.description = description # type: Optional[str] - self.color = color # type: str - self.annotations = annotations # type: AnnotationsT + def __init__(self, *, uuid: str, name: str, description: str | None, + color: str, annotations: AnnotationsT): + self.uuid = uuid + self.name = name + self.description = description + self.color = color + self.annotations = annotations @property def a(self): @@ -63,17 +144,18 @@ def load(data: dict, **options): class ResourceCollection: - def __init__(self, uuid, title, page_uuids, annotations): - self.uuid = uuid # type: str - self.title = title # type: str - self.page_uuids = page_uuids # type: list[str] - self.pages = list() # type: list[ResourcePage] - self.annotations = annotations # type: AnnotationsT + def __init__(self, *, uuid: str, title: str, page_uuids: list[str], + annotations: AnnotationsT): + self.uuid = uuid + self.title = title + self.page_uuids = page_uuids + self.annotations = annotations + self.pages: list[ResourcePage] = [] - def _resolve_links(self, ctx): + def resolve_links(self, ctx): self.pages = [ctx.e.resource_pages[key] for key in self.page_uuids - if key in ctx.e.resource_pages.keys()] + if key in ctx.e.resource_pages] for page in self.pages: page.collection = self @@ -93,12 +175,14 @@ def load(data: dict, **options): class ResourcePage: - def __init__(self, uuid, title, content, annotations): - self.uuid = uuid # type: str - self.title = title # type: str - self.content = content # type: str - self.collection = None # type: Optional[ResourceCollection] - self.annotations = annotations # type: AnnotationsT + def __init__(self, *, uuid: str, title: str, content: str, + annotations: AnnotationsT): + self.uuid = uuid + self.title = title + self.content = content + self.annotations = annotations + + self.collection: ResourceCollection | None = None @property def a(self): @@ -114,24 +198,25 @@ def load(data: dict, **options): ) -class Integration: +class Integration(abc.ABC): - def __init__(self, uuid, name, logo, integration_id, item_url, props, - integration_type, annotations): - self.uuid = uuid # type: str - self.name = name # type: str - self.id = integration_id # type: str - self.item_url = item_url # type: Optional[str] - self.logo = logo # type: Optional[str] - self.props = props # type: list[str] - self.type = integration_type # type: str - self.annotations = annotations # type: AnnotationsT + def __init__(self, *, uuid: str, name: str, integration_id: str, + item_url: str | None, logo: str | None, props: list[str], + integration_type: str, annotations: AnnotationsT): + self.uuid = uuid + self.name = name + self.id = integration_id + self.item_url = item_url + self.logo = logo + self.props = props + self.type = integration_type + self.annotations = annotations @property def a(self): return self.annotations - def item(self, item_id: str) -> Optional[str]: + def item(self, item_id: str) -> str | None: if self.item_url is None: return None return self.item_url.replace('${id}', item_id) @@ -141,12 +226,20 @@ def __eq__(self, other): return False return other.uuid == self.uuid + @staticmethod + @abc.abstractmethod + def load(data: dict, **options): + pass + class ApiIntegration(Integration): - def __init__(self, uuid, name, logo, integration_id, item_url, props, rq_body, - rq_headers, rq_method, rq_url, rs_list_field, rs_item_id, - rs_item_template, annotations): + def __init__(self, *, uuid: str, name: str, integration_id: str, + item_url: str | None, logo: str | None, + props: list[str], rq_body: str, rq_method: str, + rq_headers: dict[str, str], rq_url: str, rs_list_field: str | None, + rs_item_id: str | None, rs_item_template: str, + annotations: AnnotationsT): super().__init__( uuid=uuid, name=name, @@ -157,13 +250,13 @@ def __init__(self, uuid, name, logo, integration_id, item_url, props, rq_body, annotations=annotations, integration_type='ApiIntegration', ) - self.rq_body = rq_body # type: str - self.rq_method = rq_method # type: str - self.rq_url = rq_url # type: str - self.rq_headers = rq_headers # type: dict[str, str] - self.rs_list_field = rs_list_field # type: Optional[str] - self.rs_item_id = rs_item_id # type: Optional[str] - self.rs_item_template = rs_item_template # type: str + self.rq_body = rq_body + self.rq_method = rq_method + self.rq_url = rq_url + self.rq_headers = rq_headers + self.rs_list_field = rs_list_field + self.rs_item_id = rs_item_id + self.rs_item_template = rs_item_template @staticmethod def default(): @@ -206,8 +299,9 @@ def load(data: dict, **options): class WidgetIntegration(Integration): - def __init__(self, uuid, name, logo, integration_id, item_url, props, - widget_url, annotations): + def __init__(self, *, uuid: str, name: str, integration_id: str, + item_url: str | None, logo: str | None, props: list[str], + widget_url: str, annotations: AnnotationsT): super().__init__( uuid=uuid, name=name, @@ -218,7 +312,7 @@ def __init__(self, uuid, name, logo, integration_id, item_url, props, annotations=annotations, integration_type='WidgetIntegration', ) - self.widget_url = widget_url # type: str + self.widget_url = widget_url @staticmethod def load(data: dict, **options): @@ -236,12 +330,13 @@ def load(data: dict, **options): class Phase: - def __init__(self, uuid, title, description, annotations, order=0): - self.uuid = uuid # type: str - self.title = title # type: str - self.description = description # type: Optional[str] - self.order = order # type: int - self.annotations = annotations # type: AnnotationsT + def __init__(self, *, uuid: str, title: str, description: str | None, + annotations: AnnotationsT, order: int = 0): + self.uuid = uuid + self.title = title + self.description = description + self.order = order + self.annotations = annotations @property def a(self): @@ -273,12 +368,13 @@ def load(data: dict, **options): class Metric: - def __init__(self, uuid, title, description, abbreviation, annotations): - self.uuid = uuid # type: str - self.title = title # type: str - self.description = description # type: Optional[str] - self.abbreviation = abbreviation # type: str - self.annotations = annotations # type: AnnotationsT + def __init__(self, *, uuid: str, title: str, description: str | None, + abbreviation: str, annotations: AnnotationsT): + self.uuid = uuid + self.title = title + self.description = description + self.abbreviation = abbreviation + self.annotations = annotations @property def a(self): @@ -302,14 +398,15 @@ def load(data: dict, **options): class MetricMeasure: - def __init__(self, measure, weight, metric_uuid): - self.measure = measure # type: float - self.weight = weight # type: float - self.metric_uuid = metric_uuid # type: str - self.metric = None # type: Optional[Metric] + def __init__(self, *, measure: float, weight: float, metric_uuid: str): + self.measure = measure + self.weight = weight + self.metric_uuid = metric_uuid + + self.metric: Metric | None = None - def _resolve_links(self, ctx): - if self.metric_uuid in ctx.e.metrics.keys(): + def resolve_links(self, ctx): + if self.metric_uuid in ctx.e.metrics: self.metric = ctx.e.metrics[self.metric_uuid] @staticmethod @@ -321,12 +418,12 @@ def load(data: dict, **options): ) -class Reference: +class Reference(abc.ABC): - def __init__(self, uuid, ref_type, annotations): - self.uuid = uuid # type: str - self.type = ref_type # type: str - self.annotations = annotations # type: AnnotationsT + def __init__(self, *, uuid: str, ref_type: str, annotations: AnnotationsT): + self.uuid = uuid + self.type = ref_type + self.annotations = annotations @property def a(self): @@ -337,16 +434,26 @@ def __eq__(self, other): return False return other.uuid == self.uuid - def _resolve_links(self, ctx): + def resolve_links(self, ctx): + pass + + @staticmethod + @abc.abstractmethod + def load(data: dict, **options): pass class CrossReference(Reference): - def __init__(self, uuid, target_uuid, description, annotations): - super().__init__(uuid, 'CrossReference', annotations) - self.target_uuid = target_uuid # type: str - self.description = description # type: str + def __init__(self, *, uuid: str, target_uuid: str, description: str, + annotations: AnnotationsT): + super().__init__( + uuid=uuid, + ref_type='CrossReference', + annotations=annotations, + ) + self.target_uuid = target_uuid + self.description = description @staticmethod def load(data: dict, **options): @@ -360,10 +467,15 @@ def load(data: dict, **options): class URLReference(Reference): - def __init__(self, uuid, label, url, annotations): - super().__init__(uuid, 'URLReference', annotations) - self.label = label # type: str - self.url = url # type: str + def __init__(self, *, uuid: str, label: str, url: str, + annotations: AnnotationsT): + super().__init__( + uuid=uuid, + ref_type='URLReference', + annotations=annotations, + ) + self.label = label + self.url = url @staticmethod def load(data: dict, **options): @@ -377,13 +489,19 @@ def load(data: dict, **options): class ResourcePageReference(Reference): - def __init__(self, uuid, resource_page_uuid, annotations): - super().__init__(uuid, 'ResourcePageReference', annotations) - self.resource_page_uuid = resource_page_uuid # type: Optional[str] - self.resource_page = None # type: Optional[ResourcePage] + def __init__(self, *, uuid: str, resource_page_uuid: str | None, + annotations: AnnotationsT): + super().__init__( + uuid=uuid, + ref_type='ResourcePageReference', + annotations=annotations, + ) + self.resource_page_uuid = resource_page_uuid + + self.resource_page: ResourcePage | None = None - def _resolve_links(self, ctx): - if self.resource_page_uuid in ctx.e.resource_pages.keys(): + def resolve_links(self, ctx): + if self.resource_page_uuid in ctx.e.resource_pages: self.resource_page = ctx.e.resource_pages[self.resource_page_uuid] @staticmethod @@ -397,11 +515,12 @@ def load(data: dict, **options): class Expert: - def __init__(self, uuid, name, email, annotations): - self.uuid = uuid # type: str - self.name = name # type: str - self.email = email # type: str - self.annotations = annotations # type: AnnotationsT + def __init__(self, *, uuid: str, name: str, email: str, + annotations: AnnotationsT): + self.uuid = uuid + self.name = name + self.email = email + self.annotations = annotations @property def a(self): @@ -422,40 +541,54 @@ def load(data: dict, **options): ) -class Reply: +class Reply(abc.ABC): - def __init__(self, path, created_at, created_by, reply_type): - self.path = path # type: str - self.fragments = path.split('.') # type: list[str] - self.created_at = created_at # type: datetime.datetime - self.created_by = created_by # type: Optional[SimpleAuthor] - self.type = reply_type # type: str - self.question = None # type: Optional[Question] + def __init__(self, *, path: str, created_at: datetime.datetime, + created_by: SimpleAuthor | None, reply_type: str): + self.path = path + self.created_at = created_at + self.created_by = created_by + self.type = reply_type + + self.question: Question | None = None + self.fragments: list[str] = path.split('.') - def _resolve_links_parent(self, ctx): + def resolve_links_parent(self, ctx): question_uuid = self.fragments[-1] - if question_uuid in ctx.e.questions.keys(): + if question_uuid in ctx.e.questions: self.question = ctx.e.questions.get(question_uuid, None) if self.question is not None: self.question.replies[self.path] = self - def _resolve_links(self, ctx): + def resolve_links(self, ctx): + pass + + @staticmethod + @abc.abstractmethod + def load(path: str, data: dict, **options): pass class AnswerReply(Reply): - def __init__(self, path, created_at, created_by, answer_uuid): - super().__init__(path, created_at, created_by, 'AnswerReply') - self.answer_uuid = answer_uuid # type: str - self.answer = None # type: Optional[Answer] + def __init__(self, *, path: str, created_at: datetime.datetime, + created_by: SimpleAuthor | None, answer_uuid: str): + super().__init__( + path=path, + created_at=created_at, + created_by=created_by, + reply_type='AnswerReply', + ) + self.answer_uuid = answer_uuid + + self.answer: Answer | None = None @property - def value(self) -> Optional[str]: + def value(self) -> str | None: return self.answer_uuid - def _resolve_links(self, ctx): - super()._resolve_links_parent(ctx) + def resolve_links(self, ctx): + super().resolve_links_parent(ctx) self.answer = ctx.e.answers.get(self.answer_uuid, None) @staticmethod @@ -470,26 +603,32 @@ def load(path: str, data: dict, **options): class StringReply(Reply): - def __init__(self, path, created_at, created_by, value): - super().__init__(path, created_at, created_by, 'StringReply') - self.value = value # type: str + def __init__(self, *, path: str, created_at: datetime.datetime, + created_by: SimpleAuthor | None, value: str): + super().__init__( + path=path, + created_at=created_at, + created_by=created_by, + reply_type='StringReply', + ) + self.value = value @property - def as_number(self) -> Optional[float]: + def as_number(self) -> float | None: try: return float(self.value) except Exception: return None @property - def as_datetime(self) -> Optional[datetime.datetime]: + def as_datetime(self) -> datetime.datetime | None: try: return dp.parse(self.value) except Exception: return None - def _resolve_links(self, ctx): - super()._resolve_links_parent(ctx) + def resolve_links(self, ctx): + super().resolve_links_parent(ctx) @staticmethod def load(path: str, data: dict, **options): @@ -503,9 +642,15 @@ def load(path: str, data: dict, **options): class ItemListReply(Reply): - def __init__(self, path, created_at, created_by, items): - super().__init__(path, created_at, created_by, 'ItemListReply') - self.items = items # type: list[str] + def __init__(self, *, path: str, created_at: datetime.datetime, + created_by: SimpleAuthor | None, items: list[str]): + super().__init__( + path=path, + created_at=created_at, + created_by=created_by, + reply_type='ItemListReply', + ) + self.items = items @property def value(self) -> list[str]: @@ -517,8 +662,8 @@ def __iter__(self): def __len__(self): return len(self.items) - def _resolve_links(self, ctx): - super()._resolve_links_parent(ctx) + def resolve_links(self, ctx): + super().resolve_links_parent(ctx) @staticmethod def load(path: str, data: dict, **options): @@ -532,10 +677,17 @@ def load(path: str, data: dict, **options): class MultiChoiceReply(Reply): - def __init__(self, path, created_at, created_by, choice_uuids): - super().__init__(path, created_at, created_by, 'MultiChoiceReply') - self.choice_uuids = choice_uuids # type: list[str] - self.choices = list() # type: list[Choice] + def __init__(self, *, path: str, created_at: datetime.datetime, + created_by: SimpleAuthor | None, choice_uuids: list[str]): + super().__init__( + path=path, + created_at=created_at, + created_by=created_by, + reply_type='MultiChoiceReply', + ) + self.choice_uuids = choice_uuids + + self.choices: list[Choice] = [] @property def value(self) -> list[str]: @@ -547,11 +699,11 @@ def __iter__(self): def __len__(self): return len(self.choices) - def _resolve_links(self, ctx): - super()._resolve_links_parent(ctx) + def resolve_links(self, ctx): + super().resolve_links_parent(ctx) self.choices = [ctx.e.choices[key] for key in self.choice_uuids - if key in ctx.e.choices.keys()] + if key in ctx.e.choices] @staticmethod def load(path: str, data: dict, **options): @@ -565,13 +717,19 @@ def load(path: str, data: dict, **options): class IntegrationReply(Reply): - def __init__(self, path, created_at, created_by, item_id, value): - super().__init__(path, created_at, created_by, 'IntegrationReply') - self.item_id = item_id # type: Optional[str] - self.value = value # type: str + def __init__(self, *, path: str, created_at: datetime.datetime, + created_by: SimpleAuthor | None, item_id: str | None, value: str): + super().__init__( + path=path, + created_at=created_at, + created_by=created_by, + reply_type='IntegrationReply', + ) + self.item_id = item_id + self.value = value @property - def id(self) -> Optional[str]: + def id(self) -> str | None: return self.item_id @property @@ -583,7 +741,7 @@ def is_integration(self) -> bool: return not self.is_plain @property - def url(self) -> Optional[str]: + def url(self) -> str | None: if not self.is_integration or self.item_id is None: return None if isinstance(self.question, IntegrationQuestion) \ @@ -591,8 +749,8 @@ def url(self) -> Optional[str]: return self.question.integration.item(self.item_id) return None - def _resolve_links(self, ctx): - super()._resolve_links_parent(ctx) + def resolve_links(self, ctx): + super().resolve_links_parent(ctx) @staticmethod def load(path: str, data: dict, **options): @@ -607,17 +765,23 @@ def load(path: str, data: dict, **options): class ItemSelectReply(Reply): - def __init__(self, path, created_at, created_by, item_uuid): - super().__init__(path, created_at, created_by, 'ItemSelectReply') - self.item_uuid = item_uuid # type: str - self.item_title = 'Item' # type: str + def __init__(self, *, path: str, created_at: datetime.datetime, + created_by: SimpleAuthor | None, item_uuid: str): + super().__init__( + path=path, + created_at=created_at, + created_by=created_by, + reply_type='ItemSelectReply', + ) + self.item_uuid = item_uuid + self.item_title: str = 'Item' @property def value(self) -> str: return self.item_uuid - def _resolve_links(self, ctx): - super()._resolve_links_parent(ctx) + def resolve_links(self, ctx): + super().resolve_links_parent(ctx) @staticmethod def load(path: str, data: dict, **options): @@ -631,17 +795,23 @@ def load(path: str, data: dict, **options): class FileReply(Reply): - def __init__(self, path, created_at, created_by, file_uuid): - super().__init__(path, created_at, created_by, 'FileReply') - self.file_uuid = file_uuid # type: str - self.file = None # type: Optional[QuestionnaireFile] + def __init__(self, *, path: str, created_at: datetime.datetime, + created_by: SimpleAuthor | None, file_uuid: str): + super().__init__( + path=path, + created_at=created_at, + created_by=created_by, + reply_type='FileReply', + ) + self.file_uuid = file_uuid + self.file: QuestionnaireFile | None = None @property def value(self) -> str: return self.file_uuid - def _resolve_links(self, ctx): - super()._resolve_links_parent(ctx) + def resolve_links(self, ctx): + super().resolve_links(ctx) self.file = ctx.questionnaire.files.get(self.file_uuid, None) if self.file is not None: self.file.reply = self @@ -658,16 +828,18 @@ def load(path: str, data: dict, **options): class Answer: - def __init__(self, uuid, label, advice, metric_measures, followup_uuids, - annotations): - self.uuid = uuid # type: str - self.label = label # type: str - self.advice = advice # type: Optional[str] - self.metric_measures = metric_measures # type: list[MetricMeasure] - self.followup_uuids = followup_uuids # type: list[str] - self.followups = list() # type: list[Question] - self.parent = None # type: Optional[OptionsQuestion] - self.annotations = annotations # type: AnnotationsT + def __init__(self, *, uuid: str, label: str, advice: str | None, + metric_measures: list[MetricMeasure], followup_uuids: list[str], + annotations: AnnotationsT): + self.uuid = uuid + self.label = label + self.advice = advice + self.metric_measures = metric_measures + self.followup_uuids = followup_uuids + self.annotations = annotations + + self.followups: list[Question] = [] + self.parent: OptionsQuestion | None = None @property def a(self): @@ -678,15 +850,15 @@ def __eq__(self, other): return False return other.uuid == self.uuid - def _resolve_links(self, ctx): + def resolve_links(self, ctx): self.followups = [ctx.e.questions[key] for key in self.followup_uuids - if key in ctx.e.questions.keys()] + if key in ctx.e.questions] for followup in self.followups: followup.parent = self - followup._resolve_links(ctx) + followup.resolve_links(ctx) for mm in self.metric_measures: - mm._resolve_links(ctx) + mm.resolve_links(ctx) @staticmethod def load(data: dict, **options): @@ -704,11 +876,12 @@ def load(data: dict, **options): class Choice: - def __init__(self, uuid, label, annotations): - self.uuid = uuid # type: str - self.label = label # type: str - self.parent = None # type: Optional[MultiChoiceQuestion] - self.annotations = annotations # type: AnnotationsT + def __init__(self, *, uuid: str, label: str, annotations: AnnotationsT): + self.uuid = uuid + self.label = label + self.annotations = annotations + + self.parent: MultiChoiceQuestion | None = None @property def a(self): @@ -728,26 +901,29 @@ def load(data: dict, **options): ) -class Question: - - def __init__(self, uuid, q_type, title, text, tag_uuids, reference_uuids, - expert_uuids, required_phase_uuid, annotations): - self.uuid = uuid # type: str - self.type = q_type # type: str - self.title = title # type: str - self.text = text # type: Optional[str] - self.tag_uuids = tag_uuids # type: list[str] - self.tags = list() # type: list[Tag] - self.reference_uuids = reference_uuids # type: list[str] - self.references = list() # type: list[Reference] - self.expert_uuids = expert_uuids # type: list[str] - self.experts = list() # type: list[Expert] - self.required_phase_uuid = required_phase_uuid # type: Optional[str] - self.required_phase = PHASE_NEVER # type: Phase - self.replies = dict() # type: dict[str, Reply] # added from replies - self.is_required = None # type: Optional[bool] - self.parent = None # type: Optional[Union[Chapter, ListQuestion, Answer]] - self.annotations = annotations # type: AnnotationsT +class Question(abc.ABC): + + def __init__(self, *, uuid: str, q_type: str, title: str, text: str | None, + tag_uuids: list[str], reference_uuids: list[str], + expert_uuids: list[str], required_phase_uuid: str | None, + annotations: AnnotationsT): + self.uuid = uuid + self.type = q_type + self.title = title + self.text = text + self.tag_uuids = tag_uuids + self.reference_uuids = reference_uuids + self.expert_uuids = expert_uuids + self.required_phase_uuid = required_phase_uuid + self.annotations = annotations + + self.is_required: bool | None = None + self.parent: Chapter | ListQuestion | Answer | None = None + self.replies: dict[str, Reply] = {} + self.tags: list[Tag] = [] + self.references: list[Reference] = [] + self.experts: list[Expert] = [] + self.required_phase: Phase = PHASE_NEVER @property def a(self): @@ -758,25 +934,25 @@ def __eq__(self, other): return False return other.uuid == self.uuid - def _resolve_links_parent(self, ctx): + def resolve_links_parent(self, ctx): self.tags = [ctx.e.tags[key] for key in self.tag_uuids - if key in ctx.e.tags.keys()] + if key in ctx.e.tags] self.experts = [ctx.e.experts[key] for key in self.expert_uuids - if key in ctx.e.experts.keys()] + if key in ctx.e.experts] self.references = [ctx.e.references[key] for key in self.reference_uuids - if key in ctx.e.references.keys()] + if key in ctx.e.references] for ref in self.references: - ref._resolve_links(ctx) + ref.resolve_links(ctx) if self.required_phase_uuid is None or ctx.current_phase is None: self.is_required = False else: self.required_phase = ctx.e.phases.get(self.required_phase_uuid, PHASE_NEVER) self.is_required = ctx.current_phase.order >= self.required_phase.order - def _resolve_links(self, ctx): + def resolve_links(self, ctx): pass @property @@ -791,6 +967,11 @@ def resource_page_references(self) -> list[ResourcePageReference]: def cross_references(self) -> list[CrossReference]: return [r for r in self.references if isinstance(r, CrossReference)] + @staticmethod + @abc.abstractmethod + def load(data: dict, **options): + pass + class ValueQuestionValidation: SHORT_TYPE: dict[str, str] = { @@ -826,7 +1007,7 @@ class ValueQuestionValidation: 'DomainQuestionValidation': str, } - def __init__(self, validation_type: str, value: str | int | float | None = None): + def __init__(self, *, validation_type: str, value: str | int | float | None = None): self.type = self.SHORT_TYPE.get(validation_type, 'unknown') self.full_type = validation_type self.value = value @@ -841,13 +1022,23 @@ def load(data: dict, **options): class ValueQuestion(Question): - def __init__(self, uuid, title, text, tag_uuids, reference_uuids, - expert_uuids, required_phase_uuid, value_type, annotations): - super().__init__(uuid, 'ValueQuestion', title, text, tag_uuids, - reference_uuids, expert_uuids, required_phase_uuid, - annotations) - self.value_type = value_type # type: str - self.validations = list() # type: list[ValueQuestionValidation] + def __init__(self, *, uuid: str, title: str, text: str | None, + tag_uuids: list[str], reference_uuids: list[str], + expert_uuids: list[str], required_phase_uuid: str | None, + value_type: str, annotations: AnnotationsT): + super().__init__( + uuid=uuid, + q_type='ValueQuestion', + title=title, + text=text, + tag_uuids=tag_uuids, + reference_uuids=reference_uuids, + expert_uuids=expert_uuids, + required_phase_uuid=required_phase_uuid, + annotations=annotations, + ) + self.value_type = value_type + self.validations: list[ValueQuestionValidation] = [] @property def a(self): @@ -889,8 +1080,8 @@ def is_datetime(self): def is_date(self): return self.value_type == 'DateQuestionValueType' - def _resolve_links(self, ctx): - super()._resolve_links_parent(ctx) + def resolve_links(self, ctx): + super().resolve_links_parent(ctx) @staticmethod def load(data: dict, **options): @@ -912,23 +1103,33 @@ def load(data: dict, **options): class OptionsQuestion(Question): - def __init__(self, uuid, title, text, tag_uuids, reference_uuids, - expert_uuids, required_phase_uuid, answer_uuids, - annotations): - super().__init__(uuid, 'OptionsQuestion', title, text, tag_uuids, - reference_uuids, expert_uuids, required_phase_uuid, - annotations) - self.answer_uuids = answer_uuids # type: list[str] - self.answers = list() # type: list[Answer] - - def _resolve_links(self, ctx): - super()._resolve_links_parent(ctx) + def __init__(self, *, uuid: str, title: str, text: str | None, + tag_uuids: list[str], reference_uuids: list[str], + expert_uuids: list[str], required_phase_uuid: str | None, + answer_uuids: list[str], annotations: AnnotationsT): + super().__init__( + uuid=uuid, + q_type='OptionsQuestion', + title=title, + text=text, + tag_uuids=tag_uuids, + reference_uuids=reference_uuids, + expert_uuids=expert_uuids, + required_phase_uuid=required_phase_uuid, + annotations=annotations, + ) + self.answer_uuids = answer_uuids + + self.answers: list[Answer] = [] + + def resolve_links(self, ctx): + super().resolve_links_parent(ctx) self.answers = [ctx.e.answers[key] for key in self.answer_uuids - if key in ctx.e.answers.keys()] + if key in ctx.e.answers] for answer in self.answers: answer.parent = self - answer._resolve_links(ctx) + answer.resolve_links(ctx) @staticmethod def load(data: dict, **options): @@ -947,20 +1148,30 @@ def load(data: dict, **options): class MultiChoiceQuestion(Question): - def __init__(self, uuid, title, text, tag_uuids, reference_uuids, - expert_uuids, required_phase_uuid, choice_uuids, - annotations): - super().__init__(uuid, 'MultiChoiceQuestion', title, text, tag_uuids, - reference_uuids, expert_uuids, required_phase_uuid, - annotations) - self.choice_uuids = choice_uuids # type: list[str] - self.choices = list() # type: list[Choice] - - def _resolve_links(self, ctx): - super()._resolve_links_parent(ctx) + def __init__(self, *, uuid: str, title: str, text: str | None, + tag_uuids: list[str], reference_uuids: list[str], + expert_uuids: list[str], required_phase_uuid: str | None, + choice_uuids: list[str], annotations: AnnotationsT): + super().__init__( + uuid=uuid, + q_type='MultiChoiceQuestion', + title=title, + text=text, + tag_uuids=tag_uuids, + reference_uuids=reference_uuids, + expert_uuids=expert_uuids, + required_phase_uuid=required_phase_uuid, + annotations=annotations, + ) + self.choice_uuids = choice_uuids + + self.choices: list[Choice] = [] + + def resolve_links(self, ctx): + super().resolve_links_parent(ctx) self.choices = [ctx.e.choices[key] for key in self.choice_uuids - if key in ctx.e.choices.keys()] + if key in ctx.e.choices] for choice in self.choices: choice.parent = self @@ -981,23 +1192,33 @@ def load(data: dict, **options): class ListQuestion(Question): - def __init__(self, uuid, title, text, tag_uuids, reference_uuids, - expert_uuids, required_phase_uuid, followup_uuids, - annotations): - super().__init__(uuid, 'ListQuestion', title, text, tag_uuids, - reference_uuids, expert_uuids, required_phase_uuid, - annotations) - self.followup_uuids = followup_uuids # type: list[str] - self.followups = list() # type: list[Question] - - def _resolve_links(self, ctx): - super()._resolve_links_parent(ctx) + def __init__(self, *, uuid: str, title: str, text: str, + tag_uuids: list[str], reference_uuids: list[str], + expert_uuids: list[str], required_phase_uuid: str | None, + followup_uuids: list[str], annotations: AnnotationsT): + super().__init__( + uuid=uuid, + q_type='ListQuestion', + title=title, + text=text, + tag_uuids=tag_uuids, + reference_uuids=reference_uuids, + expert_uuids=expert_uuids, + required_phase_uuid=required_phase_uuid, + annotations=annotations, + ) + self.followup_uuids = followup_uuids + + self.followups: list[Question] = [] + + def resolve_links(self, ctx): + super().resolve_links_parent(ctx) self.followups = [ctx.e.questions[key] for key in self.followup_uuids - if key in ctx.e.questions.keys()] + if key in ctx.e.questions] for followup in self.followups: followup.parent = self - followup._resolve_links(ctx) + followup.resolve_links(ctx) @staticmethod def load(data: dict, **options): @@ -1016,18 +1237,29 @@ def load(data: dict, **options): class IntegrationQuestion(Question): - def __init__(self, uuid, title, text, tag_uuids, reference_uuids, props, - expert_uuids, required_phase_uuid, integration_uuid, - annotations): - super().__init__(uuid, 'IntegrationQuestion', title, text, tag_uuids, - reference_uuids, expert_uuids, required_phase_uuid, - annotations) - self.props = props # type: dict[str, str] - self.integration_uuid = integration_uuid # type: str - self.integration = None # type: Optional[Integration] - - def _resolve_links(self, ctx): - super()._resolve_links_parent(ctx) + def __init__(self, *, uuid: str, title: str, text: str | None, + tag_uuids: list[str], reference_uuids: list[str], + expert_uuids: list[str], required_phase_uuid: str | None, + integration_uuid: str | None, props: dict[str, str], + annotations: AnnotationsT): + super().__init__( + uuid=uuid, + q_type='IntegrationQuestion', + title=title, + text=text, + tag_uuids=tag_uuids, + reference_uuids=reference_uuids, + expert_uuids=expert_uuids, + required_phase_uuid=required_phase_uuid, + annotations=annotations, + ) + self.props = props + self.integration_uuid = integration_uuid + + self.integration: Integration | None = None + + def resolve_links(self, ctx): + super().resolve_links_parent(ctx) self.integration = ctx.e.integrations.get( self.integration_uuid, ApiIntegration.default(), @@ -1051,17 +1283,26 @@ def load(data: dict, **options): class ItemSelectQuestion(Question): - def __init__(self, uuid, title, text, tag_uuids, reference_uuids, - expert_uuids, required_phase_uuid, list_question_uuid, - annotations): - super().__init__(uuid, 'ItemSelectQuestion', title, text, tag_uuids, - reference_uuids, expert_uuids, required_phase_uuid, - annotations) - self.list_question_uuid = list_question_uuid # type: str - self.list_question = None # type: Optional[ListQuestion] - - def _resolve_links(self, ctx): - super()._resolve_links_parent(ctx) + def __init__(self, *, uuid: str, title: str, text: str | None, + tag_uuids: list[str], reference_uuids: list[str], + expert_uuids: list[str], required_phase_uuid: str | None, + list_question_uuid: str | None, annotations: AnnotationsT): + super().__init__( + uuid=uuid, + q_type='ItemSelectQuestion', + title=title, + text=text, + tag_uuids=tag_uuids, + reference_uuids=reference_uuids, + expert_uuids=expert_uuids, + required_phase_uuid=required_phase_uuid, + annotations=annotations, + ) + self.list_question_uuid = list_question_uuid + self.list_question = None + + def resolve_links(self, ctx): + super().resolve_links_parent(ctx) self.list_question = ctx.e.questions.get(self.list_question_uuid, None) @staticmethod @@ -1081,18 +1322,23 @@ def load(data: dict, **options): class FileQuestion(Question): - def __init__(self, uuid, title, text, tag_uuids, reference_uuids, + def __init__(self, *, uuid, title, text, tag_uuids, reference_uuids, expert_uuids, required_phase_uuid, max_size, file_types, annotations): - super().__init__(uuid, 'FileQuestion', title, text, tag_uuids, - reference_uuids, expert_uuids, required_phase_uuid, - annotations) + super().__init__( + uuid=uuid, + q_type='FileQuestion', + title=title, + text=text, + tag_uuids=tag_uuids, + reference_uuids=reference_uuids, + expert_uuids=expert_uuids, + required_phase_uuid=required_phase_uuid, + annotations=annotations, + ) self.max_size = max_size self.file_types = file_types - def _resolve_links(self, ctx): - super()._resolve_links_parent(ctx) - @staticmethod def load(data: dict, **options): return FileQuestion( @@ -1111,14 +1357,16 @@ def load(data: dict, **options): class Chapter: - def __init__(self, uuid, title, text, question_uuids, annotations): - self.uuid = uuid # type: str - self.title = title # type: str - self.text = text # type: Optional[str] - self.question_uuids = question_uuids # type: list[str] - self.questions = list() # type: list[Question] - self.reports = list() # type: list[ReportItem] - self.annotations = annotations # type: AnnotationsT + def __init__(self, *, uuid: str, title: str, text: str | None, + question_uuids: list[str], annotations: AnnotationsT): + self.uuid = uuid + self.title = title + self.text = text + self.question_uuids = question_uuids + self.annotations = annotations + + self.questions: list[Question] = [] + self.reports: list[ReportItem] = [] @property def a(self): @@ -1129,13 +1377,13 @@ def __eq__(self, other): return False return other.uuid == self.uuid - def _resolve_links(self, ctx): + def resolve_links(self, ctx): self.questions = [ctx.e.questions[key] for key in self.question_uuids - if key in ctx.e.questions.keys()] + if key in ctx.e.questions] for question in self.questions: question.parent = self - question._resolve_links(ctx) + question.resolve_links(ctx) @staticmethod def load(data: dict, **options): @@ -1148,75 +1396,88 @@ def load(data: dict, **options): ) +_QUESTION_TYPES: dict[str, type[Question]] = { + 'OptionsQuestion': OptionsQuestion, + 'ListQuestion': ListQuestion, + 'ValueQuestion': ValueQuestion, + 'MultiChoiceQuestion': MultiChoiceQuestion, + 'IntegrationQuestion': IntegrationQuestion, + 'ItemSelectQuestion': ItemSelectQuestion, + 'FileQuestion': FileQuestion, +} + + +_REFERENCE_TYPES: dict[str, type[Reference]] = { + 'URLReference': URLReference, + 'ResourcePageReference': ResourcePageReference, + 'CrossReference': CrossReference, +} + + +_INTEGRATION_TYPES: dict[str, type[Integration]] = { + 'ApiIntegration': ApiIntegration, + 'WidgetIntegration': WidgetIntegration, +} + + +_REPLY_TYPES: dict[str, type[Reply]] = { + 'AnswerReply': AnswerReply, + 'StringReply': StringReply, + 'ItemListReply': ItemListReply, + 'MultiChoiceReply': MultiChoiceReply, + 'IntegrationReply': IntegrationReply, + 'ItemSelectReply': ItemSelectReply, + 'FileReply': FileReply, +} + + def _load_question(data: dict, **options): - if data['questionType'] == 'OptionsQuestion': - return OptionsQuestion.load(data, **options) - if data['questionType'] == 'ListQuestion': - return ListQuestion.load(data, **options) - if data['questionType'] == 'ValueQuestion': - return ValueQuestion.load(data, **options) - if data['questionType'] == 'MultiChoiceQuestion': - return MultiChoiceQuestion.load(data, **options) - if data['questionType'] == 'IntegrationQuestion': - return IntegrationQuestion.load(data, **options) - if data['questionType'] == 'ItemSelectQuestion': - return ItemSelectQuestion.load(data, **options) - if data['questionType'] == 'FileQuestion': - return FileQuestion.load(data, **options) - raise ValueError(f'Unknown question type: {data["questionType"]}') + question_type = data['questionType'] + question_class = _QUESTION_TYPES.get(question_type, None) + if question_class is None: + raise ValueError(f'Unknown question type: {question_type}') + return question_class.load(data, **options) def _load_reference(data: dict, **options): - if data['referenceType'] == 'URLReference': - return URLReference.load(data, **options) - if data['referenceType'] == 'ResourcePageReference': - return ResourcePageReference.load(data, **options) - if data['referenceType'] == 'CrossReference': - return CrossReference.load(data, **options) - raise ValueError(f'Unknown reference type: {data["referenceType"]}') + reference_type = data['referenceType'] + reference_class = _REFERENCE_TYPES.get(reference_type, None) + if reference_class is None: + raise ValueError(f'Unknown reference type: {reference_type}') + return reference_class.load(data, **options) def _load_integration(data: dict, **options): - if data['integrationType'] == 'ApiIntegration': - return ApiIntegration.load(data, **options) - if data['integrationType'] == 'WidgetIntegration': - return WidgetIntegration.load(data, **options) - raise ValueError(f'Unknown integration type: {data["integrationType"]}') + integration_type = data['integrationType'] + integration_class = _INTEGRATION_TYPES.get(integration_type, None) + if integration_class is None: + raise ValueError(f'Unknown integration type: {integration_type}') + return integration_class.load(data, **options) def _load_reply(path: str, data: dict, **options): - if data['value']['type'] == 'AnswerReply': - return AnswerReply.load(path, data, **options) - if data['value']['type'] == 'StringReply': - return StringReply.load(path, data, **options) - if data['value']['type'] == 'ItemListReply': - return ItemListReply.load(path, data, **options) - if data['value']['type'] == 'MultiChoiceReply': - return MultiChoiceReply.load(path, data, **options) - if data['value']['type'] == 'IntegrationReply': - return IntegrationReply.load(path, data, **options) - if data['value']['type'] == 'ItemSelectReply': - return ItemSelectReply.load(path, data, **options) - if data['value']['type'] == 'FileReply': - return FileReply.load(path, data, **options) - raise ValueError(f'Unknown reply type: {data["value"]["type"]}') + reply_type = data['value']['type'] + reply_class = _REPLY_TYPES.get(reply_type, None) + if reply_class is None: + raise ValueError(f'Unknown reply type: {reply_type}') + return reply_class.load(path, data, **options) class KnowledgeModelEntities: def __init__(self): - self.chapters = dict() # type: dict[str, Chapter] - self.questions = dict() # type: dict[str, Question] - self.answers = dict() # type: dict[str, Answer] - self.choices = dict() # type: dict[str, Choice] - self.resource_collections = dict() # type: dict[str, ResourceCollection] - self.resource_pages = dict() # type: dict[str, ResourcePage] - self.references = dict() # type: dict[str, Reference] - self.experts = dict() # type: dict[str, Expert] - self.tags = dict() # type: dict[str, Tag] - self.metrics = dict() # type: dict[str, Metric] - self.phases = dict() # type: dict[str, Phase] - self.integrations = dict() # type: dict[str, Integration] + self.chapters: dict[str, Chapter] = {} + self.questions: dict[str, Question] = {} + self.answers: dict[str, Answer] = {} + self.choices: dict[str, Choice] = {} + self.resource_collections: dict[str, ResourceCollection] = {} + self.resource_pages: dict[str, ResourcePage] = {} + self.references: dict[str, Reference] = {} + self.experts: dict[str, Expert] = {} + self.tags: dict[str, Tag] = {} + self.metrics: dict[str, Metric] = {} + self.phases: dict[str, Phase] = {} + self.integrations: dict[str, Integration] = {} @staticmethod def load(data: dict, **options): @@ -1250,24 +1511,26 @@ def load(data: dict, **options): class KnowledgeModel: - def __init__(self, uuid, chapter_uuids, tag_uuids, metric_uuids, - phase_uuids, integration_uuids, resource_collection_uuids, - entities, annotations): - self.uuid = uuid # type: str - self.entities = entities # type: KnowledgeModelEntities - self.chapter_uuids = chapter_uuids # type: list[str] - self.chapters = list() # type: list[Chapter] - self.tag_uuids = tag_uuids # type: list[str] - self.tags = list() # type: list[Tag] - self.metric_uuids = metric_uuids # type: list[str] - self.metrics = list() # type: list[Metric] - self.phase_uuids = phase_uuids # type: list[str] - self.phases = list() # type: list[Phase] - self.resource_collection_uuids = resource_collection_uuids # type: list[str] - self.resource_collections = list() # type: list[ResourceCollection] - self.integration_uuids = integration_uuids # type: list[str] - self.integrations = list() # type: list[Integration] - self.annotations = annotations # type: AnnotationsT + def __init__(self, *, uuid: str, chapter_uuids: list[str], tag_uuids: list[str], + metric_uuids: list[str], phase_uuids: list[str], integration_uuids: list[str], + resource_collection_uuids: list[str], entities: KnowledgeModelEntities, + annotations: AnnotationsT): + self.uuid = uuid + self.entities = entities + self.chapter_uuids = chapter_uuids + self.tag_uuids = tag_uuids + self.metric_uuids = metric_uuids + self.phase_uuids = phase_uuids + self.resource_collection_uuids = resource_collection_uuids + self.integration_uuids = integration_uuids + self.annotations = annotations + + self.chapters: list[Chapter] = [] + self.tags: list[Tag] = [] + self.metrics: list[Metric] = [] + self.phases: list[Phase] = [] + self.resource_collections: list[ResourceCollection] = [] + self.integrations: list[Integration] = [] @property def a(self): @@ -1277,31 +1540,31 @@ def a(self): def e(self): return self.entities - def _resolve_links(self, ctx): + def resolve_links(self, ctx): self.chapters = [ctx.e.chapters[key] for key in self.chapter_uuids - if key in ctx.e.chapters.keys()] + if key in ctx.e.chapters] self.tags = [ctx.e.tags[key] for key in self.tag_uuids - if key in ctx.e.tags.keys()] + if key in ctx.e.tags] self.metrics = [ctx.e.metrics[key] for key in self.metric_uuids - if key in ctx.e.metrics.keys()] + if key in ctx.e.metrics] self.phases = [ctx.e.phases[key] for key in self.phase_uuids - if key in ctx.e.phases.keys()] + if key in ctx.e.phases] self.resource_collections = [ctx.e.resource_collections[key] for key in self.resource_collection_uuids - if key in ctx.e.resource_collections.keys()] + if key in ctx.e.resource_collections] self.integrations = [ctx.e.integrations[key] for key in self.integration_uuids - if key in ctx.e.integrations.keys()] + if key in ctx.e.integrations] for index, phase in enumerate(self.phases, start=1): phase.order = index for chapter in self.chapters: - chapter._resolve_links(ctx) + chapter.resolve_links(ctx) for resource_collection in self.resource_collections: - resource_collection._resolve_links(ctx) + resource_collection.resolve_links(ctx) @staticmethod def load(data: dict, **options): @@ -1320,8 +1583,8 @@ def load(data: dict, **options): class ContextConfig: - def __init__(self, client_url): - self.client_url = client_url # type: str + def __init__(self, *, client_url: str | None): + self.client_url = client_url @staticmethod def load(data: dict, **options): @@ -1332,14 +1595,14 @@ def load(data: dict, **options): class Document: - def __init__(self, uuid, name, document_template_id, format_uuid, - created_by, created_at): - self.uuid = uuid # type: str - self.name = name # type: str - self.document_template_id = document_template_id # type: str - self.format_uuid = format_uuid # type: str - self.created_by = created_by # type: Optional[User] - self.created_at = created_at # type: datetime.datetime + def __init__(self, *, uuid: str, name: str, document_template_id: str, format_uuid: str, + created_by: User | None, created_at: datetime.datetime): + self.uuid = uuid + self.name = name + self.document_template_id = document_template_id + self.format_uuid = format_uuid + self.created_by = created_by + self.created_at = created_at @staticmethod def load(data: dict, **options): @@ -1353,39 +1616,18 @@ def load(data: dict, **options): ) -class SimpleAuthor: - - def __init__(self, uuid, first_name, last_name, image_url, gravatar_hash): - self.uuid = uuid # type: str - self.first_name = first_name # type: str - self.last_name = last_name # type: str - self.image_url = image_url # type: Optional[str] - self.gravatar_hash = gravatar_hash # type: Optional[str] - - @staticmethod - def load(data: Optional[dict], **options): - if data is None: - return None - return SimpleAuthor( - uuid=data['uuid'], - first_name=data['firstName'], - last_name=data['lastName'], - image_url=data['imageUrl'], - gravatar_hash=data['gravatarHash'], - ) - - class QuestionnaireVersion: - def __init__(self, uuid, event_uuid, name, description, - created_at, updated_at, created_by): - self.uuid = uuid # type: str - self.event_uuid = event_uuid # type: str - self.name = name # type: str - self.description = description # type: str - self.created_at = created_at # type: datetime.datetime - self.updated_at = updated_at # type: datetime.datetime - self.created_by = created_by # type: Optional[SimpleAuthor] + def __init__(self, *, uuid: str, event_uuid: str, name: str, description: str | None, + created_at: datetime.datetime, updated_at: datetime.datetime, + created_by: SimpleAuthor | None): + self.uuid = uuid + self.event_uuid = event_uuid + self.name = name + self.description = description + self.created_at = created_at + self.updated_at = updated_at + self.created_by = created_by @staticmethod def load(data: dict, **options): @@ -1402,53 +1644,53 @@ def load(data: dict, **options): class RepliesContainer: - def __init__(self, replies: dict[str, Reply]): + def __init__(self, *, replies: dict[str, Reply]): self.replies = replies - def __getitem__(self, path: str) -> Optional[Reply]: + def __getitem__(self, path: str) -> Reply | None: return self.get(path) def __len__(self) -> int: return len(self.replies) - def get(self, path: str, default=None) -> Optional[Reply]: + def get(self, path: str, default=None) -> Reply | None: return self.replies.get(path, default) - def iterate_by_prefix(self, path_prefix: str) -> Iterable[Reply]: + def iterate_by_prefix(self, path_prefix: str) -> typing.Iterable[Reply]: return (r for path, r in self.replies.items() if path.startswith(path_prefix)) - def iterate_by_suffix(self, path_suffix: str) -> Iterable[Reply]: + def iterate_by_suffix(self, path_suffix: str) -> typing.Iterable[Reply]: return (r for path, r in self.replies.items() if path.endswith(path_suffix)) - def values(self) -> Iterable[Reply]: + def values(self) -> typing.Iterable[Reply]: return self.replies.values() - def keys(self) -> Iterable[str]: - return self.replies.keys() + def keys(self) -> typing.Iterable[str]: + return self.replies - def items(self) -> ItemsView[str, Reply]: + def items(self) -> typing.ItemsView[str, Reply]: return self.replies.items() class QuestionnaireFile: - def __init__(self, uuid, file_name, file_size, content_type): - self.uuid = uuid # type: str - self.name = file_name # type: str - self.size = file_size # type: int - self.content_type = content_type # type: str + def __init__(self, *, uuid: str, file_name: str, file_size: int, + content_type: str): + self.uuid = uuid + self.name = file_name + self.size = file_size + self.content_type = content_type - self.reply = None # type: Optional[FileReply] - self.download_url = '' # type: str + self.reply: FileReply | None = None + self.download_url: str = '' - def _resolve_links(self, ctx): - # Update the download URL: /projects//files//download + def resolve_links(self, ctx): project_uuid = ctx.questionnaire.uuid client_url = ctx.config.client_url self.download_url = f'{client_url}/projects/{project_uuid}/files/{self.uuid}/download' @staticmethod - def load(data: dict, questionnaire_uuid: str, **options): + def load(data: dict, **options): return QuestionnaireFile( uuid=data['uuid'], file_name=data['fileName'], @@ -1459,28 +1701,31 @@ def load(data: dict, questionnaire_uuid: str, **options): class Questionnaire: - def __init__(self, uuid, name, description, created_by, phase_uuid, - created_at, updated_at): - self.uuid = uuid # type: str - self.name = name # type: str - self.description = description # type: str - self.version = None # type: Optional[QuestionnaireVersion] - self.versions = list() # type: list[QuestionnaireVersion] - self.files = dict() # type: dict[str, QuestionnaireFile] - self.todos = list() # type: list[str] - self.created_by = created_by # type: User - self.phase_uuid = phase_uuid # type: Optional[str] - self.phase = PHASE_NEVER # type: Phase - self.project_tags = list() # type: list[str] - self.replies = RepliesContainer(dict()) # type: RepliesContainer + def __init__(self, *, uuid: str, name: str, description: str | None, + created_by: User, phase_uuid: str | None, + created_at: datetime.datetime, updated_at: datetime.datetime): + self.uuid = uuid + self.name = name + self.description = description + self.created_by = created_by + self.phase_uuid = phase_uuid self.created_at = created_at self.updated_at = updated_at - def _resolve_links(self, ctx): + self.version: QuestionnaireVersion | None = None + self.versions: list[QuestionnaireVersion] = [] + self.files: dict[str, QuestionnaireFile] = {} + self.todos: list[str] = [] + self.project_tags: list[str] = [] + self.phase: Phase = PHASE_NEVER + + self.replies: RepliesContainer = RepliesContainer(replies={}) + + def resolve_links(self, ctx): for reply in self.replies.values(): - reply._resolve_links(ctx) - for file in self.files.values(): - file._resolve_links(ctx) + reply.resolve_links(ctx) + for questionnaire_file in self.files.values(): + questionnaire_file.resolve_links(ctx) @staticmethod def load(data: dict, **options): @@ -1490,7 +1735,7 @@ def load(data: dict, **options): version = None replies = {p: _load_reply(p, d, **options) for p, d in data['replies'].items()} - files = {d['uuid']: QuestionnaireFile.load(d, questionnaire_uuid, **options) + files = {d['uuid']: QuestionnaireFile.load(d, **options) for d in data.get('files', [])} for v in versions: if v.uuid == data['versionUuid']: @@ -1515,15 +1760,17 @@ def load(data: dict, **options): class Package: - def __init__(self, org_id, km_id, version, versions, name, description, created_at): - self.organization_id = org_id # type: str - self.km_id = km_id # type: str - self.version = version # type: str - self.id = f'{org_id}:{km_id}:{version}' - self.versions = versions # type: list[str] - self.name = name # type: str - self.description = description # type: str - self.created_at = created_at # type: datetime.datetime + def __init__(self, *, org_id: str, km_id: str, version: str, versions: list[str], + name: str, description: str, created_at: datetime.datetime): + self.organization_id = org_id + self.km_id = km_id + self.version = version + self.versions = versions + self.name = name + self.description = description + self.created_at = created_at + + self.id: str = f'{org_id}:{km_id}:{version}' @property def org_id(self): @@ -1544,10 +1791,10 @@ def load(data: dict, **options): class ReportIndication: - def __init__(self, indication_type, answered, unanswered): - self.indication_type = indication_type # type: str - self.answered = answered # type: int - self.unanswered = unanswered # type: int + def __init__(self, *, indication_type: str, answered: int, unanswered: int): + self.indication_type = indication_type + self.answered = answered + self.unanswered = unanswered @property def total(self) -> int: @@ -1578,13 +1825,14 @@ def load(data: dict, **options): class ReportMetric: - def __init__(self, measure, metric_uuid): - self.measure = measure # type: float - self.metric_uuid = metric_uuid # type: str - self.metric = None # type: Optional[Metric] + def __init__(self, *, measure: float, metric_uuid: str): + self.measure = measure + self.metric_uuid = metric_uuid + + self.metric: Metric | None = None - def _resolve_links(self, ctx): - if self.metric_uuid in ctx.e.metrics.keys(): + def resolve_links(self, ctx): + if self.metric_uuid in ctx.e.metrics: self.metric = ctx.e.metrics[self.metric_uuid] @staticmethod @@ -1597,16 +1845,18 @@ def load(data: dict, **options): class ReportItem: - def __init__(self, indications, metrics, chapter_uuid): - self.indications = indications # type: list[ReportIndication] - self.metrics = metrics # type: list[ReportMetric] - self.chapter_uuid = chapter_uuid # type: Optional[str] - self.chapter = None # type: Optional[Chapter] + def __init__(self, *, indications: list[ReportIndication], metrics: list[ReportMetric], + chapter_uuid: str | None): + self.indications = indications + self.metrics = metrics + self.chapter_uuid = chapter_uuid - def _resolve_links(self, ctx): + self.chapter: Chapter | None = None + + def resolve_links(self, ctx): for m in self.metrics: - m._resolve_links(ctx) - if self.chapter_uuid is not None and self.chapter_uuid in ctx.e.chapters.keys(): + m.resolve_links(ctx) + if self.chapter_uuid is not None and self.chapter_uuid in ctx.e.chapters: self.chapter = ctx.e.chapters[self.chapter_uuid] if self.chapter is not None: self.chapter.reports.append(self) @@ -1624,17 +1874,19 @@ def load(data: dict, **options): class Report: - def __init__(self, uuid, created_at, updated_at, chapter_reports, total_report): - self.uuid = uuid # type: str - self.created_at = created_at # type: datetime.datetime - self.updated_at = updated_at # type: datetime.datetime - self.total_report = total_report # type: ReportItem - self.chapter_reports = chapter_reports # type: list[ReportItem] + def __init__(self, *, uuid: str, created_at: datetime.datetime, + updated_at: datetime.datetime, chapter_reports: list[ReportItem], + total_report: ReportItem): + self.uuid = uuid + self.created_at = created_at + self.updated_at = updated_at + self.total_report = total_report + self.chapter_reports = chapter_reports - def _resolve_links(self, ctx): - self.total_report._resolve_links(ctx) + def resolve_links(self, ctx): + self.total_report.resolve_links(ctx) for report in self.chapter_reports: - report._resolve_links(ctx) + report.resolve_links(ctx) @staticmethod def load(data: dict, **options): @@ -1648,70 +1900,18 @@ def load(data: dict, **options): ) -class User: - - def __init__(self, uuid, first_name, last_name, email, role, created_at, - updated_at, affiliation, permissions, sources, image_url): - self.uuid = uuid # type: str - self.first_name = first_name # type: str - self.last_name = last_name # type: str - self.email = email # type: str - self.role = role # type: str - self.image_url = image_url # type: Optional[str] - self.affiliation = affiliation # type: Optional[str] - self.permissions = permissions # type: list[str] - self.sources = sources # type: list[str] - self.created_at = created_at # type: datetime.datetime - self.updated_at = updated_at # type: datetime.datetime - - @staticmethod - def load(data: dict, **options): - if data is None: - return None - return User( - uuid=data['uuid'], - first_name=data['firstName'], - last_name=data['lastName'], - email=data['email'], - role=data['role'], - image_url=data['imageUrl'], - affiliation=data['affiliation'], - permissions=data['permissions'], - sources=data['sources'], - created_at=_datetime(data['createdAt']), - updated_at=_datetime(data['updatedAt']), - ) - - -class Organization: - - def __init__(self, org_id, name, description, affiliations): - self.id = org_id # type: str - self.name = name # type: str - self.description = description # type: Optional[str] - self.affiliations = affiliations # type: list[str] - - @staticmethod - def load(data: dict, **options): - return Organization( - org_id=data['organizationId'], - name=data['name'], - description=data['description'], - affiliations=data['affiliations'], - ) - - class UserGroup: - def __init__(self, uuid, name, description, private, - created_at, updated_at): - self.uuid = uuid # type: str - self.name = name # type: str - self.description = description # type: Optional[str] - self.private = private # type: bool - self.members = list() # type: list[UserGroupMember] - self.created_at = created_at # type: datetime.datetime - self.updated_at = updated_at # type: datetime.datetime + def __init__(self, *, uuid: str, name: str, description: str | None, private: bool, + created_at: datetime.datetime, updated_at: datetime.datetime): + self.uuid = uuid + self.name = name + self.description = description + self.private = private + self.created_at = created_at + self.updated_at = updated_at + + self.members: list[UserGroupMember] = [] @staticmethod def load(data: dict, **options): @@ -1730,14 +1930,14 @@ def load(data: dict, **options): class UserGroupMember: - def __init__(self, uuid, first_name, last_name, gravatar_hash, - image_url, membership_type): - self.uuid = uuid # type: str - self.first_name = first_name # type: str - self.last_name = last_name # type: str - self.gravatar_hash = gravatar_hash # type: str - self.image_url = image_url # type: Optional[str] - self.membership_type = membership_type # type: str + def __init__(self, *, uuid: str, first_name: str, last_name: str, gravatar_hash: str, + image_url: str | None, membership_type: str): + self.uuid = uuid + self.first_name = first_name + self.last_name = last_name + self.gravatar_hash = gravatar_hash + self.image_url = image_url + self.membership_type = membership_type @staticmethod def load(data: dict, **options): @@ -1756,9 +1956,9 @@ def load(data: dict, **options): class DocumentContextUserPermission: - def __init__(self, user, permissions): - self.user = user # type: Optional[User] - self.permissions = permissions # type: list[str] + def __init__(self, *, user: User | None, permissions: list[str]): + self.user = user + self.permissions = permissions @property def is_viewer(self): @@ -1786,9 +1986,9 @@ def load(data: dict, **options): class DocumentContextUserGroupPermission: - def __init__(self, group, permissions): - self.group = group # type: Optional[UserGroup] - self.permissions = permissions # type: list[str] + def __init__(self, *, group: UserGroup | None, permissions: list[str]): + self.group = group + self.permissions = permissions @property def is_viewer(self): @@ -1818,7 +2018,7 @@ class DocumentContext: """Document Context smart representation""" METAMODEL_VERSION = 16 - def __init__(self, ctx, **options): + def __init__(self, *, ctx, **options): self.metamodel_version = int(ctx.get('metamodelVersion', '0')) if self.metamodel_version != self.METAMODEL_VERSION: raise ValueError(f'Unsupported metamodel version: {self.metamodel_version} ' @@ -1831,7 +2031,7 @@ def __init__(self, ctx, **options): self.document = Document.load(ctx['document'], **options) self.package = Package.load(ctx['package'], **options) self.organization = Organization.load(ctx['organization'], **options) - self.current_phase = PHASE_NEVER # type: Phase + self.current_phase: Phase = PHASE_NEVER self.users = [DocumentContextUserPermission.load(d, **options) for d in ctx['users']] @@ -1866,16 +2066,16 @@ def doc(self) -> Document: def replies(self) -> RepliesContainer: return self.questionnaire.replies - def _resolve_links(self): + def resolve_links(self): phase_uuid = self.questionnaire.phase_uuid - if phase_uuid is not None and phase_uuid in self.e.phases.keys(): + if phase_uuid is not None and phase_uuid in self.e.phases: self.current_phase = self.e.phases[phase_uuid] self.questionnaire.phase = self.current_phase - self.km._resolve_links(self) - self.report._resolve_links(self) - self.questionnaire._resolve_links(self) + self.km.resolve_links(self) + self.report.resolve_links(self) + self.questionnaire.resolve_links(self) - rv = ReplyVisitor(self) + rv = ReplyVisitor(context=self) rv.visit() for reply in self.replies.values(): if isinstance(reply, ItemSelectReply): @@ -1884,9 +2084,9 @@ def _resolve_links(self): class ReplyVisitor: - def __init__(self, context: DocumentContext): - self.item_titles = dict() # type: dict[str, str] - self._set_also = dict() # type: dict[str, list[str]] + def __init__(self, *, context: DocumentContext): + self.item_titles: dict[str, str] = {} + self._set_also: dict[str, list[str]] = {} self.context = context def visit(self): @@ -1902,7 +2102,7 @@ def _visit_question(self, question: Question, path: str): if isinstance(question, ListQuestion): self._visit_list_question(question, new_path) elif isinstance(question, OptionsQuestion): - self._visit_options_question(question, new_path) + self._visit_options_question(new_path) def _visit_list_question(self, question: ListQuestion, path: str): reply = self.context.replies.get(path) @@ -1912,30 +2112,31 @@ def _visit_list_question(self, question: ListQuestion, path: str): self.item_titles[item_uuid] = f'Item {n}' item_path = f'{path}.{item_uuid}' - # title if len(question.followups) > 0: title_path = f'{item_path}.{question.followups[0].uuid}' title_reply = self.context.replies.get(title_path) if title_reply is not None and isinstance(title_reply, StringReply): self.item_titles[item_uuid] = title_reply.value elif title_reply is not None and isinstance(title_reply, IntegrationReply): - non_empty_lines = list(filter(lambda line: len(line) > 0, title_reply.value.split('\n'))) + non_empty_lines = list(filter( + lambda line: len(line) > 0, + title_reply.value.split('\n'), + )) if len(non_empty_lines) > 0: self.item_titles[item_uuid] = non_empty_lines[0] elif title_reply is not None and isinstance(title_reply, ItemSelectReply): ref_item_uuid = title_reply.item_uuid - if ref_item_uuid in self.item_titles.keys(): + if ref_item_uuid in self.item_titles: self.item_titles[item_uuid] = self.item_titles[ref_item_uuid] else: self._set_also.setdefault(ref_item_uuid, []).append(item_uuid) for set_also in self._set_also.get(item_uuid, []): self.item_titles[set_also] = self.item_titles[item_uuid] - # followups for followup in question.followups: self._visit_question(followup, path=item_path) - def _visit_options_question(self, question: OptionsQuestion, path: str): + def _visit_options_question(self, path: str): reply = self.context.replies.get(path) if reply is None or not isinstance(reply, AnswerReply) or reply.answer is None: return diff --git a/packages/dsw-document-worker/dsw/document_worker/model/http.py b/packages/dsw-document-worker/dsw/document_worker/model/http.py index 8476d091..4e75d0c0 100644 --- a/packages/dsw-document-worker/dsw/document_worker/model/http.py +++ b/packages/dsw-document-worker/dsw/document_worker/model/http.py @@ -17,16 +17,13 @@ def _prepare_for_request(self): def get(self, url, params=None, **kwargs) -> requests.Response: self._prepare_for_request() - kwargs.update(timeout=self.timeout) - resp = requests.get(url=url, params=params, **kwargs) + resp = requests.get(url=url, params=params, timeout=self.timeout, **kwargs) return resp def post(self, url, data=None, json=None, **kwargs) -> requests.Response: self._prepare_for_request() - kwargs.update(timeout=self.timeout) - return requests.post(url=url, data=data, json=json, **kwargs) + return requests.post(url=url, data=data, json=json, timeout=self.timeout, **kwargs) def request(self, method: str, url: str, **kwargs) -> requests.Response: self._prepare_for_request() - kwargs.update(timeout=self.timeout) - return requests.request(method=method, url=url, **kwargs) + return requests.request(method=method, url=url, timeout=self.timeout, **kwargs) diff --git a/packages/dsw-document-worker/dsw/document_worker/templates/filters.py b/packages/dsw-document-worker/dsw/document_worker/templates/filters.py index de4fb825..f6991fcc 100644 --- a/packages/dsw-document-worker/dsw/document_worker/templates/filters.py +++ b/packages/dsw-document-worker/dsw/document_worker/templates/filters.py @@ -1,16 +1,18 @@ import datetime +import logging +import re +import typing + import dateutil.parser as dp import jinja2 -import logging import markupsafe import markdown -import re -from typing import Any, Union, Optional +from dsw.document_worker.utils import byte_size_format from ..exceptions import JobException from ..model import DocumentContext -from dsw.document_worker.utils import byte_size_format +from .tests import tests LOG = logging.getLogger(__name__) @@ -23,10 +25,10 @@ def extendMarkdown(self, md): class DSWMarkdownProcessor(markdown.preprocessors.Preprocessor): + LI_RE = re.compile(r'^[ ]*((\d+\.)|[*+-])[ ]+.*') def __init__(self, md): super().__init__(md) - self.LI_RE = re.compile(r'^[ ]*((\d+\.)|[*+-])[ ]+.*') def run(self, lines): prev_li = False @@ -55,12 +57,11 @@ def run(self, lines): class _JinjaEnv: def __init__(self): - self._env = None # type: Optional[jinja2.Environment] + self._env: jinja2.Environment | None = None @property def env(self) -> jinja2.Environment: if self._env is None: - from .tests import tests self._env = jinja2.Environment( loader=_base_jinja_loader, extensions=['jinja2.ext.do'], @@ -77,12 +78,12 @@ def get_template(self, template_str: str) -> jinja2.Template: _alphabet_size = len(_alphabet) _base_jinja_loader = jinja2.BaseLoader() _j2_env = _JinjaEnv() -_empty_dict = dict() # type: dict[str, Any] +_empty_dict: dict[str, typing.Any] = {} _romans = [(1000, 'M'), (900, 'CM'), (500, 'D'), (400, 'CD'), (100, 'C'), (90, 'XC'), (50, 'L'), (40, 'XL'), (10, 'X'), (9, 'IX'), (5, 'V'), (4, 'IV'), (1, 'I')] -def datetime_format(iso_timestamp: Union[None, datetime.datetime, str], fmt: str): +def datetime_format(iso_timestamp: None | datetime.datetime | str, fmt: str): if iso_timestamp is None: return '' if not isinstance(iso_timestamp, datetime.datetime): @@ -134,7 +135,7 @@ def _has_value(reply: dict) -> bool: return bool(reply) and ('value' in reply.keys()) and ('value' in reply['value'].keys()) -def _get_value(reply: dict) -> Any: +def _get_value(reply: dict) -> typing.Any: return reply['value']['value'] @@ -182,14 +183,14 @@ def reply_path(uuids: list) -> str: return '.'.join(map(str, uuids)) -def jinja2_render(template_str: str, vars=None, fail_safe=False, **kwargs): - if vars is None: - vars = _empty_dict +def jinja2_render(template_str: str, variables=None, fail_safe=False, **kwargs): + if variables is None: + variables = _empty_dict LOG.debug('Jinja2-in-Jinja2 rendering requested') try: j2_template = _j2_env.get_template(template_str) LOG.debug('Jinja2-in-Jinja2 template prepared') - result = j2_template.render(**vars, **kwargs) + result = j2_template.render(**variables, **kwargs) LOG.debug('Jinja2-in-Jinja2 result finished') return result except Exception as e: @@ -200,9 +201,9 @@ def jinja2_render(template_str: str, vars=None, fail_safe=False, **kwargs): def to_context_obj(ctx, **options) -> DocumentContext: LOG.debug('DocumentContext object requested') - result = DocumentContext(ctx, **options) + result = DocumentContext(ctx=ctx, **options) LOG.debug('DocumentContext object created') - result._resolve_links() + result.resolve_links() LOG.debug('DocumentContext object links resolved') return result diff --git a/packages/dsw-document-worker/dsw/document_worker/templates/formats.py b/packages/dsw-document-worker/dsw/document_worker/templates/formats.py index 1962c300..4c5a98fb 100644 --- a/packages/dsw-document-worker/dsw/document_worker/templates/formats.py +++ b/packages/dsw-document-worker/dsw/document_worker/templates/formats.py @@ -23,7 +23,7 @@ def __init__(self, template, metadata: dict): self._verify_metadata(metadata) self.uuid = self._trace = metadata[FormatField.UUID] self.name = metadata[FormatField.NAME] - LOG.info(f'Setting up format "{self.name}" ({self._trace})') + LOG.info('Setting up format "%s" (%s)', self.name, self._trace) self.steps = self._create_steps(metadata) if len(self.steps) < 1: self.template.raise_exc(f'Format {self.name} has no steps') diff --git a/packages/dsw-document-worker/dsw/document_worker/templates/steps/archive.py b/packages/dsw-document-worker/dsw/document_worker/templates/steps/archive.py index 40e8a08a..740e2db4 100644 --- a/packages/dsw-document-worker/dsw/document_worker/templates/steps/archive.py +++ b/packages/dsw-document-worker/dsw/document_worker/templates/steps/archive.py @@ -72,10 +72,9 @@ def _load_type(self) -> str: return self.TYPE_ZIP def _rectify_compression_level(self): - if self.compression_level < 0: - self.compression_level = 0 - if self.compression_level > 9: - self.compression_level = 9 + self.compression_level = max(self.compression_level, 0) + self.compression_level = min(self.compression_level, 9) + if self.mode == self.MODE_BZIP2 and self.compression_level == 0: self.compression_level = 1 @@ -114,7 +113,7 @@ def _create_tar(self) -> DocumentFile: tar_format = self.FORMATS_TAR[self.format] tar_file = TMP_DIR / 'result.tar' extra_opts = {} - if compression == 'gz' or compression == 'bz2': + if compression in ('gz', 'bz2'): extra_opts['compresslevel'] = self.compression_level with tarfile.open( name=str(tar_file), diff --git a/packages/dsw-document-worker/dsw/document_worker/templates/steps/base.py b/packages/dsw-document-worker/dsw/document_worker/templates/steps/base.py index 220bcd35..4bb4fb73 100644 --- a/packages/dsw-document-worker/dsw/document_worker/templates/steps/base.py +++ b/packages/dsw-document-worker/dsw/document_worker/templates/steps/base.py @@ -22,15 +22,18 @@ class Step: def __init__(self, template, options: dict[str, str]): self.template = template self.options = options - extras_str = self.options.get(self.OPTION_EXTRAS, '') # type: str - self.extras = set(extras_str.split(',')) # type: set[str] + + extras_str: str = self.options.get(self.OPTION_EXTRAS, '') + self.extras: set[str] = set(extras_str.split(',')) def requires_via_extras(self, requirement: str) -> bool: return requirement in self.extras + # pylint: disable-next=unused-argument def execute_first(self, context: dict) -> DocumentFile: return self.raise_exc('Called execute_follow on Step class') + # pylint: disable-next=unused-argument def execute_follow(self, document: DocumentFile, context: dict) -> DocumentFile: return self.raise_exc('Called execute_follow on Step class') @@ -38,10 +41,10 @@ def raise_exc(self, message: str): raise FormatStepException(message) -STEPS = dict() +STEPS: dict[str, type[Step]] = {} -def register_step(name: str, step_class: type): +def register_step(name: str, step_class: type[Step]): STEPS[name.lower()] = step_class diff --git a/packages/dsw-document-worker/dsw/document_worker/templates/steps/conversion.py b/packages/dsw-document-worker/dsw/document_worker/templates/steps/conversion.py index c318b67a..712b3f29 100644 --- a/packages/dsw-document-worker/dsw/document_worker/templates/steps/conversion.py +++ b/packages/dsw-document-worker/dsw/document_worker/templates/steps/conversion.py @@ -5,13 +5,19 @@ from .base import Step, register_step +def _is_true(value: str) -> bool: + return value.lower() == 'true' + + class WeasyPrintStep(Step): NAME = 'weasyprint' INPUT_FORMAT = FileFormats.HTML OUTPUT_FORMAT = FileFormats.PDF def __init__(self, template, options: dict): + # pylint: disable-next=import-outside-toplevel import weasyprint + super().__init__(template, options) # PDF options self.wp_options = weasyprint.DEFAULT_OPTIONS @@ -21,13 +27,13 @@ def __init__(self, template, options: dict): def wp_update_options(self, options: dict): optimize_size = tuple(options.get('render.optimize_size', 'fonts').split(',')) self.wp_options.update({ - 'pdf_identifier': options.get('pdf.identifier', 'false').lower() == 'true', + 'pdf_identifier': _is_true(options.get('pdf.identifier', 'false')), 'pdf_variant': options.get('pdf.variant', None), 'pdf_version': options.get('pdf.version', None), - 'pdf_forms': options.get('render.forms', 'false').lower() == 'true', - 'uncompressed_pdf': options.get('pdf.uncompressed', 'false').lower() == 'true', - 'custom_metadata': options.get('pdf.custom_metadata', 'false').lower() == 'true', - 'presentational_hints': options.get('render.presentational_hints', 'false').lower() == 'true', + 'pdf_forms': _is_true(options.get('render.forms', 'false')), + 'uncompressed_pdf': _is_true(options.get('pdf.uncompressed', 'false')), + 'custom_metadata': _is_true(options.get('pdf.custom_metadata', 'false')), + 'presentational_hints': _is_true(options.get('render.presentational_hints', 'false')), 'optimize_images': 'images' in optimize_size, 'jpeg_quality': int(options.get('render.jpeg_quality', '95')), 'dpi': int(options.get('render.dpi', '96')), @@ -37,9 +43,12 @@ def execute_first(self, context: dict) -> DocumentFile: return self.raise_exc(f'Step "{self.NAME}" cannot be first') def execute_follow(self, document: DocumentFile, context: dict) -> DocumentFile: + # pylint: disable-next=import-outside-toplevel import weasyprint + if document.file_format != FileFormats.HTML: - self.raise_exc(f'WeasyPrint does not support {document.file_format.name} format as input') + self.raise_exc(f'WeasyPrint does not support {document.file_format.name}' + f' format as input') file_uri = self.template.template_dir / '_file.html' wp_html = weasyprint.HTML( string=document.content.decode(DEFAULT_ENCODING), @@ -113,13 +122,15 @@ def __init__(self, template, options: dict): self.raise_exc(f'Unknown output format "{self.output_format.name}"') self.pandoc = Pandoc( config=Context.get().app.cfg, - filter_names=self._extract_filter_names(options.get(self.OPTION_FILTERS, '')), - template_name=options.get(self.OPTION_TEMPLATE, None) + filter_names=self._extract_filter_names( + filters=options.get(self.OPTION_FILTERS, ''), + ), + template_name=options.get(self.OPTION_TEMPLATE, None), ) @staticmethod def _extract_filter_names(filters: str) -> list[str]: - names = list() + names: list[str] = [] for name in filters.split(','): name = name.strip() if name: @@ -182,7 +193,10 @@ def execute_follow(self, document: DocumentFile, context: dict) -> DocumentFile: f'as input for rdflib-convert ' f'(expecting {self.input_format.name})') data = self.rdflib_convert( - self.input_format, self.output_format, document.content, self.options + source_format=self.input_format, + target_format=self.output_format, + data=document.content, + metadata=self.options, ) return DocumentFile( file_format=self.output_format, diff --git a/packages/dsw-document-worker/dsw/document_worker/templates/steps/excel.py b/packages/dsw-document-worker/dsw/document_worker/templates/steps/excel.py index 1f9032af..13772d47 100644 --- a/packages/dsw-document-worker/dsw/document_worker/templates/steps/excel.py +++ b/packages/dsw-document-worker/dsw/document_worker/templates/steps/excel.py @@ -1,23 +1,22 @@ import base64 import datetime -import dateutil.parser import io import json import pathlib +import typing +import dateutil.parser import xlsxwriter from xlsxwriter.chart import Chart from xlsxwriter.format import Format from xlsxwriter.worksheet import Worksheet -from typing import Any - from ...documents import DocumentFile, FileFormats from .base import Step, register_step, TMP_DIR, FormatStepException -_EMPTY_DICT = {} # type: dict[str, Any] -_EMPTY_LIST = [] # type: list[Any] +_EMPTY_DICT: dict[str, typing.Any] = {} +_EMPTY_LIST: list[typing.Any] = [] def _b64img2io(b64bytes: str) -> io.BytesIO: @@ -76,7 +75,7 @@ def _cell_writer_formula(worksheet: Worksheet, pos_args, item, cell_format): def _cell_writer_blank(worksheet: Worksheet, pos_args, item, cell_format): worksheet.write_blank( *pos_args, - blank=None, + blank=item, cell_format=cell_format, ) @@ -115,36 +114,37 @@ class WorkbookBuilder: def __init__(self, workbook: xlsxwriter.Workbook): self.workbook = workbook - self.sheets = list() # type: list[Worksheet] - self.formats = dict() # type: dict[str, Format] - self.charts = dict() # type: dict[str, Chart] + self.sheets: list[Worksheet] = [] + self.formats: dict[str, Format] = {} + self.charts: dict[str, Chart] = {} + self.byte_streams: list[io.BytesIO] = [] def _add_workbook_options(self, data: dict): # customized options not in regular 'constructor options' options = data.get('options', _EMPTY_DICT) - if 'active_sheet' in options.keys(): + if 'active_sheet' in options: sheet_index = options['active_sheet'] if 0 <= sheet_index < len(self.sheets): self.sheets[sheet_index].activate() - if 'vba_name' in options.keys(): + if 'vba_name' in options: self.workbook.set_vba_name(options['vba_name']) - if 'size' in options.keys(): + if 'size' in options: self.workbook.set_size( width=options['size'].get('width'), height=options['size'].get('height'), ) - if 'tab_ratio' in options.keys(): + if 'tab_ratio' in options: self.workbook.set_tab_ratio(options['tab_ratio']) def _add_workbook_properties(self, data: dict): props = data.get('properties', _EMPTY_DICT) - if 'document' in props.keys(): + if 'document' in props: if 'created' in props['document']: props['document']['created'] = dateutil.parser.parse( timestr=props['document']['created'], ) self.workbook.set_properties(props['document']) - if 'custom' in props.keys(): + if 'custom' in props: for prop in props['custom']: name = prop.get('name', 'unnamed') value = prop.get('value', '') @@ -181,51 +181,51 @@ def _add_chart(self, data: dict): def _add_chart_axis(self, chart: Chart, data: dict): axis = data.get('axis', _EMPTY_DICT) - if 'x' in axis.keys(): + if 'x' in axis: chart.set_x_axis(axis['x']) - if 'x2' in axis.keys(): + if 'x2' in axis: chart.set_x2_axis(axis['x2']) - if 'y' in axis.keys(): + if 'y' in axis: chart.set_y_axis(axis['y']) - if 'y2' in axis.keys(): + if 'y2' in axis: chart.set_y2_axis(axis['y2']) def _add_chart_basic(self, chart: Chart, data: dict): - if 'size' in data.keys(): + if 'size' in data: chart.set_size(data['size']) - if 'title' in data.keys(): + if 'title' in data: chart.set_title(data['title']) - if 'legend' in data.keys(): + if 'legend' in data: chart.set_legend(data['legend']) - if 'chartarea' in data.keys(): + if 'chartarea' in data: chart.set_chartarea(data['chartarea']) - if 'plotarea' in data.keys(): + if 'plotarea' in data: chart.set_plotarea(data['plotarea']) - if 'style' in data.keys(): + if 'style' in data: chart.set_style(data['style']) - if 'table' in data.keys(): + if 'table' in data: chart.set_table(data['table']) def _add_chart_advanced(self, chart: Chart, data: dict): - if 'combine' in data.keys(): + if 'combine' in data: other = data['combine'] - if other in self.charts.keys(): + if other in self.charts: chart.combine(self.charts[other]) - if 'up_down_bars' in data.keys(): + if 'up_down_bars' in data: chart.set_up_down_bars(data['up_down_bars']) - if 'drop_lines' in data.keys(): + if 'drop_lines' in data: chart.set_drop_lines(data['drop_lines']) - if 'high_low_lines' in data.keys(): + if 'high_low_lines' in data: chart.set_high_low_lines(data['high_low_lines']) - if 'show_blanks_as' in data.keys(): + if 'show_blanks_as' in data: chart.show_blanks_as(data['show_blanks_as']) - if 'show_hidden_data' in data.keys(): + if 'show_hidden_data' in data: chart.show_hidden_data() def _add_chartsheet(self, data: dict): name = data.get('name', None) chart_name = data.get('chart', '') - if chart_name not in self.charts.keys(): + if chart_name not in self.charts: return # ignore if chart is missing sheet = self.workbook.add_chartsheet(name) sheet.set_chart(self.charts[chart_name]) @@ -258,7 +258,8 @@ def _add_data_to_worksheet(self, worksheet: Worksheet, data: dict): item_type = item.get('type', None) if item_type is None: continue - elif item_type == 'cell': + + if item_type == 'cell': self._add_data_cell(worksheet, item) elif item_type == 'row': self._add_data_row(worksheet, item) @@ -274,7 +275,7 @@ def _add_data_to_worksheet(self, worksheet: Worksheet, data: dict): def _add_data_cell(self, worksheet: Worksheet, item: dict): subtype = item.get('subtype', '') cell_format = self.formats.get(item.get('format', ''), None) - if 'cell' in item.keys(): + if 'cell' in item: pos_args = [item['cell']] else: pos_args = [ @@ -287,7 +288,7 @@ def _add_data_cell(self, worksheet: Worksheet, item: dict): item.pop('cell', None) item.pop('row', None) item.pop('col', None) - if subtype in _CELL_WRITERS.keys(): + if subtype in _CELL_WRITERS: _CELL_WRITERS[subtype](worksheet, pos_args, item, cell_format) elif subtype == 'rich_string': parts = [] @@ -307,7 +308,7 @@ def _add_data_cell(self, worksheet: Worksheet, item: dict): def _add_data_row(self, worksheet: Worksheet, item: dict): cell_format = self.formats.get(item.get('format', ''), None) - if 'cell' in item.keys(): + if 'cell' in item: worksheet.write_row( item['cell'], data=item.get('data', []), @@ -323,7 +324,7 @@ def _add_data_row(self, worksheet: Worksheet, item: dict): def _add_data_column(self, worksheet: Worksheet, item: dict): cell_format = self.formats.get(item.get('format', ''), None) - if 'cell' in item.keys(): + if 'cell' in item: worksheet.write_column( item['cell'], data=item.get('data', []), @@ -355,7 +356,7 @@ def _add_data_chart(self, worksheet: Worksheet, item: dict): chart = self.charts.get(item.get('chart', ''), None) if chart is None: return - if 'cell' in item.keys(): + if 'cell' in item: worksheet.insert_chart( item['cell'], chart=chart, @@ -371,7 +372,7 @@ def _add_data_chart(self, worksheet: Worksheet, item: dict): @staticmethod def _add_data_comment(worksheet: Worksheet, item: dict): - if 'cell' in item.keys(): + if 'cell' in item: worksheet.write_comment( item['cell'], comment=item.get('comment', ''), @@ -387,7 +388,7 @@ def _add_data_comment(worksheet: Worksheet, item: dict): @staticmethod def _add_data_textbox(worksheet: Worksheet, item: dict): - if 'cell' in item.keys(): + if 'cell' in item: worksheet.insert_textbox( item['cell'], text=item.get('text', ''), @@ -403,7 +404,7 @@ def _add_data_textbox(worksheet: Worksheet, item: dict): @staticmethod def _add_data_button(worksheet: Worksheet, item: dict): - if 'cell' in item.keys(): + if 'cell' in item: worksheet.insert_button( item['cell'], options=item.get('options', None), @@ -415,15 +416,14 @@ def _add_data_button(worksheet: Worksheet, item: dict): options=item.get('options', None), ) - @staticmethod - def _add_data_image(worksheet: Worksheet, item: dict): - bytes_io = io.BytesIO() - if 'b64bytes' in item.keys(): - if 'options' not in item.keys(): - item['options'] = dict() + def _add_data_image(self, worksheet: Worksheet, item: dict): + if 'b64bytes' in item: + if 'options' not in item: + item['options'] = {} bytes_io = _b64img2io(item['b64bytes']) item['options']['image_data'] = bytes_io - if 'cell' in item.keys(): + self.byte_streams.append(bytes_io) + if 'cell' in item: worksheet.insert_image( item['cell'], filename=item.get('filename', ''), @@ -436,14 +436,13 @@ def _add_data_image(worksheet: Worksheet, item: dict): filename=item.get('filename', ''), options=item.get('options', None), ) - # TODO: check closed io at the end (cannot close before closing Excel) def _add_data_array_formula(self, worksheet: Worksheet, item: dict): method = worksheet.write_array_formula if item.get('dynamic', False): method = worksheet.write_dynamic_array_formula cell_format = self.formats.get(item.get('format', ''), None) - if 'range' in item.keys(): + if 'range' in item: method( item['range'], formula=item.get('formula', ''), @@ -463,23 +462,23 @@ def _add_data_array_formula(self, worksheet: Worksheet, item: dict): @classmethod def _setup_worksheet_print(cls, worksheet: Worksheet, data: dict): - if 'orientation' in data.keys(): + if 'orientation' in data: if data['orientation'] == 'landscape': worksheet.set_landscape() elif data['orientation'] == 'portrait': worksheet.set_portrait() - if 'paper' in data.keys(): + if 'paper' in data: worksheet.set_paper(data['paper']) - if 'margins' in data.keys(): + if 'margins' in data: worksheet.set_margins(**data['margins']) - if 'header' in data.keys(): + if 'header' in data: options = data['header'].get('options', None) cls._fix_footer_header_images(options) worksheet.set_header( header=data['header'].get('content', ''), options=options, ) - if 'footer' in data.keys(): + if 'footer' in data: options = data['footer'].get('options', None) cls._fix_footer_header_images(options) worksheet.set_footer( @@ -496,12 +495,12 @@ def _fix_footer_header_images(options): if not isinstance(options, dict): return for key in ('image_data_left', 'image_data_center', 'image_data_right'): - if key in options.keys(): + if key in options: options[key] = _b64img2io(options[key]) @classmethod - def _setup_worksheet_common(cls, worksheet: Worksheet, data: dict): - data = data.get('options', None) + def _setup_worksheet_common(cls, worksheet: Worksheet, container: dict): + data: dict | None = container.get('options', None) if data is None: return if data.get('select', False): @@ -510,21 +509,21 @@ def _setup_worksheet_common(cls, worksheet: Worksheet, data: dict): worksheet.hide() if data.get('first_sheet', False): worksheet.set_first_sheet() - if 'protect' in data.keys(): + if 'protect' in data: worksheet.protect( password=data['protect'].get('password', ''), options=data['protect'].get('options', None), ) - if 'zoom' in data.keys(): + if 'zoom' in data: worksheet.set_zoom(data['zoom']) - if 'tab_color' in data.keys(): + if 'tab_color' in data: worksheet.set_tab_color(data['tab_color']) if data.get('page_view', False): worksheet.set_page_view() cls._setup_worksheet_print(worksheet, data) - def _setup_worksheet_data(self, worksheet: Worksheet, data: dict): - data = data.get('options', None) + def _setup_worksheet_data(self, worksheet: Worksheet, container: dict): + data: dict | None = container.get('options', None) if data is None: return self._setup_worksheet_basic(worksheet, data) @@ -545,7 +544,7 @@ def _setup_worksheet_data(self, worksheet: Worksheet, data: dict): @staticmethod def _setup_worksheet_basic(worksheet: Worksheet, data: dict): - if 'comments_author' in data.keys(): + if 'comments_author' in data: worksheet.set_comments_author(data['comments_author']) if data.get('hide_zero', False): worksheet.hide_zero() @@ -553,18 +552,18 @@ def _setup_worksheet_basic(worksheet: Worksheet, data: dict): worksheet.hide_row_col_headers() if data.get('right_to_left', False): worksheet.right_to_left() - if 'hide_gridlines' in data.keys(): + if 'hide_gridlines' in data: worksheet.hide_gridlines(data['hide_gridlines']) - if 'ignore_errors' in data.keys(): + if 'ignore_errors' in data: worksheet.ignore_errors(data['ignore_errors']) - if 'vba_name' in data.keys(): + if 'vba_name' in data: worksheet.set_vba_name(data['vba_name']) @staticmethod def _setup_worksheet_printing(worksheet: Worksheet, data: dict): if data.get('print_row_col_headers', False): worksheet.print_row_col_headers() - if 'print_area' in data.keys(): + if 'print_area' in data: area = data['print_area'] if isinstance(area, dict): worksheet.print_area(**area) @@ -572,27 +571,27 @@ def _setup_worksheet_printing(worksheet: Worksheet, data: dict): worksheet.print_area(area) if data.get('print_across', False): worksheet.print_across() - if 'fit_to_pages' in data.keys(): + if 'fit_to_pages' in data: worksheet.fit_to_pages( width=data['fit_to_pages'].get('width', 1), height=data['fit_to_pages'].get('height', 1), ) - if 'start_page' in data.keys(): + if 'start_page' in data: worksheet.set_start_page(data['start_page']) - if 'print_scale' in data.keys(): + if 'print_scale' in data: worksheet.set_print_scale(data['print_scale']) if data.get('print_black_and_white', False): worksheet.print_black_and_white() @staticmethod def _setup_worksheet_special_ranges(worksheet: Worksheet, data: dict): - if 'unprotect_ranges' in data.keys(): + if 'unprotect_ranges' in data: for r in data['unprotect_ranges']: worksheet.unprotect_range( cell_range=r.get('range', 'A1'), range_name=r.get('name', None), ) - if 'top_left_cell' in data.keys(): + if 'top_left_cell' in data: if isinstance(data['top_left_cell'], str): worksheet.set_top_left_cell(data['top_left_cell']) else: @@ -600,7 +599,7 @@ def _setup_worksheet_special_ranges(worksheet: Worksheet, data: dict): row=data['top_left_cell'].get('row', 0), col=data['top_left_cell'].get('col', 0), ) - if 'selection' in data.keys(): + if 'selection' in data: if isinstance(data['selection'], str): worksheet.set_selection(data['selection']) else: @@ -615,17 +614,17 @@ def _setup_worksheet_special_ranges(worksheet: Worksheet, data: dict): @staticmethod def _setup_worksheet_repeats(worksheet: Worksheet, data: dict): - if 'repeat_rows' in data.keys(): + if 'repeat_rows' in data: worksheet.repeat_rows( first_row=data['repeat_rows'].get('first_row', 0), last_row=data['repeat_rows'].get('last_row', None), ) - if 'repeat_columns' in data.keys(): + if 'repeat_columns' in data: worksheet.repeat_columns( first_row=data['repeat_columns'].get('first_col', 0), last_row=data['repeat_columns'].get('last_col', None), ) - if 'default_row' in data.keys(): + if 'default_row' in data: worksheet.set_default_row( height=data['default_row'].get('height', 15), hide_unused_rows=data['default_row'].get('hide_unused_rows', False), @@ -633,11 +632,11 @@ def _setup_worksheet_repeats(worksheet: Worksheet, data: dict): @staticmethod def _setup_worksheet_paging(worksheet: Worksheet, data: dict): - if 'h_pagebreaks' in data.keys(): + if 'h_pagebreaks' in data: worksheet.set_h_pagebreaks(data['h_pagebreaks']) - if 'v_pagebreaks' in data.keys(): + if 'v_pagebreaks' in data: worksheet.set_v_pagebreaks(data['v_pagebreaks']) - if 'outline_settings' in data.keys(): + if 'outline_settings' in data: worksheet.outline_settings( visible=data['outline_settings'].get('visible', True), symbols_below=data['outline_settings'].get('symbols_below', True), @@ -647,14 +646,14 @@ def _setup_worksheet_paging(worksheet: Worksheet, data: dict): @staticmethod def _setup_worksheet_panes(worksheet: Worksheet, data: dict): - if 'split_panes' in data.keys(): + if 'split_panes' in data: worksheet.split_panes( x=data['split_panes'].get('x', 0), y=data['split_panes'].get('y', 0), top_row=data['split_panes'].get('top_row', 0), left_col=data['split_panes'].get('left_col', 0), ) - if 'freeze_panes' in data.keys(): + if 'freeze_panes' in data: worksheet.split_panes( x=data['split_panes'].get('x', 0), y=data['split_panes'].get('y', 0), @@ -664,19 +663,19 @@ def _setup_worksheet_panes(worksheet: Worksheet, data: dict): @staticmethod def _setup_worksheet_filters(worksheet: Worksheet, data: dict): - if 'filter_column_lists' in data.keys(): + if 'filter_column_lists' in data: for f in data['filter_column_list']: worksheet.filter_column_list( col=f.get('col', 0), filters=f.get('filters', _EMPTY_LIST), ) - if 'filter_columns' in data.keys(): + if 'filter_columns' in data: for f in data['filter_column_list']: worksheet.filter_column( col=f.get('col', 0), criteria=f.get('criteria', _EMPTY_LIST), ) - if 'autofilter' in data.keys(): + if 'autofilter' in data: if isinstance(data['autofilter'], str): worksheet.set_selection(data['autofilter']) else: @@ -688,10 +687,10 @@ def _setup_worksheet_filters(worksheet: Worksheet, data: dict): ) def _setup_worksheet_merge_ranges(self, worksheet: Worksheet, data: dict): - if 'merge_ranges' in data.keys(): + if 'merge_ranges' in data: for item in data['merge_ranges']: cell_format = self.formats.get(item.get('format', ''), None) - if 'range' in item.keys(): + if 'range' in item: worksheet.merge_range( item['range'], data=item.get('data', ''), @@ -709,9 +708,9 @@ def _setup_worksheet_merge_ranges(self, worksheet: Worksheet, data: dict): @staticmethod def _setup_worksheet_data_validations(worksheet: Worksheet, data: dict): - if 'data_validations' in data.keys(): + if 'data_validations' in data: for item in data['data_validations']: - if 'range' in item.keys(): + if 'range' in item: worksheet.data_validation( item['range'], options=item.get('options', ''), @@ -726,11 +725,11 @@ def _setup_worksheet_data_validations(worksheet: Worksheet, data: dict): ) def _setup_worksheet_conditional_formats(self, worksheet: Worksheet, data: dict): - if 'conditional_formats' in data.keys(): + if 'conditional_formats' in data: for item in data['conditional_formats']: - if 'format' in item.keys(): + if 'format' in item: item['format'] = self.formats.get(item['format'], None) - if 'range' in item.keys(): + if 'range' in item: worksheet.conditional_format( item['range'], options=item.get('options', ''), @@ -745,13 +744,13 @@ def _setup_worksheet_conditional_formats(self, worksheet: Worksheet, data: dict) ) def _setup_worksheet_tables(self, worksheet: Worksheet, data: dict): - if 'tables' in data.keys(): + if 'tables' in data: for item in data['tables']: self._replace_nested_formats( data=item, keys=frozenset(['format', 'header_format']), ) - if 'range' in item.keys(): + if 'range' in item: worksheet.add_table( item['range'], options=item.get('options', ''), @@ -767,9 +766,9 @@ def _setup_worksheet_tables(self, worksheet: Worksheet, data: dict): @staticmethod def _setup_worksheet_sparklines(worksheet: Worksheet, data: dict): - if 'sparklines' in data.keys(): + if 'sparklines' in data: for item in data['sparklines']: - if 'cell' in item.keys(): + if 'cell' in item: worksheet.add_sparkline( item['cell'], options=item.get('options', ''), @@ -782,9 +781,9 @@ def _setup_worksheet_sparklines(worksheet: Worksheet, data: dict): ) def _setup_worksheet_row_sizing(self, worksheet: Worksheet, data: dict): - if 'row_pixels' in data.keys(): + if 'row_pixels' in data: for item in data['row_pixels']: - if 'format' in item.keys(): + if 'format' in item: item['format'] = self.formats.get(item['format'], None) worksheet.set_row_pixels( row=item.get('row', 0), @@ -792,9 +791,9 @@ def _setup_worksheet_row_sizing(self, worksheet: Worksheet, data: dict): cell_format=item.get('format', None), options=item.get('options', None), ) - if 'rows' in data.keys(): + if 'rows' in data: for item in data['rows']: - if 'format' in item.keys(): + if 'format' in item: item['format'] = self.formats.get(item['format'], None) worksheet.set_row( row=item.get('row', 0), @@ -804,9 +803,9 @@ def _setup_worksheet_row_sizing(self, worksheet: Worksheet, data: dict): ) def _setup_worksheet_col_sizing(self, worksheet: Worksheet, data: dict): - if 'column_pixels' in data.keys(): + if 'column_pixels' in data: for item in data['column_pixels']: - if 'format' in item.keys(): + if 'format' in item: item['format'] = self.formats.get(item['format'], None) first_col = item.get('first_col', item.get('col', 0)) worksheet.set_column_pixels( @@ -816,9 +815,9 @@ def _setup_worksheet_col_sizing(self, worksheet: Worksheet, data: dict): cell_format=item.get('format', None), options=item.get('options', None), ) - if 'columns' in data.keys(): + if 'columns' in data: for item in data['columns']: - if 'format' in item.keys(): + if 'format' in item: item['format'] = self.formats.get(item['format'], None) first_col = item.get('first_col', item.get('col', 0)) worksheet.set_column( @@ -831,13 +830,13 @@ def _setup_worksheet_col_sizing(self, worksheet: Worksheet, data: dict): @staticmethod def _setup_worksheet_background(worksheet: Worksheet, data: dict): - if 'background' in data.keys(): - if 'filename' in data['background'].keys(): + if 'background' in data: + if 'filename' in data['background']: worksheet.set_background( filename=data['background']['filename'], is_byte_stream=False, ) - elif 'b64bytes' in data['background'].keys(): + elif 'b64bytes' in data['background']: bytes_io = _b64img2io(data['background']['b64bytes']) worksheet.set_background( bytes_io, @@ -880,6 +879,10 @@ def build(self, data: dict): else: self._add_worksheet(sheet_data) + def cleanup(self): + for stream in self.byte_streams: + stream.close() + @staticmethod def build_to_bytes(tmp_file: pathlib.Path, input_data: dict) -> bytes: options = input_data.get('options', None) @@ -889,6 +892,7 @@ def build_to_bytes(tmp_file: pathlib.Path, input_data: dict) -> bytes: builder.build(data=input_data) data = tmp_file.read_bytes() tmp_file.unlink() + builder.cleanup() return data @staticmethod @@ -938,7 +942,7 @@ def execute_follow(self, document: DocumentFile, context: dict) -> DocumentFile: ) except Exception as e: raise FormatStepException(f'Failed to construct Excel document ' - f'due to: {str(e)}') + f'due to: {str(e)}') from e register_step(ExcelStep.NAME, ExcelStep) diff --git a/packages/dsw-document-worker/dsw/document_worker/templates/steps/template.py b/packages/dsw-document-worker/dsw/document_worker/templates/steps/template.py index 8ac5e38c..b012d54d 100644 --- a/packages/dsw-document-worker/dsw/document_worker/templates/steps/template.py +++ b/packages/dsw-document-worker/dsw/document_worker/templates/steps/template.py @@ -1,15 +1,18 @@ -import typing as t +import json +import typing +import gettext import jinja2 import jinja2.exceptions import jinja2.sandbox -import json - -from typing import Any +import rdflib from ...consts import DEFAULT_ENCODING from ...context import Context from ...documents import DocumentFile, FileFormat, FileFormats +from ...model.http import RequestsWrapper +from ..filters import filters +from ..tests import tests from .base import Step, register_step @@ -30,7 +33,7 @@ def execute_follow(self, document: DocumentFile, context: dict) -> DocumentFile: class JinjaEnvironment(jinja2.sandbox.SandboxedEnvironment): - def is_safe_attribute(self, obj: t.Any, attr: str, value: t.Any) -> bool: + def is_safe_attribute(self, obj: typing.Any, attr: str, value: typing.Any) -> bool: if attr in ['os', 'subprocess', 'eval', 'exec', 'popen', 'system']: return False if attr == '__setitem__' and isinstance(obj, dict): @@ -38,32 +41,14 @@ def is_safe_attribute(self, obj: t.Any, attr: str, value: t.Any) -> bool: return super().is_safe_attribute(obj, attr, value) -class Jinja2Step(Step): - NAME = 'jinja' - DEFAULT_FORMAT = FileFormats.HTML - - OPTION_ROOT_FILE = 'template' - OPTION_CONTENT_TYPE = 'content-type' - OPTION_EXTENSION = 'extension' +class JinjaPoweredStep(Step): OPTION_JINJA_EXT = 'jinja-ext' OPTION_I18N_DIR = 'i18n-dir' OPTION_I18N_DOMAIN = 'i18n-domain' OPTION_I18N_LANG = 'i18n-lang' - def _jinja_exception_msg(self, e: jinja2.exceptions.TemplateSyntaxError): - lines = [ - 'Failed loading Jinja2 template due to syntax error:', - f'- {e.message}', - f'- Filename: {e.name}', - f'- Line number: {e.lineno}', - ] - return '\n'.join(lines) - - def __init__(self, template, options: dict): + def __init__(self, template, options): super().__init__(template, options) - self.root_file = self.options[self.OPTION_ROOT_FILE] - self.content_type = self.options.get(self.OPTION_CONTENT_TYPE, self.DEFAULT_FORMAT.content_type) - self.extension = self.options.get(self.OPTION_EXTENSION, self.DEFAULT_FORMAT.file_extension) self.jinja_ext = frozenset( map(lambda x: x.strip(), self.options.get(self.OPTION_JINJA_EXT, '').split(',')) ) @@ -71,7 +56,6 @@ def __init__(self, template, options: dict): self.i18n_domain = self.options.get(self.OPTION_I18N_DOMAIN, 'default') self.i18n_lang = self.options.get(self.OPTION_I18N_LANG, None) - self.output_format = FileFormat(self.extension, self.content_type, self.extension) try: self.j2_env = JinjaEnvironment( loader=jinja2.FileSystemLoader(searchpath=template.template_dir), @@ -86,21 +70,29 @@ def __init__(self, template, options: dict): self.j2_env.add_extension('jinja2.ext.debug') self._apply_policies(options) self._add_j2_enhancements() - self.j2_root_template = self.j2_env.get_template(self.root_file) except jinja2.exceptions.TemplateSyntaxError as e: self.raise_exc(self._jinja_exception_msg(e)) except Exception as e: self.raise_exc(f'Failed loading Jinja2 template: {e}') + def _jinja_exception_msg(self, e: jinja2.exceptions.TemplateSyntaxError): + lines = [ + 'Failed loading Jinja2 template due to syntax error:', + f'- {e.message}', + f'- Filename: {e.name}', + f'- Line number: {e.lineno}', + ] + return '\n'.join(lines) + def _apply_policies(self, options: dict): # https://jinja.palletsprojects.com/en/3.0.x/api/#policies - policies = { + policies: dict[str, typing.Any] = { 'policy.urlize.target': '_blank', 'json.dumps_kwargs': { 'allow_nan': False, 'ensure_ascii': False, }, - } # type: dict[str,Any] + } if 'policy.truncate.leeway' in options: policies['truncate.leeway'] = options['policy.truncate.leeway'] if 'policy.urlize.rel' in options: @@ -127,23 +119,19 @@ def _add_j2_i18n(self, template): # https://jinja.palletsprojects.com/en/3.1.x/extensions/#i18n-extension self.j2_env.add_extension('jinja2.ext.i18n') if self.i18n_dir is not None and self.i18n_lang is not None: - import gettext locale_path = template.template_dir / self.i18n_dir translations = gettext.translation( domain=self.i18n_domain, localedir=locale_path, languages=map(lambda x: x.strip(), self.i18n_lang.split(',')), ) + # pylint: disable-next=no-member self.j2_env.install_gettext_translations(translations, newstyle=True) # type: ignore else: + # pylint: disable-next=no-member self.j2_env.install_null_translations(newstyle=True) # type: ignore def _add_j2_enhancements(self): - from ..filters import filters - from ..tests import tests - from ...model.http import RequestsWrapper - import rdflib - import json self.j2_env.filters.update(filters) self.j2_env.tests.update(tests) template_cfg = Context.get().app.cfg.templates.get_config( @@ -151,13 +139,38 @@ def _add_j2_enhancements(self): ) self.j2_env.globals.update({'rdflib': rdflib, 'json': json}) if template_cfg is not None: - global_vars = {'secrets': template_cfg.secrets} # type: dict[str,Any] + global_vars: dict[str, typing.Any] = {'secrets': template_cfg.secrets} if template_cfg.requests.enabled: global_vars['requests'] = RequestsWrapper( template_cfg=template_cfg, ) self.j2_env.globals.update(global_vars) + +class Jinja2Step(JinjaPoweredStep): + NAME = 'jinja' + DEFAULT_FORMAT = FileFormats.HTML + + OPTION_ROOT_FILE = 'template' + OPTION_CONTENT_TYPE = 'content-type' + OPTION_EXTENSION = 'extension' + + def __init__(self, template, options: dict): + super().__init__(template, options) + self.root_file = self.options[self.OPTION_ROOT_FILE] + self.content_type = self.options.get(self.OPTION_CONTENT_TYPE, + self.DEFAULT_FORMAT.content_type) + self.extension = self.options.get(self.OPTION_EXTENSION, + self.DEFAULT_FORMAT.file_extension) + + self.output_format = FileFormat(self.extension, self.content_type, self.extension) + try: + self.j2_root_template = self.j2_env.get_template(self.root_file) + except jinja2.exceptions.TemplateSyntaxError as e: + self.raise_exc(self._jinja_exception_msg(e)) + except Exception as e: + self.raise_exc(f'Failed loading Jinja2 template: {e}') + def _execute(self, **jinja_args): def asset_fetcher(file_name): return self.template.fetch_asset(file_name) diff --git a/packages/dsw-document-worker/dsw/document_worker/templates/steps/word.py b/packages/dsw-document-worker/dsw/document_worker/templates/steps/word.py index 657584a8..57efa4fc 100644 --- a/packages/dsw-document-worker/dsw/document_worker/templates/steps/word.py +++ b/packages/dsw-document-worker/dsw/document_worker/templates/steps/word.py @@ -1,67 +1,28 @@ import pathlib -import jinja2 import shutil import zipfile -from typing import Any, Optional +import jinja2 from ...consts import DEFAULT_ENCODING -from ...context import Context from ...documents import DocumentFile, FileFormats -from .base import Step, register_step, TMP_DIR +from .base import register_step, TMP_DIR +from .template import JinjaPoweredStep -class EnrichDocxStep(Step): +class EnrichDocxStep(JinjaPoweredStep): NAME = 'enrich-docx' INPUT_FORMAT = FileFormats.DOCX OUTPUT_FORMAT = FileFormats.DOCX - def _jinja_exception_msg(self, e: jinja2.exceptions.TemplateSyntaxError): - lines = [ - 'Failed loading Jinja2 template due to syntax error:', - f'- {e.message}', - f'- Filename: {e.name}', - f'- Line number: {e.lineno}', - ] - return '\n'.join(lines) - def __init__(self, template, options: dict): super().__init__(template, options) self.rewrites = {k[8:]: v for k, v in options.items() if k.startswith('rewrite:')} - # TODO: shared part with Jinja2Step - try: - self.j2_env = jinja2.Environment( - loader=jinja2.FileSystemLoader(searchpath=template.template_dir), - extensions=['jinja2.ext.do'], - ) - self._add_j2_enhancements() - except jinja2.exceptions.TemplateSyntaxError as e: - self.raise_exc(self._jinja_exception_msg(e)) - except Exception as e: - self.raise_exc(f'Failed loading Jinja2 template: {e}') - - def _add_j2_enhancements(self): - # TODO: shared part with Jinja2Step - from ..filters import filters - from ..tests import tests - from ...model.http import RequestsWrapper - self.j2_env.filters.update(filters) - self.j2_env.tests.update(tests) - template_cfg = Context.get().app.cfg.templates.get_config( - self.template.template_id, - ) - if template_cfg is not None: - global_vars = {'secrets': template_cfg.secrets} # type: dict[str, Any] - if template_cfg.requests.enabled: - global_vars['requests'] = RequestsWrapper( - template_cfg=template_cfg, - ) - self.j2_env.globals.update(global_vars) def _render_rewrite(self, rewrite_template: str, context: dict, - existing_content: Optional[str]) -> str: + existing_content: str | None) -> str: try: j2_template = self.j2_env.get_template(rewrite_template) return j2_template.render( @@ -76,16 +37,17 @@ def _render_rewrite(self, rewrite_template: str, context: dict, def _static_rewrite(self, rewrite_file: str) -> str: try: - path = self.template.template_dir / rewrite_file # type: pathlib.Path + path: pathlib.Path = self.template.template_dir / rewrite_file return path.read_text(encoding=DEFAULT_ENCODING) except Exception as e: self.raise_exc(f'Failed loading Jinja2 template: {e}') return '' - def _get_rewrite(self, rewrite: str, context: dict, existing_content: Optional[str]) -> str: + def _get_rewrite(self, rewrite: str, context: dict, + existing_content: str | None) -> str: if rewrite.startswith('static:'): return self._static_rewrite(rewrite[7:]) - elif rewrite.startswith('render:'): + if rewrite.startswith('render:'): return self._render_rewrite(rewrite[7:], context, existing_content) return '' diff --git a/packages/dsw-document-worker/dsw/document_worker/templates/templates.py b/packages/dsw-document-worker/dsw/document_worker/templates/templates.py index 32054333..6e73a714 100644 --- a/packages/dsw-document-worker/dsw/document_worker/templates/templates.py +++ b/packages/dsw-document-worker/dsw/document_worker/templates/templates.py @@ -1,11 +1,10 @@ import base64 +import dataclasses import datetime import logging import pathlib import shutil -from typing import Optional - from dsw.database.database import DBDocumentTemplate, \ DBDocumentTemplateFile, DBDocumentTemplateAsset @@ -47,12 +46,11 @@ def src_value(self): return f'data:{self.content_type};base64,{self.data_base64}' +@dataclasses.dataclass class TemplateComposite: - - def __init__(self, db_template, db_files, db_assets): - self.template = db_template # type: DBDocumentTemplate - self.files = db_files # type: dict[str, DBDocumentTemplateFile] - self.assets = db_assets # type: dict[str, DBDocumentTemplateAsset] + template: DBDocumentTemplate + files: dict[str, DBDocumentTemplateFile] + assets: dict[str, DBDocumentTemplateAsset] class Template: @@ -64,7 +62,7 @@ def __init__(self, tenant_uuid: str, template_dir: pathlib.Path, self.last_used = datetime.datetime.now(tz=datetime.UTC) self.db_template = db_template self.template_id = self.db_template.template.id - self.formats = dict() # type: dict[str, Format] + self.formats: dict[str, Format] = {} self.asset_prefix = f'templates/{self.db_template.template.id}' if Context.get().app.cfg.cloud.multi_tenant: self.asset_prefix = f'{self.tenant_uuid}/{self.asset_prefix}' @@ -72,8 +70,8 @@ def __init__(self, tenant_uuid: str, template_dir: pathlib.Path, def raise_exc(self, message: str): raise TemplateException(self.template_id, message) - def fetch_asset(self, file_name: str) -> Optional[Asset]: - LOG.info(f'Fetching asset "{file_name}"') + def fetch_asset(self, file_name: str) -> Asset | None: + LOG.info('Fetching asset "%s"', file_name) file_path = self.template_dir / file_name asset = None for a in self.db_template.assets.values(): @@ -81,7 +79,7 @@ def fetch_asset(self, file_name: str) -> Optional[Asset]: asset = a break if asset is None or not file_path.exists(): - LOG.error(f'Asset "{file_name}" not found') + LOG.error('Asset "%s" not found', file_name) return None return Asset( asset_uuid=asset.uuid, @@ -94,16 +92,19 @@ def asset_path(self, filename: str) -> str: return str(self.template_dir / filename) def _store_asset(self, asset: DBDocumentTemplateAsset): - LOG.debug(f'Storing asset {asset.uuid} ({asset.file_name})') + LOG.debug('Storing asset %s (%s)', asset.uuid, asset.file_name) remote_path = f'{self.asset_prefix}/{asset.uuid}' local_path = self.template_dir / asset.file_name local_path.parent.mkdir(parents=True, exist_ok=True) - result = Context.get().app.s3.download_file(remote_path, local_path) + result = Context.get().app.s3.download_file( + file_name=remote_path, + target_path=local_path, + ) if not result: - LOG.error(f'Asset "{local_path.name}" cannot be retrieved') + LOG.error('Asset "%s" cannot be retrieved', local_path.name) def _store_file(self, file: DBDocumentTemplateFile): - LOG.debug(f'Storing file {file.uuid} ({file.file_name})') + LOG.debug('Storing file %s (%s)', file.uuid, file.file_name) local_path = self.template_dir / file.file_name local_path.parent.mkdir(parents=True, exist_ok=True) local_path.write_text( @@ -112,45 +113,45 @@ def _store_file(self, file: DBDocumentTemplateFile): ) def _delete_asset(self, asset: DBDocumentTemplateAsset): - LOG.debug(f'Deleting asset {asset.uuid} ({asset.file_name})') + LOG.debug('Deleting asset %s (%s)', asset.uuid, asset.file_name) local_path = self.template_dir / asset.file_name local_path.unlink(missing_ok=True) def _delete_file(self, file: DBDocumentTemplateFile): - LOG.debug(f'Deleting file {file.uuid} ({file.file_name})') + LOG.debug('Deleting file %s (%s)', file.uuid, file.file_name) local_path = self.template_dir / file.file_name local_path.unlink(missing_ok=True) def _update_asset(self, asset: DBDocumentTemplateAsset): - LOG.debug(f'Updating asset {asset.uuid} ({asset.file_name})') + LOG.debug('Updating asset %s (%s)', asset.uuid, asset.file_name) old_asset = self.db_template.assets[asset.uuid] local_path = self.template_dir / asset.file_name if old_asset.updated_at == asset.updated_at and local_path.exists(): - LOG.debug(f'- Asset {asset.uuid} ({asset.file_name}) did not change') + LOG.debug('- Asset %s (%s) did not change', asset.uuid, asset.file_name) return self._store_asset(asset) def _update_file(self, file: DBDocumentTemplateFile): - LOG.debug(f'Updating file {file.uuid} ({file.file_name})') + LOG.debug('Updating file %s (%s)', file.uuid, file.file_name) old_file = self.db_template.files[file.uuid] local_path = self.template_dir / file.file_name if old_file.updated_at == file.updated_at and local_path.exists(): - LOG.debug(f'- File {file.uuid} ({file.file_name}) did not change') + LOG.debug('- File %s (%s) did not change', file.uuid, file.file_name) return self._store_file(file) def prepare_all_template_files(self): - LOG.info(f'Storing all files of template {self.template_id} locally') + LOG.info('Storing all files of template %s locally', self.template_id) for file in self.db_template.files.values(): self._store_file(file) def prepare_all_template_assets(self): - LOG.info(f'Storing all assets of template {self.template_id} locally') + LOG.info('Storing all assets of template %s locally', self.template_id) for asset in self.db_template.assets.values(): self._store_asset(asset) def prepare_fs(self): - LOG.info(f'Preparing directory for template {self.template_id}') + LOG.info('Preparing directory for template %s', self.template_id) if self.template_dir.exists(): shutil.rmtree(self.template_dir) self.template_dir.mkdir(parents=True) @@ -165,7 +166,7 @@ def _resolve_change(old_keys: frozenset[str], new_keys: frozenset[str]): return to_add, to_del, to_chk def update_template_files(self, db_files: dict[str, DBDocumentTemplateFile]): - LOG.info(f'Updating files of template {self.template_id}') + LOG.info('Updating files of template %s', self.template_id) to_add, to_del, to_chk = self._resolve_change( old_keys=frozenset(self.db_template.files.keys()), new_keys=frozenset(db_files.keys()), @@ -179,7 +180,7 @@ def update_template_files(self, db_files: dict[str, DBDocumentTemplateFile]): self.db_template.files = db_files def update_template_assets(self, db_assets: dict[str, DBDocumentTemplateAsset]): - LOG.info(f'Updating assets of template {self.template_id}') + LOG.info('Updating assets of template %s', self.template_id) to_add, to_del, to_chk = self._resolve_change( old_keys=frozenset(self.db_template.assets.keys()), new_keys=frozenset(db_assets.keys()), @@ -231,15 +232,15 @@ def get(cls) -> 'TemplateRegistry': return cls._instance def __init__(self): - self._templates = dict() # type: dict[str, dict[str, Template]] + self._templates: dict[str, dict[str, Template]] = {} def has_template(self, tenant_uuid: str, template_id: str) -> bool: - return tenant_uuid in self._templates.keys() and \ - template_id in self._templates[tenant_uuid].keys() + return tenant_uuid in self._templates and \ + template_id in self._templates[tenant_uuid] def _set_template(self, tenant_uuid: str, template_id: str, template: Template): - if tenant_uuid not in self._templates.keys(): - self._templates[tenant_uuid] = dict() + if tenant_uuid not in self._templates: + self._templates[tenant_uuid] = {} self._templates[tenant_uuid][template_id] = template def get_template(self, tenant_uuid: str, template_id: str) -> Template: @@ -264,19 +265,19 @@ def _refresh_template(self, tenant_uuid: str, template_id: str, def prepare_template(self, tenant_uuid: str, template_id: str) -> Template: ctx = Context.get() - query_args = dict( - template_id=template_id, - tenant_uuid=tenant_uuid, - ) + query_args = { + 'template_id': template_id, + 'tenant_uuid': tenant_uuid, + } db_template = ctx.app.db.fetch_template(**query_args) if db_template is None: raise RuntimeError(f'Template {template_id} not found in database') db_files = ctx.app.db.fetch_template_files(**query_args) db_assets = ctx.app.db.fetch_template_assets(**query_args) template_composite = TemplateComposite( - db_template=db_template, - db_files={f.uuid: f for f in db_files}, - db_assets={f.uuid: f for f in db_assets}, + template=db_template, + files={f.uuid: f for f in db_files}, + assets={f.uuid: f for f in db_assets}, ) if self.has_template(tenant_uuid, template_id): @@ -292,7 +293,6 @@ def _clear_template(self, tenant_uuid: str, template_id: str): shutil.rmtree(template.template_dir) def cleanup(self): - # TODO: configurable threshold = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(days=7) for tenant_uuid, templates in self._templates.items(): for template_id, template in templates.items(): diff --git a/packages/dsw-document-worker/dsw/document_worker/worker.py b/packages/dsw-document-worker/dsw/document_worker/worker.py index 8f302f96..e33567c8 100644 --- a/packages/dsw-document-worker/dsw/document_worker/worker.py +++ b/packages/dsw-document-worker/dsw/document_worker/worker.py @@ -1,17 +1,16 @@ import datetime -import dateutil.parser import functools import logging import pathlib -import sentry_sdk.types as sentry +import typing -from typing import Optional +import dateutil.parser +import sentry_sdk.types as sentry from dsw.command_queue import CommandWorker, CommandQueue from dsw.config.sentry import SentryReporter from dsw.database.database import Database -from dsw.database.model import DBDocument, DBTenantConfig, \ - DBTenantLimits, PersistentCommand +from dsw.database.model import DBDocument, PersistentCommand from dsw.storage import S3Storage from .build_info import BUILD_INFO @@ -37,11 +36,12 @@ def handled_step(job, *args, **kwargs): return func(job, *args, **kwargs) except Exception as e: LOG.debug('Handling exception', exc_info=True) - raise create_job_exception( + new_exception = create_job_exception( job_id=job.doc_uuid, message=message, exc=e, ) + raise new_exception from e return handled_step return decorator @@ -50,16 +50,16 @@ class Job: def __init__(self, command: PersistentCommand, document_uuid: str): self.ctx = Context.get() - self.template = None # type: Optional[Template] - self.format = None # type: Optional[Format] - self.tenant_uuid = command.tenant_uuid # type: str - self.doc_uuid = document_uuid # type: str - self.doc_context = command.body # type: dict - self.doc = None # type: Optional[DBDocument] - self.final_file = None # type: Optional[DocumentFile] - self.template_config = None # type: Optional[TemplateConfig] - self.tenant_config = self.ctx.app.db.get_tenant_config(self.tenant_uuid) # type: Optional[DBTenantConfig] - self.tenant_limits = self.ctx.app.db.fetch_tenant_limits(self.tenant_uuid) # type: Optional[DBTenantLimits] + self.template: Template | None = None + self.format: Format | None = None + self.tenant_uuid: str = command.tenant_uuid + self.doc_uuid: str = document_uuid + self.doc_context: dict = command.body + self.doc: DBDocument | None = None + self.final_file: DocumentFile | None = None + self.template_config: TemplateConfig | None = None + self.tenant_config = self.ctx.app.db.get_tenant_config(self.tenant_uuid) + self.tenant_limits = self.ctx.app.db.fetch_tenant_limits(self.tenant_uuid) @property def safe_doc(self) -> DBDocument: @@ -93,8 +93,8 @@ def get_document(self): format='?', ) if self.tenant_uuid != NULL_UUID: - LOG.info(f'Limiting to tenant with UUID: {self.tenant_uuid}') - LOG.info(f'Getting the document "{self.doc_uuid}" details from DB') + LOG.info('Limiting to tenant with UUID: %s', self.tenant_uuid) + LOG.info('Getting the document "%s" details from DB', self.doc_uuid) self.doc = self.ctx.app.db.fetch_document( document_uuid=self.doc_uuid, tenant_uuid=self.tenant_uuid, @@ -105,10 +105,10 @@ def get_document(self): message='Document record not found in database', ) self.doc.retrieved_at = datetime.datetime.now(tz=datetime.UTC) - LOG.info(f'Job "{self.doc_uuid}" details received') + LOG.info('Job "%s" details received', self.doc_uuid) # verify state state = self.doc.state - LOG.info(f'Original state of job is {state}') + LOG.info('Original state of job is %s', state) if state == DocumentState.FINISHED: raise create_job_exception( job_id=self.doc_uuid, @@ -124,7 +124,8 @@ def prepare_template(self): SentryReporter.set_tags(phase='prepare') template_id = self.safe_doc.document_template_id format_uuid = self.safe_doc.format_uuid - LOG.info(f'Document uses template {template_id} with format {format_uuid}') + LOG.info('Document uses template %s with format %s', + template_id, format_uuid) # update Sentry info SentryReporter.set_tags( template=template_id, @@ -144,7 +145,7 @@ def prepare_template(self): self.template_config = self.ctx.app.cfg.templates.get_config(template_id) def _enrich_context(self): - extras = dict() + extras: dict[str, typing.Any] = {} if self.safe_format.requires_via_extras('submissions'): submissions = self.ctx.app.db.fetch_questionnaire_submissions( questionnaire_uuid=self.safe_doc.questionnaire_uuid, @@ -203,16 +204,17 @@ def store_document(self): SentryReporter.set_tags(phase='store') s3_id = self.ctx.app.s3.identification final_file = self.safe_final_file - LOG.info(f'Preparing S3 bucket {s3_id}') + LOG.info('Preparing S3 bucket %s', s3_id) self.ctx.app.s3.ensure_bucket() - LOG.info(f'Storing document to S3 bucket {s3_id}') + LOG.info('Storing document to S3 bucket %s', s3_id) self.ctx.app.s3.store_document( tenant_uuid=self.tenant_uuid, file_name=self.doc_uuid, content_type=final_file.object_content_type, data=final_file.content, ) - LOG.info(f'Document {self.doc_uuid} stored in S3 bucket {s3_id}') + LOG.info('Document %s stored in S3 bucket %s', + self.doc_uuid, s3_id) @handle_job_step('Failed to finalize document generation') def finalize(self): @@ -235,7 +237,7 @@ def finalize(self): ), document_uuid=self.doc_uuid, ) - LOG.info(f'Document {self.doc_uuid} record finalized') + LOG.info('Document %s record finalized', self.doc_uuid) def set_job_state(self, state: str, message: str) -> bool: return self.ctx.app.db.update_document_state( @@ -249,7 +251,8 @@ def try_set_job_state(self, state: str, message: str) -> bool: return self.set_job_state(state, message) except Exception as e: SentryReporter.capture_exception(e) - LOG.warning(f'Tried to set state of {self.doc_uuid} to {state} but failed: {e}') + LOG.warning('Tried to set state of %s to %s but failed: %s', + self.doc_uuid, state, str(e)) return False def _run(self): @@ -264,9 +267,9 @@ def _run(self): def _set_failed(self, message: str): if self.try_set_job_state(DocumentState.FAILED, message): - LOG.info(f'Set state to {DocumentState.FAILED}') + LOG.info('Set state to FAILED') else: - msg = f'Could not set state to {DocumentState.FAILED}' + msg = 'Could not set state to FAILED' SentryReporter.capture_message(msg) LOG.error(msg) raise RuntimeError(msg) @@ -295,7 +298,7 @@ class DocumentWorker(CommandWorker): def __init__(self, config: DocumentWorkerConfig, workdir: pathlib.Path): self.config = config self._init_context(workdir=workdir) - self.current_job = None # type: Job | None + self.current_job: Job | None = None def _init_context(self, workdir: pathlib.Path): Context.initialize( @@ -319,10 +322,11 @@ def _init_sentry(self): ) def filter_templates(event: sentry.Event, hint: sentry.Hint) -> sentry.Event | None: - LOG.debug(f'Filtering Sentry event (template, {event.get("event_id")}, {hint})') + LOG.debug('Filtering Sentry event (template, %s, %s)', + event.get('event_id'), hint) template = event.get('tags', {}).get('template') phase = event.get('tags', {}).get('phase') - if (phase == 'render' or phase == 'prepare') and template is not None: + if phase in ('render', 'prepare') and template is not None: template_config = Context.get().app.cfg.templates.get_config(template) if template_config is not None and not template_config.send_sentry: return None @@ -333,8 +337,8 @@ def filter_templates(event: sentry.Event, hint: sentry.Hint) -> sentry.Event | N @staticmethod def _update_component_info(): built_at = dateutil.parser.parse(BUILD_INFO.built_at) - LOG.info(f'Updating component info ({BUILD_INFO.version}, ' - f'{built_at.isoformat(timespec="seconds")})') + LOG.info('Updating component info (%s, %s)', + BUILD_INFO.version, built_at.isoformat(timespec="seconds")) Context.get().app.db.update_component_info( name=COMPONENT_NAME, version=BUILD_INFO.version, @@ -367,18 +371,18 @@ def run_once(self): queue = self._run_preparation() queue.run_once() - def work(self, cmd: PersistentCommand): - document_uuid = cmd.body['document']['uuid'] - Context.get().update_trace_id(cmd.uuid) + def work(self, command: PersistentCommand): + document_uuid = command.body['document']['uuid'] + Context.get().update_trace_id(command.uuid) Context.get().update_document_id(document_uuid) SentryReporter.set_tags( - command_uuid=cmd.uuid, - tenant_uuid=cmd.tenant_uuid, + command_uuid=command.uuid, + tenant_uuid=command.tenant_uuid, document_uuid=document_uuid, phase='init', ) - LOG.info(f'Running job #{cmd.uuid}') - self.current_job = Job(command=cmd, document_uuid=document_uuid) + LOG.info('Running job #%s', command.uuid) + self.current_job = Job(command=command, document_uuid=document_uuid) self.current_job.run() self.current_job = None SentryReporter.set_tags( diff --git a/packages/dsw-document-worker/pyproject.toml b/packages/dsw-document-worker/pyproject.toml index b0760c00..ae6d6cd0 100644 --- a/packages/dsw-document-worker/pyproject.toml +++ b/packages/dsw-document-worker/pyproject.toml @@ -17,11 +17,11 @@ classifiers = [ 'License :: OSI Approved :: Apache Software License', 'Operating System :: POSIX :: Linux', 'Programming Language :: Python', - 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Topic :: Text Processing', ] -requires-python = '>=3.10, <4' +requires-python = '>=3.11, <4' dependencies = [ 'click', 'Jinja2', diff --git a/packages/dsw-mailer/dsw/mailer/cli.py b/packages/dsw-mailer/dsw/mailer/cli.py index 52690b8a..e56ec244 100644 --- a/packages/dsw-mailer/dsw/mailer/cli.py +++ b/packages/dsw-mailer/dsw/mailer/cli.py @@ -1,14 +1,16 @@ -import click # type: ignore import json import pathlib +import typing +import sys -from typing import IO +import click from dsw.config.parser import MissingConfigurationError from .config import MailerConfig, MailerConfigParser from .consts import (VERSION, VAR_APP_CONFIG_PATH, VAR_WORKDIR_PATH, DEFAULT_ENCODING) +from .mailer import Mailer from .model import MessageRequest @@ -16,7 +18,7 @@ def load_config_str(config_str: str) -> MailerConfig: parser = MailerConfigParser() if not parser.can_read(config_str): click.echo('Error: Cannot parse config file', err=True) - exit(1) + sys.exit(1) try: parser.read_string(config_str) @@ -25,14 +27,15 @@ def load_config_str(config_str: str) -> MailerConfig: click.echo('Error: Missing configuration', err=True) for missing_item in e.missing: click.echo(f' - {missing_item}', err=True) - exit(1) + sys.exit(1) config = parser.config config.log.apply() return config -def validate_config(ctx, param, value: IO | None): +# pylint: disable-next=unused-argument +def validate_config(ctx, param, value: typing.IO | None): content = '' if value is not None: content = value.read() @@ -40,14 +43,15 @@ def validate_config(ctx, param, value: IO | None): return load_config_str(content) -def extract_message_request(ctx, param, value: IO): +# pylint: disable-next=unused-argument +def extract_message_request(ctx, param, value: typing.IO): data = json.load(value) try: return MessageRequest.load_from_file(data) except Exception as e: click.echo('Error: Cannot parse message request', err=True) click.echo(f'{type(e).__name__}: {str(e)}') - exit(1) + sys.exit(1) @click.group(name='dsw-mailer', help='Mailer for sending emails from DSW') @@ -60,7 +64,6 @@ def extract_message_request(ctx, param, value: IO): type=click.Path(dir_okay=True, exists=True)) def cli(ctx, config: MailerConfig, workdir: str): path_workdir = pathlib.Path(workdir) - from .mailer import Mailer config.log.apply() ctx.obj['mailer'] = Mailer(config, path_workdir) @@ -73,18 +76,17 @@ def cli(ctx, config: MailerConfig, workdir: str): required=False, callback=validate_config, type=click.File('r', encoding=DEFAULT_ENCODING)) def send(ctx, msg_request: MessageRequest, config: MailerConfig): - from .mailer import Mailer - mailer = ctx.obj['mailer'] # type: Mailer + mailer: Mailer = ctx.obj['mailer'] mailer.send(rq=msg_request, cfg=config.mail) @cli.command(name='run', help='Run mailer worker processing message jobs.') @click.pass_context def run(ctx): - from .mailer import Mailer - mailer = ctx.obj['mailer'] # type: Mailer + mailer: Mailer = ctx.obj['mailer'] mailer.run() def main(): + # pylint: disable-next=no-value-for-parameter cli(obj={}) diff --git a/packages/dsw-mailer/dsw/mailer/config.py b/packages/dsw-mailer/dsw/mailer/config.py index 50045690..1002d308 100644 --- a/packages/dsw-mailer/dsw/mailer/config.py +++ b/packages/dsw-mailer/dsw/mailer/config.py @@ -1,3 +1,4 @@ +import dataclasses import enum import pathlib @@ -10,6 +11,7 @@ from dsw.database.model import DBInstanceConfigMail +# pylint: disable-next=too-few-public-methods class _ExperimentalKeys(ConfigKeysContainer): job_timeout = ConfigKey( yaml_path=['experimental', 'jobTimeout'], @@ -19,12 +21,12 @@ class _ExperimentalKeys(ConfigKeysContainer): ) +@dataclasses.dataclass class ExperimentalConfig(ConfigModel): - - def __init__(self, job_timeout: int | None): - self.job_timeout = job_timeout + job_timeout: int | None +# pylint: disable-next=too-few-public-methods class _MailKeys(ConfigKeysContainer): enabled = ConfigKey( yaml_path=['mail', 'enabled'], @@ -76,6 +78,7 @@ class _MailKeys(ConfigKeysContainer): ) +# pylint: disable-next=too-few-public-methods class _MailLegacySMTPKeys(ConfigKeysContainer): host = ConfigKey( yaml_path=['mail', 'host'], @@ -121,6 +124,7 @@ class _MailLegacySMTPKeys(ConfigKeysContainer): ) +# pylint: disable-next=too-few-public-methods class _MailSMTPKeys(ConfigKeysContainer): host = ConfigKey( yaml_path=['mail', 'smtp', 'host'], @@ -156,6 +160,7 @@ class _MailSMTPKeys(ConfigKeysContainer): ) +# pylint: disable-next=too-few-public-methods class _MailAmazonSESKeys(ConfigKeysContainer): access_key_id = ConfigKey( yaml_path=['mail', 'amazonSes', 'accessKeyId'], @@ -174,6 +179,7 @@ class _MailAmazonSESKeys(ConfigKeysContainer): ) +# pylint: disable-next=too-few-public-methods class MailerConfigKeys(ConfigKeys): mail = _MailKeys mail_legacy_smtp = _MailLegacySMTPKeys @@ -204,7 +210,7 @@ def has(cls, name): class MailSMTPConfig: - def __init__(self, host: str | None = None, port: int | None = None, + def __init__(self, *, host: str | None = None, port: int | None = None, security: str | None = None, ssl: bool | None = None, username: str | None = None, password: str | None = None, auth_enabled: bool | None = None, timeout: int = 10): @@ -253,14 +259,11 @@ def has_credentials(self) -> bool: return self.username is not None and self.password is not None +@dataclasses.dataclass class MailAmazonSESConfig: - - def __init__(self, access_key_id: str | None = None, - secret_access_key: str | None = None, - region: str | None = None): - self.access_key_id = access_key_id - self.secret_access_key = secret_access_key - self.region = region + access_key_id: str | None = None + secret_access_key: str | None = None + region: str | None = None def has_credentials(self) -> bool: return self.access_key_id is not None and self.secret_access_key is not None @@ -268,7 +271,8 @@ def has_credentials(self) -> bool: class MailConfig(ConfigModel): - def __init__(self, enabled: bool, name: str, email: str, + # pylint: disable-next=too-many-arguments + def __init__(self, *, enabled: bool, name: str, email: str, provider: str, smtp: MailSMTPConfig, amazon_ses: MailAmazonSESConfig, rate_limit_window: int, rate_limit_count: int, dkim_selector: str | None = None, dkim_privkey_file: str | None = None): @@ -322,7 +326,7 @@ def __str__(self): class MailerConfig: - def __init__(self, db: DatabaseConfig, log: LoggingConfig, + def __init__(self, *, db: DatabaseConfig, log: LoggingConfig, mail: MailConfig, sentry: SentryConfig, general: GeneralConfig, aws: AWSConfig, experimental: ExperimentalConfig): diff --git a/packages/dsw-mailer/dsw/mailer/context.py b/packages/dsw-mailer/dsw/mailer/context.py index e7c6aac0..68c97b8e 100644 --- a/packages/dsw-mailer/dsw/mailer/context.py +++ b/packages/dsw-mailer/dsw/mailer/context.py @@ -1,13 +1,11 @@ +import dataclasses import pathlib -from typing import TYPE_CHECKING +from dsw.database import Database +from .config import MailerConfig from .templates import TemplateRegistry -if TYPE_CHECKING: - from .config import MailerConfig - from dsw.database import Database - class ContextNotInitializedError(RuntimeError): @@ -15,18 +13,16 @@ def __init__(self): super().__init__('Context cannot be retrieved, not initialized') +@dataclasses.dataclass class AppContext: - - def __init__(self, db, cfg, workdir): - self.db = db # type: Database - self.cfg = cfg # type: MailerConfig - self.workdir = workdir # type: pathlib.Path + db: Database + cfg: MailerConfig + workdir: pathlib.Path +@dataclasses.dataclass class JobContext: - - def __init__(self, trace_id: str): - self.trace_id = trace_id + trace_id: str class _Context: @@ -47,7 +43,7 @@ def reset_ids(self): class Context: - _instance = None # type: _Context | None + _instance: _Context | None = None @classmethod def get(cls) -> _Context: diff --git a/packages/dsw-mailer/dsw/mailer/handlers.py b/packages/dsw-mailer/dsw/mailer/handlers.py index 9dffd5b8..22e13ae6 100644 --- a/packages/dsw-mailer/dsw/mailer/handlers.py +++ b/packages/dsw-mailer/dsw/mailer/handlers.py @@ -2,14 +2,15 @@ import pathlib from .cli import load_config_str -from .consts import VAR_APP_CONFIG_PATH, VAR_WORKDIR_PATH +from .consts import VAR_APP_CONFIG_PATH, VAR_WORKDIR_PATH, DEFAULT_ENCODING from .mailer import Mailer +# pylint: disable-next=unused-argument def lambda_handler(event, context): config_path = pathlib.Path(os.getenv(VAR_APP_CONFIG_PATH, '/var/task/application.yml')) workdir_path = pathlib.Path(os.getenv(VAR_WORKDIR_PATH, '/var/task/templates')) - config = load_config_str(config_path.read_text()) + config = load_config_str(config_path.read_text(encoding=DEFAULT_ENCODING)) mailer = Mailer(config, workdir_path) mailer.run_once() diff --git a/packages/dsw-mailer/dsw/mailer/mailer.py b/packages/dsw-mailer/dsw/mailer/mailer.py index 34fa375c..b690360c 100644 --- a/packages/dsw-mailer/dsw/mailer/mailer.py +++ b/packages/dsw-mailer/dsw/mailer/mailer.py @@ -1,11 +1,12 @@ import datetime -import dateutil.parser import logging import math import pathlib import time import urllib.parse +import dateutil.parser + from dsw.command_queue import CommandWorker, CommandQueue from dsw.config.sentry import SentryReporter from dsw.database.database import Database @@ -52,8 +53,8 @@ def _init_context(self, workdir: pathlib.Path): @staticmethod def _update_component_info(): built_at = dateutil.parser.parse(BUILD_INFO.built_at) - LOG.info(f'Updating component info ({BUILD_INFO.version}, ' - f'{built_at.isoformat(timespec="seconds")})') + LOG.info('Updating component info (%s, %s)', + BUILD_INFO.version, built_at.isoformat(timespec="seconds")) Context.get().app.db.update_component_info( name=COMPONENT_NAME, version=BUILD_INFO.version, @@ -86,34 +87,34 @@ def run_once(self): queue = self._run_preparation() queue.run_once() - def work(self, cmd: PersistentCommand): + def work(self, command: PersistentCommand): # update Sentry info SentryReporter.set_tags( template='?', - command_uuid=cmd.uuid, - tenant_uuid=cmd.tenant_uuid, + command_uuid=command.uuid, + tenant_uuid=command.tenant_uuid, ) - Context.get().update_trace_id(cmd.uuid) + Context.get().update_trace_id(command.uuid) # work app_ctx = Context.get().app - mc = MailerCommand.load(cmd) + mc = MailerCommand.load(command) rq = mc.to_request( - msg_id=cmd.uuid, + msg_id=command.uuid, trigger='PersistentComment', ) # get tenant config from DB - tenant_cfg = app_ctx.db.get_tenant_config(tenant_uuid=cmd.tenant_uuid) - LOG.debug(f'Tenant config from DB: {tenant_cfg}') + tenant_cfg = app_ctx.db.get_tenant_config(tenant_uuid=command.tenant_uuid) + LOG.debug('Tenant config from DB: %s', tenant_cfg) if tenant_cfg is not None: rq.style.from_dict(tenant_cfg.look_and_feel) # get mailer config from DB mail_cfg = merge_mail_configs( cfg=self.cfg, - db_cfg=app_ctx.db.get_mail_config(tenant_uuid=cmd.tenant_uuid), + db_cfg=app_ctx.db.get_mail_config(tenant_uuid=command.tenant_uuid), ) - LOG.debug(f'Mail config from DB: {mail_cfg}') + LOG.debug('Mail config from DB: %s', mail_cfg) # client URL - rq.client_url = cmd.body.get('clientUrl', app_ctx.cfg.general.client_url) + rq.client_url = command.body.get('clientUrl', app_ctx.cfg.general.client_url) rq.domain = urllib.parse.urlparse(rq.client_url).hostname # update Sentry info SentryReporter.set_tags(template=rq.template_name) @@ -134,15 +135,15 @@ def process_exception(self, e: BaseException): SentryReporter.capture_exception(e) def send(self, rq: MessageRequest, cfg: MailConfig): - LOG.info(f'Sending request: {rq.template_name} ({rq.id})') + LOG.info('Sending request: %s (%s)', rq.template_name, rq.id) # get template if not self.ctx.templates.has_template_for(rq): raise RuntimeError(f'Template not found: {rq.template_name}') # render - LOG.info(f'Rendering message: {rq.template_name}') + LOG.info('Rendering message: %s', rq.template_name) msg = self.ctx.templates.render(rq, cfg) # send - LOG.info(f'Sending message: {rq.template_name}') + LOG.info('Sending message: %s', rq.template_name) send(msg, cfg) LOG.info('Message sent successfully') @@ -152,7 +153,7 @@ class RateLimiter: def __init__(self, window: int, count: int): self.window = window self.count = count - self.hits = [] # type: list[float] + self.hits: list[float] = [] def hit(self): if self.window == 0: @@ -168,14 +169,14 @@ def hit(self): LOG.info('Reached rate limit') sleep_time = math.ceil(self.window - now + self.hits[0]) if sleep_time > 1: - LOG.info(f'Will sleep now for {sleep_time} second') + LOG.info('Will sleep now for %s seconds', sleep_time) time.sleep(sleep_time) class MailerCommand: - def __init__(self, recipients: list[str], mode: str, template: str, ctx: dict, - tenant_uuid: str, cmd_uuid: str): + def __init__(self, *, recipients: list[str], mode: str, template: str, + ctx: dict, tenant_uuid: str, cmd_uuid: str): self.mode = mode self.template = template self.recipients = recipients @@ -203,19 +204,19 @@ def _enrich_context(self): } @staticmethod - def load(cmd: PersistentCommand) -> 'MailerCommand': - if cmd.component != CMD_COMPONENT: + def load(command: PersistentCommand) -> 'MailerCommand': + if command.component != CMD_COMPONENT: raise RuntimeError('Tried to process non-mailer command') - if cmd.function != CMD_FUNCTION: - raise RuntimeError(f'Unsupported function: {cmd.function}') + if command.function != CMD_FUNCTION: + raise RuntimeError(f'Unsupported function: {command.function}') try: return MailerCommand( - mode=cmd.body['mode'], - template=cmd.body['template'], - recipients=cmd.body['recipients'], - ctx=cmd.body['parameters'], - tenant_uuid=cmd.tenant_uuid, - cmd_uuid=cmd.uuid, + mode=command.body['mode'], + template=command.body['template'], + recipients=command.body['recipients'], + ctx=command.body['parameters'], + tenant_uuid=command.tenant_uuid, + cmd_uuid=command.uuid, ) except KeyError as e: - raise RuntimeError(f'Cannot parse command: {str(e)}') + raise RuntimeError(f'Cannot parse command: {str(e)}') from e diff --git a/packages/dsw-mailer/dsw/mailer/model.py b/packages/dsw-mailer/dsw/mailer/model.py index 722f419e..7027d9b5 100644 --- a/packages/dsw-mailer/dsw/mailer/model.py +++ b/packages/dsw-mailer/dsw/mailer/model.py @@ -1,3 +1,4 @@ +import dataclasses import os import re @@ -14,8 +15,7 @@ def contrast_ratio(color1: 'Color', color2: 'Color') -> float: l2 = color2.luminance + 0.05 if l1 > l2: return l1 / l2 - else: - return l2 / l1 + return l2 / l1 def __init__(self, color_hex: str = '#000000', default: str = '#000000'): color_hex = self.parse_color_to_hex(color_hex) or default @@ -46,8 +46,7 @@ def _luminance_component(component: int): c = component / 255 if c <= 0.03928: return c / 12.92 - else: - return ((c + 0.055) / 1.055) ** 2.4 + return ((c + 0.055) / 1.055) ** 2.4 r = _luminance_component(self.red) g = _luminance_component(self.green) @@ -66,8 +65,7 @@ def is_light(self): def contrast_color(self) -> 'Color': if self.contrast_ratio(self, Color('#ffffff')) > 3: return Color('#ffffff') - else: - return Color('#000000') + return Color('#000000') def __str__(self): return self.hex @@ -83,7 +81,7 @@ def __init__(self, logo_url: str | None, primary_color: str, self.illustrations_color = Color(illustrations_color, Color.DEFAULT_ILLUSTRATIONS_HEX) def from_dict(self, data: dict | None): - data = data or dict() + data = data or {} if data.get('logoUrl', None) is not None: self.logo_url = data.get('logoUrl') if data.get('primaryColor', None) is not None: @@ -134,7 +132,7 @@ def __init__(self, part_type: str, file: str): self.content_type = '' self.encoding = '' - def _update_from_data(self, data: dict): + def update_from_data(self, data: dict): for field in self.FIELDS: target_field = field.replace('-', '_') if field in data.keys(): @@ -150,13 +148,13 @@ def load_from_file(data: dict) -> 'TemplateDescriptorPart': part_type=data.get('type', 'unknown'), file=data.get('file', ''), ) - part._update_from_data(data) + part.update_from_data(data) return part class TemplateDescriptor: - def __init__(self, message_id: str, subject: str, subject_prefix: bool, + def __init__(self, *, message_id: str, subject: str, subject_prefix: bool, default_sender_name: str | None, language: str, importance: str, sensitivity: str | None, priority: str | None): @@ -191,7 +189,7 @@ def load_from_file(data: dict) -> 'TemplateDescriptor': class MessageRequest: - def __init__(self, message_id: str, template_name: str, trigger: str, + def __init__(self, *, message_id: str, template_name: str, trigger: str, ctx: dict, recipients: list[str], style: StyleConfig | None = None): self.id = message_id self.template_name = template_name @@ -213,35 +211,35 @@ def load_from_file(data: dict) -> 'MessageRequest': recipients=data.get('recipients', []), style=StyleConfig( logo_url=data.get('styleLogoUrl', None), - primary_color=data.get('stylePrimaryColor', Color.DEFAULT_PRIMARY_HEX), - illustrations_color=data.get('styleIllustrationsColor', Color.DEFAULT_ILLUSTRATIONS_HEX), + primary_color=data.get('stylePrimaryColor', + Color.DEFAULT_PRIMARY_HEX), + illustrations_color=data.get('styleIllustrationsColor', + Color.DEFAULT_ILLUSTRATIONS_HEX), ), ) +@dataclasses.dataclass class MailMessage: - - def __init__(self): - self.from_mail = '' # type: str - self.from_name = None # type: str | None - self.recipients = list() # type: list[str] - self.subject = '' # type: str - self.plain_body = None # type: str | None - self.html_body = None # type: str | None - self.html_images = list() # type: list[MailAttachment] - self.attachments = list() # type: list[MailAttachment] - self.msg_id = None # type: str | None - self.msg_domain = None # type: str | None - self.language = 'en' # type: str - self.importance = 'normal' # type: str - self.sensitivity = None # type: str | None - self.priority = None # type: str | None - self.client_url = '' # type: str - - + from_mail: str = '' + from_name: str | None = None + recipients: list[str] = dataclasses.field(default_factory=list) + subject: str = '' + plain_body: str | None = None + html_body: str | None = None + html_images: list['MailAttachment'] = dataclasses.field(default_factory=list) + attachments: list['MailAttachment'] = dataclasses.field(default_factory=list) + msg_id: str | None = None + msg_domain: str | None = None + language: str = 'en' + importance: str = 'normal' + sensitivity: str | None = None + priority: str | None = None + client_url: str = '' + + +@dataclasses.dataclass class MailAttachment: - - def __init__(self, name='', content_type='', data=''): - self.name = name - self.content_type = content_type - self.data = data + name: str + content_type: str + data: bytes diff --git a/packages/dsw-mailer/dsw/mailer/sender/amazon_ses.py b/packages/dsw-mailer/dsw/mailer/sender/amazon_ses.py index 647ac2ae..ae384660 100644 --- a/packages/dsw-mailer/dsw/mailer/sender/amazon_ses.py +++ b/packages/dsw-mailer/dsw/mailer/sender/amazon_ses.py @@ -1,6 +1,7 @@ -import boto3 import logging +import boto3 + from .base import BaseMailSender from ..config import MailConfig from ..model import MailMessage @@ -19,7 +20,8 @@ def validate_config(cfg: MailConfig): raise ValueError('Missing region for Amazon SES') def send(self, message: MailMessage): - LOG.info(f'Sending via Amazon SES (region {self.cfg.amazon_ses.region})') + LOG.info('Sending via Amazon SES (region %s)', + self.cfg.amazon_ses.region) self._send(message, self.cfg) def _send(self, mail: MailMessage, cfg: MailConfig): diff --git a/packages/dsw-mailer/dsw/mailer/sender/base.py b/packages/dsw-mailer/dsw/mailer/sender/base.py index 4ff86200..7cf5caec 100644 --- a/packages/dsw-mailer/dsw/mailer/sender/base.py +++ b/packages/dsw-mailer/dsw/mailer/sender/base.py @@ -1,8 +1,6 @@ import abc import datetime -import dkim import logging -import pathvalidate from email import encoders from email.mime.base import MIMEBase @@ -10,6 +8,8 @@ from email.mime.text import MIMEText from email.utils import formataddr, format_datetime, make_msgid +import pathvalidate + from ..config import MailConfig from ..consts import DEFAULT_ENCODING from ..model import MailMessage, MailAttachment @@ -71,6 +71,9 @@ def add_header(name: str, value: str): add_header('Priority', mail.priority) if self.cfg.dkim_selector and self.cfg.dkim_privkey: + # pylint: disable=import-outside-toplevel + import dkim # type: ignore + sender_domain = mail.from_mail.split('@')[-1] signature = dkim.sign( message=msg.as_bytes(), @@ -147,4 +150,3 @@ def validate_config(cfg: MailConfig): def send(self, message: MailMessage): LOG.info('No provider configured, not sending anything') - return diff --git a/packages/dsw-mailer/dsw/mailer/sender/smtp.py b/packages/dsw-mailer/dsw/mailer/sender/smtp.py index 0dc6a477..6a1082eb 100644 --- a/packages/dsw-mailer/dsw/mailer/sender/smtp.py +++ b/packages/dsw-mailer/dsw/mailer/sender/smtp.py @@ -1,10 +1,11 @@ import logging import smtplib import ssl -import tenacity from email.utils import formataddr +import tenacity + from .base import BaseMailSender from ..config import MailConfig from ..model import MailMessage @@ -32,7 +33,8 @@ def validate_config(cfg: MailConfig): after=tenacity.after_log(LOG, logging.DEBUG), ) def send(self, message: MailMessage): - LOG.info(f'Sending via SMTP (server {self.cfg.smtp.host}:{self.cfg.smtp.port})') + LOG.info('Sending via SMTP (server %s:%s)', + self.cfg.smtp.host, self.cfg.smtp.port) if self.cfg.smtp.is_ssl: self._send_smtp_ssl(mail=message) else: diff --git a/packages/dsw-mailer/dsw/mailer/templates.py b/packages/dsw-mailer/dsw/mailer/templates.py index 8f1858f2..f3915ffa 100644 --- a/packages/dsw-mailer/dsw/mailer/templates.py +++ b/packages/dsw-mailer/dsw/mailer/templates.py @@ -1,17 +1,18 @@ import datetime +import json +import logging +import pathlib +import re + import dateutil.parser import jinja2 import jinja2.sandbox -import json -import logging import markdown import markupsafe -import pathlib -import re from .config import MailerConfig, MailConfig from .consts import DEFAULT_ENCODING -from .model import MailMessage, MailAttachment, MessageRequest,\ +from .model import MailMessage, MailAttachment, MessageRequest, \ TemplateDescriptor, TemplateDescriptorPart @@ -27,8 +28,8 @@ def __init__(self, name: str, descriptor: TemplateDescriptor, self.descriptor = descriptor self.html_template = html_template self.plain_template = plain_template - self.attachments = list() # type: list[MailAttachment] - self.html_images = list() # type: list[MailAttachment] + self.attachments: list[MailAttachment] = [] + self.html_images: list[MailAttachment] = [] def render(self, rq: MessageRequest, mail_name: str | None, mail_from: str) -> MailMessage: ctx = rq.ctx @@ -71,7 +72,7 @@ def __init__(self, cfg: MailerConfig, workdir: pathlib.Path): loader=jinja2.FileSystemLoader(searchpath=workdir), extensions=['jinja2.ext.do'], ) - self.templates = dict() # type: dict[str, MailTemplate] + self.templates: dict[str, MailTemplate] = {} self._set_filters() self._load_templates() @@ -110,16 +111,16 @@ def _load_descriptor(path: pathlib.Path) -> TemplateDescriptor | None: data = json.loads(path.read_text(encoding=DEFAULT_ENCODING)) return TemplateDescriptor.load_from_file(data) except Exception as e: - LOG.warning(f'Cannot load template descriptor at {str(path)}' - f'due to: {str(e)}') + LOG.warning('Cannot load template descriptor at %s: %s', + path.as_posix(), str(e)) return None def _load_template(self, path: pathlib.Path, descriptor: TemplateDescriptor) -> MailTemplate | None: html_template = None plain_template = None - attachments = list() - html_images = list() + attachments = [] + html_images = [] for part in descriptor.parts: if part.type == 'html': html_template = self._load_jinja2(path / part.file) @@ -130,8 +131,8 @@ def _load_template(self, path: pathlib.Path, elif part.type == 'html_image': html_images.append(self._load_attachment(path, part)) if html_template is None and plain_template is None: - LOG.warning(f'Template "{descriptor.id}" from {str(path)}' - f'does not have HTML nor Plain part - skipping') + LOG.warning('Template "%s" from %s has no HTML nor Plain part - skipping', + descriptor.id, path.as_posix()) return None template = MailTemplate( name=path.name, @@ -152,11 +153,12 @@ def _load_templates(self): template = self._load_template(path, descriptor) if template is None: continue - LOG.info(f'Loaded template "{descriptor.id}" from {str(path)}') + LOG.info('Loaded template "%s" from %s', + descriptor.id, path.as_posix()) self.templates[descriptor.id] = template def has_template_for(self, rq: MessageRequest) -> bool: - return rq.template_name in self.templates.keys() + return rq.template_name in self.templates def render(self, rq: MessageRequest, cfg: MailConfig) -> MailMessage: used_cfg = cfg or self.cfg.mail @@ -183,9 +185,10 @@ def extendMarkdown(self, md): class DSWMarkdownProcessor(markdown.preprocessors.Preprocessor): + LI_RE = re.compile(r'^[ ]*((\d+\.)|[*+-])[ ]+.*') + def __init__(self, md): super().__init__(md) - self.LI_RE = re.compile(r'^[ ]*((\d+\.)|[*+-])[ ]+.*') def run(self, lines): prev_li = False diff --git a/packages/dsw-mailer/pyproject.toml b/packages/dsw-mailer/pyproject.toml index cc9b8fb4..75d578e3 100644 --- a/packages/dsw-mailer/pyproject.toml +++ b/packages/dsw-mailer/pyproject.toml @@ -16,12 +16,12 @@ classifiers = [ 'Development Status :: 5 - Production/Stable', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python', - 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Topic :: Communications :: Email', 'Topic :: Text Processing', ] -requires-python = '>=3.10, <4' +requires-python = '>=3.11, <4' dependencies = [ 'boto3', 'click', diff --git a/packages/dsw-models/dsw/models/io.py b/packages/dsw-models/dsw/models/io.py index d6402cef..23e652db 100644 --- a/packages/dsw-models/dsw/models/io.py +++ b/packages/dsw-models/dsw/models/io.py @@ -1,10 +1,10 @@ import json +import typing class DSWJSONEncoder(json.JSONEncoder): - def default(self, value): - if hasattr(value, 'to_dict') and callable(value.to_dict): - return value.to_dict() - else: - return super().default(value) + def default(self, o: typing.Any) -> typing.Any: + if hasattr(o, 'to_dict') and callable(o.to_dict): + return o.to_dict() + return super().default(o) diff --git a/packages/dsw-models/dsw/models/km/events.py b/packages/dsw-models/dsw/models/km/events.py index ca83c13f..9ce3fe3a 100644 --- a/packages/dsw-models/dsw/models/km/events.py +++ b/packages/dsw-models/dsw/models/km/events.py @@ -1,15 +1,15 @@ +# pylint: disable=too-many-arguments, too-many-locals, too-many-lines import abc - -from typing import Generic, Optional, TypeVar, Any +import typing # https://github.com/ds-wizard/engine-backend/blob/develop/engine-shared/src/Shared/Model/Event/ -T = TypeVar('T') +T = typing.TypeVar('T') class MetricMeasure: - def __init__(self, metric_uuid: str, measure: float, weight: float): + def __init__(self, *, metric_uuid: str, measure: float, weight: float): self.metric_uuid = metric_uuid self.measure = measure self.weight = weight @@ -30,9 +30,9 @@ def from_dict(data: dict) -> 'MetricMeasure': ) -class EventField(Generic[T]): +class EventField(typing.Generic[T]): - def __init__(self, changed: bool, value: Optional[T]): + def __init__(self, *, changed: bool, value: T | None): self.changed = changed self.value = value @@ -59,7 +59,7 @@ def from_dict(data: dict, loader=None) -> 'EventField': class MapEntry: - def __init__(self, key: str, value: str): + def __init__(self, *, key: str, value: str): self.key = key self.value = value @@ -84,7 +84,7 @@ class _KMEvent(abc.ABC): TYPE = 'UNKNOWN' METAMODEL_VERSION = 14 - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str): self.event_uuid = event_uuid self.entity_uuid = entity_uuid @@ -108,7 +108,7 @@ def from_dict(cls, data: dict): class _KMAddEvent(_KMEvent, abc.ABC): TYPE = 'ADD' - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: list[MapEntry]): super().__init__( event_uuid=event_uuid, @@ -129,7 +129,7 @@ def to_dict(self) -> dict: class _KMEditEvent(_KMEvent, abc.ABC): TYPE = 'EDIT' - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]]): super().__init__( event_uuid=event_uuid, @@ -154,7 +154,7 @@ class _KMDeleteEvent(_KMEvent, abc.ABC): class _KMMoveEvent(_KMEvent, abc.ABC): TYPE = 'MOVE' - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, target_uuid: str): super().__init__( event_uuid=event_uuid, @@ -172,7 +172,7 @@ def to_dict(self) -> dict: return result -EVENT_TYPES = {} # type: dict[str, Any] +EVENT_TYPES: dict[str, typing.Any] = {} def event_class(cls): @@ -206,7 +206,7 @@ def from_dict(cls, data: dict) -> 'AddKnowledgeModelEvent': @event_class class EditKnowledgeModelEvent(_KMEditEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], chapter_uuids: EventField[list[str]], tag_uuids: EventField[list[str]], integration_uuids: EventField[list[str]], metric_uuids: EventField[list[str]], @@ -260,9 +260,9 @@ def from_dict(cls, data: dict) -> 'EditKnowledgeModelEvent': @event_class class AddChapterEvent(_KMAddEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: list[MapEntry], - title: str, text: Optional[str]): + title: str, text: str | None): super().__init__( event_uuid=event_uuid, entity_uuid=entity_uuid, @@ -300,9 +300,9 @@ def from_dict(cls, data: dict) -> 'AddChapterEvent': @event_class class EditChapterEvent(_KMEditEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], - title: EventField[str], text: EventField[Optional[str]], + title: EventField[str], text: EventField[str | None], question_uuids: EventField[list[str]]): super().__init__( event_uuid=event_uuid, @@ -369,9 +369,9 @@ def from_dict(cls, data: dict) -> 'DeleteChapterEvent': @event_class class AddQuestionEvent(_KMAddEvent, abc.ABC): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: list[MapEntry], - title: str, text: Optional[str], required_phase_uuid: Optional[str], + title: str, text: str | None, required_phase_uuid: str | None, tag_uuids: list[str]): super().__init__( event_uuid=event_uuid, @@ -416,22 +416,6 @@ def from_dict(cls, data: dict) -> 'AddQuestionEvent': class AddOptionsQuestionEvent(AddQuestionEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, - created_at: str, annotations: list[MapEntry], - title: str, text: Optional[str], required_phase_uuid: Optional[str], - tag_uuids: list[str]): - super().__init__( - event_uuid=event_uuid, - entity_uuid=entity_uuid, - parent_uuid=parent_uuid, - created_at=created_at, - annotations=annotations, - title=title, - text=text, - required_phase_uuid=required_phase_uuid, - tag_uuids=tag_uuids, - ) - def to_dict(self) -> dict: result = super().to_dict() result.update({ @@ -460,22 +444,6 @@ def from_dict(cls, data: dict) -> 'AddOptionsQuestionEvent': class AddMultiChoiceQuestionEvent(AddQuestionEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, - created_at: str, annotations: list[MapEntry], - title: str, text: Optional[str], required_phase_uuid: Optional[str], - tag_uuids: list[str]): - super().__init__( - event_uuid=event_uuid, - entity_uuid=entity_uuid, - parent_uuid=parent_uuid, - created_at=created_at, - annotations=annotations, - title=title, - text=text, - required_phase_uuid=required_phase_uuid, - tag_uuids=tag_uuids, - ) - def to_dict(self) -> dict: result = super().to_dict() result.update({ @@ -504,22 +472,6 @@ def from_dict(cls, data: dict) -> 'AddMultiChoiceQuestionEvent': class AddListQuestionEvent(AddQuestionEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, - created_at: str, annotations: list[MapEntry], - title: str, text: Optional[str], required_phase_uuid: Optional[str], - tag_uuids: list[str]): - super().__init__( - event_uuid=event_uuid, - entity_uuid=entity_uuid, - parent_uuid=parent_uuid, - created_at=created_at, - annotations=annotations, - title=title, - text=text, - required_phase_uuid=required_phase_uuid, - tag_uuids=tag_uuids, - ) - def to_dict(self) -> dict: result = super().to_dict() result.update({ @@ -548,9 +500,9 @@ def from_dict(cls, data: dict) -> 'AddListQuestionEvent': class AddValueQuestionEvent(AddQuestionEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: list[MapEntry], - title: str, text: Optional[str], required_phase_uuid: Optional[str], + title: str, text: str | None, required_phase_uuid: str | None, tag_uuids: list[str], value_type: str): super().__init__( event_uuid=event_uuid, @@ -595,9 +547,9 @@ def from_dict(cls, data: dict) -> 'AddValueQuestionEvent': class AddIntegrationQuestionEvent(AddQuestionEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: list[MapEntry], - title: str, text: Optional[str], required_phase_uuid: Optional[str], + title: str, text: str | None, required_phase_uuid: str | None, tag_uuids: list[str], integration_uuid: str, props: dict[str, str]): super().__init__( event_uuid=event_uuid, @@ -646,10 +598,10 @@ def from_dict(cls, data: dict) -> 'AddIntegrationQuestionEvent': @event_class class EditQuestionEvent(_KMEditEvent, abc.ABC): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], - title: EventField[str], text: EventField[Optional[str]], - required_phase_uuid: EventField[Optional[str]], + title: EventField[str], text: EventField[str | None], + required_phase_uuid: EventField[str | None], tag_uuids: EventField[list[str]], expert_uuids: EventField[list[str]], reference_uuids: EventField[list[str]]): super().__init__( @@ -699,10 +651,10 @@ def from_dict(cls, data: dict) -> 'EditQuestionEvent': class EditOptionsQuestionEvent(EditQuestionEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], - title: EventField[str], text: EventField[Optional[str]], - required_phase_uuid: EventField[Optional[str]], + title: EventField[str], text: EventField[str | None], + required_phase_uuid: EventField[str | None], tag_uuids: EventField[list[str]], expert_uuids: EventField[list[str]], reference_uuids: EventField[list[str]], answer_uuids: EventField[list[str]]): @@ -756,10 +708,10 @@ def from_dict(cls, data: dict) -> 'EditOptionsQuestionEvent': class EditMultiChoiceQuestionEvent(EditQuestionEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], - title: EventField[str], text: EventField[Optional[str]], - required_phase_uuid: EventField[Optional[str]], + title: EventField[str], text: EventField[str | None], + required_phase_uuid: EventField[str | None], tag_uuids: EventField[list[str]], expert_uuids: EventField[list[str]], reference_uuids: EventField[list[str]], choice_uuids: EventField[list[str]]): @@ -813,10 +765,10 @@ def from_dict(cls, data: dict) -> 'EditMultiChoiceQuestionEvent': class EditListQuestionEvent(EditQuestionEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], - title: EventField[str], text: EventField[Optional[str]], - required_phase_uuid: EventField[Optional[str]], + title: EventField[str], text: EventField[str | None], + required_phase_uuid: EventField[str | None], tag_uuids: EventField[list[str]], expert_uuids: EventField[list[str]], reference_uuids: EventField[list[str]], item_template_question_uuids: EventField[list[str]]): @@ -870,10 +822,10 @@ def from_dict(cls, data: dict) -> 'EditListQuestionEvent': class EditValueQuestionEvent(EditQuestionEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], - title: EventField[str], text: EventField[Optional[str]], - required_phase_uuid: EventField[Optional[str]], + title: EventField[str], text: EventField[str | None], + required_phase_uuid: EventField[str | None], tag_uuids: EventField[list[str]], expert_uuids: EventField[list[str]], reference_uuids: EventField[list[str]], value_type: EventField[str]): super().__init__( @@ -926,10 +878,10 @@ def from_dict(cls, data: dict) -> 'EditValueQuestionEvent': class EditIntegrationQuestionEvent(EditQuestionEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], - title: EventField[str], text: EventField[Optional[str]], - required_phase_uuid: EventField[Optional[str]], + title: EventField[str], text: EventField[str | None], + required_phase_uuid: EventField[str | None], tag_uuids: EventField[list[str]], expert_uuids: EventField[list[str]], reference_uuids: EventField[list[str]], integration_uuid: EventField[str], props: EventField[dict[str, str]]): @@ -1009,9 +961,9 @@ def from_dict(cls, data: dict) -> 'DeleteQuestionEvent': @event_class class AddAnswerEvent(_KMAddEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: list[MapEntry], - label: str, advice: Optional[str], metric_measures: list[MetricMeasure]): + label: str, advice: str | None, metric_measures: list[MetricMeasure]): super().__init__( event_uuid=event_uuid, entity_uuid=entity_uuid, @@ -1052,9 +1004,9 @@ def from_dict(cls, data: dict) -> 'AddAnswerEvent': @event_class class EditAnswerEvent(_KMEditEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], - label: EventField[str], advice: EventField[Optional[str]], + label: EventField[str], advice: EventField[str | None], follow_up_uuids: EventField[list[str]], metric_measures: EventField[list[MetricMeasure]]): super().__init__( @@ -1128,7 +1080,7 @@ def from_dict(cls, data: dict) -> 'DeleteAnswerEvent': @event_class class AddChoiceEvent(_KMAddEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: list[MapEntry], label: str): super().__init__( @@ -1165,7 +1117,7 @@ def from_dict(cls, data: dict) -> 'AddChoiceEvent': @event_class class EditChoiceEvent(_KMEditEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], label: EventField[str]): super().__init__( @@ -1227,7 +1179,7 @@ def from_dict(cls, data: dict) -> 'DeleteChoiceEvent': @event_class class AddExpertEvent(_KMAddEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: list[MapEntry], name: str, email: str): super().__init__( @@ -1267,7 +1219,7 @@ def from_dict(cls, data: dict) -> 'AddExpertEvent': @event_class class EditExpertEvent(_KMEditEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], name: EventField[str], email: EventField[str]): super().__init__( @@ -1332,16 +1284,6 @@ def from_dict(cls, data: dict) -> 'DeleteExpertEvent': @event_class class AddReferenceEvent(_KMAddEvent, abc.ABC): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, - created_at: str, annotations: list[MapEntry]): - super().__init__( - event_uuid=event_uuid, - entity_uuid=entity_uuid, - parent_uuid=parent_uuid, - created_at=created_at, - annotations=annotations, - ) - def to_dict(self) -> dict: result = super().to_dict() result.update({ @@ -1365,7 +1307,7 @@ def from_dict(cls, data: dict) -> 'AddReferenceEvent': class AddResourcePageReferenceEvent(AddReferenceEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: list[MapEntry], short_uuid: str): super().__init__( @@ -1403,7 +1345,7 @@ def from_dict(cls, data: dict) -> 'AddResourcePageReferenceEvent': class AddURLReferenceEvent(AddReferenceEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: list[MapEntry], url: str, label: str): super().__init__( @@ -1444,7 +1386,7 @@ def from_dict(cls, data: dict) -> 'AddURLReferenceEvent': class AddCrossReferenceEvent(AddReferenceEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: list[MapEntry], target_uuid: str, description: str): super().__init__( @@ -1486,16 +1428,6 @@ def from_dict(cls, data: dict) -> 'AddCrossReferenceEvent': @event_class class EditReferenceEvent(_KMEditEvent, abc.ABC): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, - created_at: str, annotations: EventField[list[MapEntry]]): - super().__init__( - event_uuid=event_uuid, - entity_uuid=entity_uuid, - parent_uuid=parent_uuid, - created_at=created_at, - annotations=annotations, - ) - def to_dict(self) -> dict: result = super().to_dict() result.update({ @@ -1519,7 +1451,7 @@ def from_dict(cls, data: dict) -> 'EditReferenceEvent': class EditResourcePageReferenceEvent(EditReferenceEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], short_uuid: EventField[str]): super().__init__( @@ -1560,7 +1492,7 @@ def from_dict(cls, data: dict) -> 'EditResourcePageReferenceEvent': class EditURLReferenceEvent(EditReferenceEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], url: EventField[str], label: EventField[str]): super().__init__( @@ -1604,7 +1536,7 @@ def from_dict(cls, data: dict) -> 'EditURLReferenceEvent': class EditCrossReferenceEvent(EditReferenceEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], target_uuid: EventField[str], description: EventField[str]): super().__init__( @@ -1671,9 +1603,9 @@ def from_dict(cls, data: dict) -> 'DeleteReferenceEvent': @event_class class AddTagEvent(_KMAddEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: list[MapEntry], - name: str, description: Optional[str], color: str): + name: str, description: str | None, color: str): super().__init__( event_uuid=event_uuid, entity_uuid=entity_uuid, @@ -1714,9 +1646,9 @@ def from_dict(cls, data: dict) -> 'AddTagEvent': @event_class class EditTagEvent(_KMEditEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], - name: EventField[str], description: EventField[Optional[str]], + name: EventField[str], description: EventField[str | None], color: EventField[str]): super().__init__( event_uuid=event_uuid, @@ -1783,10 +1715,10 @@ def from_dict(cls, data: dict) -> 'DeleteTagEvent': @event_class class AddIntegrationEvent(_KMAddEvent, abc.ABC): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: list[MapEntry], - integration_id: str, name: str, props: list[str], logo: Optional[str], - item_url: Optional[str]): + integration_id: str, name: str, props: list[str], logo: str | None, + item_url: str | None): super().__init__( event_uuid=event_uuid, entity_uuid=entity_uuid, @@ -1826,12 +1758,12 @@ def from_dict(cls, data: dict) -> 'AddIntegrationEvent': class AddApiIntegrationEvent(AddIntegrationEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: list[MapEntry], - integration_id: str, name: str, props: list[str], logo: Optional[str], - item_url: Optional[str], rq_method: str, rq_url: str, + integration_id: str, name: str, props: list[str], logo: str | None, + item_url: str | None, rq_method: str, rq_url: str, rq_headers: list[MapEntry], rq_body: str, rq_empty_search: bool, - rs_list_field: Optional[str], rs_item_id: Optional[str], + rs_list_field: str | None, rs_item_id: str | None, rs_item_template: str): super().__init__( event_uuid=event_uuid, @@ -1899,10 +1831,10 @@ def from_dict(cls, data: dict) -> 'AddApiIntegrationEvent': class AddWidgetIntegrationEvent(AddIntegrationEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: list[MapEntry], - integration_id: str, name: str, props: list[str], logo: Optional[str], - item_url: Optional[str], widget_url: str): + integration_id: str, name: str, props: list[str], logo: str | None, + item_url: str | None, widget_url: str): super().__init__( event_uuid=event_uuid, entity_uuid=entity_uuid, @@ -1949,11 +1881,11 @@ def from_dict(cls, data: dict) -> 'AddWidgetIntegrationEvent': @event_class class EditIntegrationEvent(_KMEditEvent, abc.ABC): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], integration_id: EventField[str], name: EventField[str], - props: EventField[list[str]], logo: EventField[Optional[str]], - item_url: EventField[Optional[str]]): + props: EventField[list[str]], logo: EventField[str | None], + item_url: EventField[str | None]): super().__init__( event_uuid=event_uuid, entity_uuid=entity_uuid, @@ -1993,14 +1925,14 @@ def from_dict(cls, data: dict) -> 'EditIntegrationEvent': class EditApiIntegrationEvent(EditIntegrationEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], integration_id: EventField[str], name: EventField[str], - props: EventField[list[str]], logo: EventField[Optional[str]], - item_url: EventField[Optional[str]], rq_method: EventField[str], + props: EventField[list[str]], logo: EventField[str | None], + item_url: EventField[str | None], rq_method: EventField[str], rq_url: EventField[str], rq_headers: EventField[list[MapEntry]], rq_body: EventField[str], rq_empty_search: EventField[bool], - rs_list_field: EventField[Optional[str]], rs_item_id: EventField[Optional[str]], + rs_list_field: EventField[str | None], rs_item_id: EventField[str | None], rs_item_template: EventField[str]): super().__init__( event_uuid=event_uuid, @@ -2074,11 +2006,11 @@ def from_dict(cls, data: dict) -> 'EditApiIntegrationEvent': class EditWidgetIntegrationEvent(EditIntegrationEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], integration_id: EventField[str], name: EventField[str], - props: EventField[list[str]], logo: EventField[Optional[str]], - item_url: EventField[Optional[str]], widget_url: EventField[str]): + props: EventField[list[str]], logo: EventField[str | None], + item_url: EventField[str | None], widget_url: EventField[str]): super().__init__( event_uuid=event_uuid, entity_uuid=entity_uuid, @@ -2150,9 +2082,9 @@ def from_dict(cls, data: dict) -> 'DeleteIntegrationEvent': @event_class class AddMetricEvent(_KMAddEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: list[MapEntry], - title: str, abbreviation: Optional[str], description: Optional[str]): + title: str, abbreviation: str | None, description: str | None): super().__init__( event_uuid=event_uuid, entity_uuid=entity_uuid, @@ -2193,10 +2125,10 @@ def from_dict(cls, data: dict) -> 'AddMetricEvent': @event_class class EditMetricEvent(_KMEditEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], - title: EventField[str], abbreviation: EventField[Optional[str]], - description: EventField[Optional[str]]): + title: EventField[str], abbreviation: EventField[str | None], + description: EventField[str | None]): super().__init__( event_uuid=event_uuid, entity_uuid=entity_uuid, @@ -2262,9 +2194,9 @@ def from_dict(cls, data: dict) -> 'DeleteMetricEvent': @event_class class AddPhaseEvent(_KMAddEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: list[MapEntry], - title: str, description: Optional[str]): + title: str, description: str | None): super().__init__( event_uuid=event_uuid, entity_uuid=entity_uuid, @@ -2302,9 +2234,9 @@ def from_dict(cls, data: dict) -> 'AddPhaseEvent': @event_class class EditPhaseEvent(_KMEditEvent): - def __init__(self, event_uuid: str, entity_uuid: str, parent_uuid: str, + def __init__(self, *, event_uuid: str, entity_uuid: str, parent_uuid: str, created_at: str, annotations: EventField[list[MapEntry]], - title: EventField[str], description: EventField[Optional[str]]): + title: EventField[str], description: EventField[str | None]): super().__init__( event_uuid=event_uuid, entity_uuid=entity_uuid, @@ -2484,7 +2416,7 @@ class Event: @classmethod def from_dict(cls, data: dict): event_type = data['eventType'] - if event_type not in EVENT_TYPES.keys(): + if event_type not in EVENT_TYPES: raise ValueError(f'Unknown event type: {event_type}') t = EVENT_TYPES[event_type] if not hasattr(t, 'from_dict'): diff --git a/packages/dsw-models/dsw/models/km/package.py b/packages/dsw-models/dsw/models/km/package.py index edc4019e..a9a36205 100644 --- a/packages/dsw-models/dsw/models/km/package.py +++ b/packages/dsw-models/dsw/models/km/package.py @@ -1,27 +1,26 @@ +# pylint: disable=too-many-arguments, too-many-locals, too-many-lines from .events import _KMEvent, Event -from typing import Optional - class Package: - def __init__(self, km_id: str, org_id: str, version: str, name: str, - metamodel_version: int, description: str, license: str, - readme: str, created_at: str, fork_pkg_id: Optional[str], - merge_pkg_id: Optional[str], prev_pkg_id: Optional[str]): + def __init__(self, *, km_id: str, org_id: str, version: str, name: str, + metamodel_version: int, description: str, pkg_license: str, + readme: str, created_at: str, fork_pkg_id: str | None, + merge_pkg_id: str | None, prev_pkg_id: str | None): self.km_id = km_id self.org_id = org_id self.version = version self.name = name self.metamodel_version = metamodel_version self.description = description - self.license = license + self.license = pkg_license self.readme = readme self.created_at = created_at self.fork_pkg_id = fork_pkg_id self.merge_pkg_id = merge_pkg_id self.prev_pkg_id = prev_pkg_id - self.events = list() # type: list[_KMEvent] + self.events: list[_KMEvent] = [] @property def id(self): @@ -54,7 +53,7 @@ def from_dict(data: dict) -> 'Package': metamodel_version=data['metamodelVersion'], name=data['name'], description=data['description'], - license=data['license'], + pkg_license=data['license'], readme=data['readme'], created_at=data['createdAt'], fork_pkg_id=data['forkOfPackageId'], @@ -68,14 +67,14 @@ def from_dict(data: dict) -> 'Package': class PackageBundle: - def __init__(self, km_id: str, org_id: str, version: str, name: str, + def __init__(self, *, km_id: str, org_id: str, version: str, name: str, metamodel_version: int): self.km_id = km_id self.org_id = org_id self.version = version self.name = name self.metamodel_version = metamodel_version - self.packages = list() # type: list[Package] + self.packages: list[Package] = [] @property def id(self): diff --git a/packages/dsw-models/pyproject.toml b/packages/dsw-models/pyproject.toml index 0c7cb086..b93b4613 100644 --- a/packages/dsw-models/pyproject.toml +++ b/packages/dsw-models/pyproject.toml @@ -16,12 +16,12 @@ classifiers = [ 'Development Status :: 4 - Beta', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python', - 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Topic :: Text Processing', 'Topic :: Utilities', ] -requires-python = '>=3.10, <4' +requires-python = '>=3.11, <4' dependencies = [ ] diff --git a/packages/dsw-storage/dsw/storage/s3storage.py b/packages/dsw-storage/dsw/storage/s3storage.py index ac0b8a9f..c32941cf 100644 --- a/packages/dsw-storage/dsw/storage/s3storage.py +++ b/packages/dsw-storage/dsw/storage/s3storage.py @@ -1,13 +1,12 @@ import contextlib import io import logging -import minio # type: ignore -import minio.error # type: ignore import pathlib import tempfile -import tenacity -from typing import Optional +import minio +import minio.error +import tenacity from dsw.config.model import S3Config @@ -35,7 +34,7 @@ def _get_endpoint(url: str): parts = url.split('://', maxsplit=1) return parts[0] if len(parts) == 1 else parts[1] - def __init__(self, cfg: S3Config, multi_tenant: bool): + def __init__(self, *, cfg: S3Config, multi_tenant: bool): self.cfg = cfg self.multi_tenant = multi_tenant self.client = minio.Minio( @@ -69,9 +68,9 @@ def ensure_bucket(self): before=tenacity.before_log(LOG, logging.DEBUG), after=tenacity.after_log(LOG, logging.DEBUG), ) - def store_document(self, tenant_uuid: str, file_name: str, + def store_document(self, *, tenant_uuid: str, file_name: str, content_type: str, data: bytes, - metadata: Optional[dict] = None): + metadata: dict | None = None): object_name = f'{DOCUMENTS_DIR}/{file_name}' if self.multi_tenant: object_name = f'{tenant_uuid}/{object_name}' @@ -92,7 +91,7 @@ def store_document(self, tenant_uuid: str, file_name: str, before=tenacity.before_log(LOG, logging.DEBUG), after=tenacity.after_log(LOG, logging.DEBUG), ) - def download_file(self, file_name: str, target_path: pathlib.Path) -> bool: + def download_file(self, *, file_name: str, target_path: pathlib.Path) -> bool: try: self.client.fget_object( bucket_name=self.cfg.bucket, @@ -112,9 +111,9 @@ def download_file(self, file_name: str, target_path: pathlib.Path) -> bool: before=tenacity.before_log(LOG, logging.DEBUG), after=tenacity.after_log(LOG, logging.DEBUG), ) - def store_object(self, tenant_uuid: str, object_name: str, + def store_object(self, *, tenant_uuid: str, object_name: str, content_type: str, data: bytes, - metadata: Optional[dict] = None): + metadata: dict | None = None): if self.multi_tenant: object_name = f'{tenant_uuid}/{object_name}' with io.BytesIO(data) as file: diff --git a/packages/dsw-storage/pyproject.toml b/packages/dsw-storage/pyproject.toml index f420ce47..aab7330a 100644 --- a/packages/dsw-storage/pyproject.toml +++ b/packages/dsw-storage/pyproject.toml @@ -16,12 +16,12 @@ classifiers = [ 'Development Status :: 5 - Production/Stable', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python', - 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Topic :: Communications :: File Sharing', 'Topic :: Utilities', ] -requires-python = '>=3.10, <4' +requires-python = '>=3.11, <4' dependencies = [ 'minio', 'tenacity', diff --git a/packages/dsw-tdk/dsw/tdk/__main__.py b/packages/dsw-tdk/dsw/tdk/__main__.py index 4f261ad7..d7541e2a 100644 --- a/packages/dsw-tdk/dsw/tdk/__main__.py +++ b/packages/dsw-tdk/dsw/tdk/__main__.py @@ -1,3 +1,4 @@ from .cli import main +# pylint: disable-next=no-value-for-parameter main(ctx=None) diff --git a/packages/dsw-tdk/dsw/tdk/api_client.py b/packages/dsw-tdk/dsw/tdk/api_client.py index 11e8877f..d781ac48 100644 --- a/packages/dsw-tdk/dsw/tdk/api_client.py +++ b/packages/dsw-tdk/dsw/tdk/api_client.py @@ -1,10 +1,9 @@ -import aiohttp -import aiohttp.client_exceptions import functools import pathlib import urllib.parse -from typing import List, Optional +import aiohttp +import aiohttp.client_exceptions from .consts import DEFAULT_ENCODING, APP, VERSION from .model import Template, TemplateFile, TemplateFileType @@ -36,30 +35,31 @@ async def handled_client_call(job, *args, **kwargs): raise DSWCommunicationError( reason='Unexpected response type', message=e.message - ) + ) from e except aiohttp.client_exceptions.ClientResponseError as e: raise DSWCommunicationError( reason='Error response status', message=f'Server responded with error HTTP status {e.status}: {e.message}' - ) + ) from e except aiohttp.client_exceptions.InvalidURL as e: raise DSWCommunicationError( reason='Invalid URL', message=f'Provided API URL seems invalid: {e.url}' - ) + ) from e except aiohttp.client_exceptions.ClientConnectorError as e: raise DSWCommunicationError( reason='Server unreachable', message=f'Desired server is not reachable (errno {e.os_error.errno})' - ) + ) from e except Exception as e: raise DSWCommunicationError( reason='Communication error', message=f'Communication with server failed ({e})' - ) + ) from e return handled_client_call +# pylint: disable-next=too-many-public-methods class DSWAPIClient: def _headers(self, extra=None): @@ -91,7 +91,9 @@ def __init__(self, api_url: str, api_key: str, session=None): """ self.api_url = api_url self.token = api_key - self.session = session or aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=False)) + self.session = session or aiohttp.ClientSession( + connector=aiohttp.TCPConnector(ssl=False), + ) @property def templates_endpoint(self): @@ -112,31 +114,48 @@ async def safe_close(self) -> bool: return False async def _post_json(self, endpoint, json) -> dict: - async with self.session.post(f'{self.api_url}{endpoint}', json=json, headers=self._headers()) as r: + async with self.session.post( + url=f'{self.api_url}{endpoint}', + json=json, + headers=self._headers(), + ) as r: self._check_status(r, expected_status=201) return await r.json() async def _put_json(self, endpoint, json) -> dict: - async with self.session.put(f'{self.api_url}{endpoint}', json=json, headers=self._headers()) as r: + async with self.session.put( + url=f'{self.api_url}{endpoint}', + json=json, + headers=self._headers(), + ) as r: self._check_status(r, expected_status=200) return await r.json() async def _get_json(self, endpoint) -> dict: - async with self.session.get(f'{self.api_url}{endpoint}', headers=self._headers()) as r: + async with self.session.get( + url=f'{self.api_url}{endpoint}', + headers=self._headers(), + ) as r: self._check_status(r, expected_status=200) return await r.json() async def _get_bytes(self, endpoint) -> bytes: - async with self.session.get(f'{self.api_url}{endpoint}', headers=self._headers()) as r: + async with self.session.get( + url=f'{self.api_url}{endpoint}', + headers=self._headers(), + ) as r: self._check_status(r, expected_status=200) return await r.read() async def _delete(self, endpoint) -> bool: - async with self.session.delete(f'{self.api_url}{endpoint}', headers=self._headers()) as r: + async with self.session.delete( + url=f'{self.api_url}{endpoint}', + headers=self._headers(), + ) as r: return r.status == 204 @handle_client_errors - async def login(self, email: str, password: str) -> Optional[str]: + async def login(self, email: str, password: str) -> str | None: req = {'email': email, 'password': password} body = await self._post_json('/tokens', json=req) self.token = body.get('token', None) @@ -148,7 +167,10 @@ async def get_current_user(self) -> dict: @handle_client_errors async def check_template_exists(self, remote_id: str) -> bool: - async with self.session.get(f'{self.templates_endpoint}/{remote_id}', headers=self._headers()) as r: + async with self.session.get( + url=f'{self.templates_endpoint}/{remote_id}', + headers=self._headers(), + ) as r: if r.status == 404: return False self._check_status(r, expected_status=200) @@ -156,20 +178,22 @@ async def check_template_exists(self, remote_id: str) -> bool: @handle_client_errors async def check_draft_exists(self, remote_id: str) -> bool: - async with self.session.get(f'{self.drafts_endpoint}/{remote_id}', headers=self._headers()) as r: + async with self.session.get( + url=f'{self.drafts_endpoint}/{remote_id}', + headers=self._headers(), + ) as r: if r.status == 404: return False self._check_status(r, expected_status=200) return True @handle_client_errors - async def get_templates(self) -> List[Template]: + async def get_templates(self) -> list[Template]: body = await self._get_json('/document-templates/all') return list(map(_load_remote_template, body)) @handle_client_errors - async def get_drafts(self) -> List[Template]: - # TODO: better paging + async def get_drafts(self) -> list[Template]: body = await self._get_json('/document-template-drafts?size=10000') drafts = body.get('_embedded', {}).get('documentTemplateDrafts', []) return list(map(_load_remote_template, drafts)) @@ -191,7 +215,7 @@ async def get_template_draft(self, remote_id: str) -> Template: return _load_remote_template(body) @handle_client_errors - async def get_template_draft_files(self, remote_id: str) -> List[TemplateFile]: + async def get_template_draft_files(self, remote_id: str) -> list[TemplateFile]: body = await self._get_json(f'/document-template-drafts/{remote_id}/files') result = [] for file_body in body: @@ -206,7 +230,7 @@ async def get_template_draft_file(self, remote_id: str, file_id: str) -> Templat return _load_remote_file(body) @handle_client_errors - async def get_template_draft_assets(self, remote_id: str) -> List[TemplateFile]: + async def get_template_draft_assets(self, remote_id: str) -> list[TemplateFile]: body = await self._get_json(f'/document-template-drafts/{remote_id}/assets') result = [] for file_body in body: @@ -217,8 +241,10 @@ async def get_template_draft_assets(self, remote_id: str) -> List[TemplateFile]: @handle_client_errors async def get_template_draft_asset(self, remote_id: str, asset_id: str) -> TemplateFile: - body = await self._get_json(f'/document-template-drafts/{remote_id}/assets/{asset_id}') - content = await self._get_bytes(f'/document-template-drafts/{remote_id}/assets/{asset_id}/content') + body = await self._get_json(f'/document-template-drafts/{remote_id}' + f'/assets/{asset_id}') + content = await self._get_bytes(f'/document-template-drafts/{remote_id}' + f'/assets/{asset_id}/content') return _load_remote_asset(body, content) @handle_client_errors @@ -258,7 +284,8 @@ async def post_template_draft_file(self, remote_id: str, tfile: TemplateFile): async def put_template_draft_file_content(self, remote_id: str, tfile: TemplateFile): self.session.headers.update(self._headers()) async with self.session.put( - f'{self.api_url}/document-template-drafts/{remote_id}/files/{tfile.remote_id}/content', + f'{self.api_url}/document-template-drafts/{remote_id}' + f'/files/{tfile.remote_id}/content', data=tfile.content, headers={'Content-Type': 'text/plain;charset=UTF-8'}, ) as r: @@ -302,7 +329,8 @@ async def put_template_draft_asset_content(self, remote_id: str, tfile: Template value=tfile.filename.as_posix(), ) async with self.session.put( - f'{self.api_url}/document-template-drafts/{remote_id}/assets/{tfile.remote_id}/content', + f'{self.api_url}/document-template-drafts/{remote_id}' + f'/assets/{tfile.remote_id}/content', data=data, headers=self._headers() ) as r: @@ -338,11 +366,11 @@ async def get_organization_id(self) -> str: def _load_remote_file(data: dict) -> TemplateFile: - content = data.get('content', None) # type: str - filename = str(data.get('fileName', '')) # type: str + content: str = data.get('content', '') + filename: str = str(data.get('fileName', '')) template_file = TemplateFile( remote_id=data.get('uuid', None), - remote_type=TemplateFileType.file, + remote_type=TemplateFileType.FILE, filename=pathlib.Path(urllib.parse.unquote(filename)), content=content.encode(encoding=DEFAULT_ENCODING), ) @@ -353,7 +381,7 @@ def _load_remote_asset(data: dict, content: bytes) -> TemplateFile: filename = str(data.get('fileName', '')) template_file = TemplateFile( remote_id=data.get('uuid', None), - remote_type=TemplateFileType.asset, + remote_type=TemplateFileType.ASSET, filename=pathlib.Path(urllib.parse.unquote(filename)), content_type=data.get('contentType', None), content=content, diff --git a/packages/dsw-tdk/dsw/tdk/cli.py b/packages/dsw-tdk/dsw/tdk/cli.py index b85b2475..1fbfc3ee 100644 --- a/packages/dsw-tdk/dsw/tdk/cli.py +++ b/packages/dsw-tdk/dsw/tdk/cli.py @@ -1,17 +1,17 @@ +# pylint: disable=too-many-positional-arguments import asyncio -import signal - -import click # type: ignore import datetime -import dotenv -import humanize # type: ignore import logging import mimetypes import pathlib -import slugify -import watchfiles # type: ignore +import signal +import sys -from typing import Dict +import click +import dotenv +import humanize +import slugify +import watchfiles from .api_client import DSWCommunicationError from .core import TDKCore, TDKProcessingError @@ -43,7 +43,7 @@ def error(message: str, **kwargs): @staticmethod def warning(message: str, **kwargs): - click.secho('WARNING', fg='yellow', bold=True, nl=False) + click.secho('WARNING', fg='yellow', bold=True, nl=False, **kwargs) click.echo(f': {message}') @staticmethod @@ -62,7 +62,8 @@ def watch(message: str): click.echo(f': {message}') @classmethod - def watch_change(cls, change_type: watchfiles.Change, filepath: pathlib.Path, root: pathlib.Path): + def watch_change(cls, change_type: watchfiles.Change, filepath: pathlib.Path, + root: pathlib.Path): timestamp = datetime.datetime.now().isoformat(timespec='milliseconds') sign = cls.CHANGE_SIGNS[change_type] click.secho('WATCH', fg='blue', bold=True, nl=False) @@ -93,10 +94,12 @@ def print_template_info(template: Template): click.echo(f' - {tfile.filename.as_posix()} [{filesize}]') +# pylint: disable-next=unused-argument def rectify_url(ctx, param, value) -> str: return value.rstrip('/') +# pylint: disable-next=unused-argument def rectify_key(ctx, param, value) -> str: return value.strip() @@ -130,7 +133,7 @@ def _format_level(self, level, justify=False): name = logging.getLevelName(level) # type: str if justify: name = name.ljust(8, ' ') - if self.colors and level in self.LEVEL_STYLES.keys(): + if self.colors and level in self.LEVEL_STYLES: name = self.LEVEL_STYLES[level](name) return name @@ -143,6 +146,7 @@ def _print_message(self, level, message): click.echo(self._format_level(level, justify=self.show_timestamp) + sep, nl=False) click.echo(message) + # pylint: disable-next=unused-argument def _log(self, level, msg, *args, **kwargs): if not self.muted: # super()._log(level, msg, args, exc_info, extra, stack_info, stacklevel) @@ -165,9 +169,9 @@ def get_command(self, ctx, cmd_name): if x.startswith(cmd_name)] if not matches: return None - elif len(matches) == 1: + if len(matches) == 1: return click.Group.get_command(self, ctx, matches[0]) - ctx.fail('Too many matches: %s' % ', '.join(sorted(matches))) + return ctx.fail(f'Too many matches: {", ".join(sorted(matches))}') class CLIContext: @@ -184,21 +188,26 @@ def quiet_mode(self): self.logger.muted = True -def interact_formats() -> Dict[str, FormatSpec]: +def interact_formats() -> dict[str, FormatSpec]: add_format = click.confirm('Do you want to add a format?', default=True) - formats = dict() # type: Dict[str, FormatSpec] + formats: dict[str, FormatSpec] = {} while add_format: format_spec = FormatSpec() prompt_fill('Format name', obj=format_spec, attr='name', default='HTML') - if format_spec.name not in formats.keys() or click.confirm( + if format_spec.name not in formats or click.confirm( 'There is already a format with this name. Do you want to change it?' ): prompt_fill('File extension', obj=format_spec, attr='file_extension', default=format_spec.name.lower() if ' ' not in format_spec.name else None) prompt_fill('Content type', obj=format_spec, attr='content_type', default=mimetypes.types_map.get(f'.{format_spec.file_extension}', None)) - default_filename = str(pathlib.Path('src') / f'template.{format_spec.file_extension}.j2') - prompt_fill('Jinja2 filename', obj=format_spec, attr='filename', default=default_filename) + t_path = pathlib.Path('src') / f'template.{format_spec.file_extension}.j2' + prompt_fill( + text='Jinja2 filename', + obj=format_spec, + attr='filename', + default=str(t_path), + ) formats[format_spec.name] = format_spec click.echo('=' * 60) add_format = click.confirm('Do you want to add yet another format?', default=False) @@ -208,10 +217,14 @@ def interact_formats() -> Dict[str, FormatSpec]: def interact_builder(builder: TemplateBuilder): prompt_fill('Template name', obj=builder, attr='name') prompt_fill('Organization ID', obj=builder, attr='organization_id') - prompt_fill('Template ID', obj=builder, attr='template_id', default=slugify.slugify(builder.name)) - prompt_fill('Version', obj=builder, attr='version', default='0.1.0') - prompt_fill('Description', obj=builder, attr='description', default='My custom template') - prompt_fill('License', obj=builder, attr='license', default='CC0') + prompt_fill('Template ID', obj=builder, attr='template_id', + default=slugify.slugify(builder.name)) + prompt_fill('Version', obj=builder, attr='version', + default='0.1.0') + prompt_fill('Description', obj=builder, attr='description', + default='My custom template') + prompt_fill('License', obj=builder, attr='license', + default='CC0') click.echo('=' * 60) formats = interact_formats() for format_spec in formats.values(): @@ -224,19 +237,16 @@ def load_local(tdk: TDKCore, template_dir: pathlib.Path): except Exception as e: ClickPrinter.failure('Could not load local template') ClickPrinter.error(f'> {e}') - exit(1) + sys.exit(1) def dir_from_id(template_id: str) -> pathlib.Path: return pathlib.Path.cwd() / template_id.replace(':', '_') -############################################################################################################# - - @click.group(cls=AliasedGroup) -@click.option('-e', '--dot-env', default='.env', required=False, show_default=True, - type=click.Path(file_okay=True, dir_okay=False), +@click.option('-e', '--dot-env', default='.env', required=False, + show_default=True, type=click.Path(file_okay=True, dir_okay=False), help='Provide file with environment variables.') @click.option('-q', '--quiet', is_flag=True, help='Hide additional information logs.') @@ -266,7 +276,7 @@ def new_template(ctx, template_dir, force): except Exception: click.echo('') ClickPrinter.failure('Exited...') - exit(1) + sys.exit(1) tdk = TDKCore(template=builder.build(), logger=ctx.obj.logger) template_dir = template_dir or dir_from_id(tdk.safe_template.id) tdk.prepare_local(template_dir=template_dir) @@ -276,7 +286,7 @@ def new_template(ctx, template_dir, force): except Exception as e: ClickPrinter.failure('Could not create new template project') ClickPrinter.error(f'> {e}') - exit(1) + sys.exit(1) @main.command(help='Download template from Wizard.', name='get') @@ -309,30 +319,32 @@ async def main_routine(): ClickPrinter.error('Could not get template:', bold=True) ClickPrinter.error(f'> {e.reason}\n> {e.message}') await tdk.safe_client.close() - exit(1) + sys.exit(1) await tdk.safe_client.safe_close() if template_type == 'draft': tdk.prepare_local(template_dir=template_dir) try: tdk.store_local(force=force) - ClickPrinter.success(f'Template draft {template_id} downloaded to {template_dir}') + ClickPrinter.success(f'Template draft {template_id} ' + f'downloaded to {template_dir}') except Exception as e: ClickPrinter.failure('Could not store template locally') ClickPrinter.error(f'> {e}') await tdk.safe_client.close() - exit(1) + sys.exit(1) elif template_type == 'bundle' and zip_data is not None: try: tdk.extract_package(zip_data=zip_data, template_dir=template_dir, force=force) - ClickPrinter.success(f'Template {template_id} (released) downloaded to {template_dir}') + ClickPrinter.success(f'Template {template_id} (released) ' + f'downloaded to {template_dir}') except Exception as e: ClickPrinter.failure('Could not store template locally') ClickPrinter.error(f'> {e}') await tdk.safe_client.close() - exit(1) + sys.exit(1) else: ClickPrinter.failure(f'{template_id} is not released nor draft of a document template') - exit(1) + sys.exit(1) loop = asyncio.get_event_loop() loop.run_until_complete(main_routine()) @@ -345,8 +357,10 @@ async def main_routine(): @click.option('-k', '--api-key', metavar='API-KEY', envvar='DSW_API_KEY', prompt='API Key', help='API key for Wizard instance.', callback=rectify_key, hide_input=True) -@click.option('-f', '--force', is_flag=True, help='Delete template if already exists.') -@click.option('-w', '--watch', is_flag=True, help='Enter watch mode to continually upload changes.') +@click.option('-f', '--force', is_flag=True, + help='Delete template if already exists.') +@click.option('-w', '--watch', is_flag=True, + help='Enter watch mode to continually upload changes.') @click.pass_context def put_template(ctx, api_url, template_dir, api_key, force, watch): tdk = TDKCore(logger=ctx.obj.logger) @@ -380,7 +394,7 @@ async def main_routine(): ClickPrinter.failure('Could not upload template') ClickPrinter.error(f'> {e.message}\n> {e.hint}') await tdk.safe_client.safe_close() - exit(1) + sys.exit(1) except DSWCommunicationError as e: ClickPrinter.failure('Could not upload template') ClickPrinter.error(f'> {e.reason}\n> {e.message}') @@ -388,8 +402,9 @@ async def main_routine(): 'or template already exists...') ClickPrinter.error('> Check if you are using the matching version') await tdk.safe_client.safe_close() - exit(1) + sys.exit(1) + # pylint: disable-next=unused-argument def set_stop_event(signum, frame): signame = signal.Signals(signum).name ClickPrinter.warning(f'Got {signame}, finishing... Bye!') @@ -417,7 +432,7 @@ def create_package(ctx, template_dir, output, force: bool): except Exception as e: ClickPrinter.failure('Failed to create the package') ClickPrinter.error(f'> {e}') - exit(1) + sys.exit(1) filename = click.style(output, bold=True) ClickPrinter.success(f'Package {filename} created') @@ -440,7 +455,7 @@ def extract_package(ctx, template_package, output, force: bool): except Exception as e: ClickPrinter.failure('Failed to extract the package') ClickPrinter.error(f'> {e}') - exit(1) + sys.exit(1) ClickPrinter.success(f'Package {template_package} extracted') @@ -484,7 +499,7 @@ async def main_routine(): ClickPrinter.failure('Failed to get list of templates') ClickPrinter.error(f'> {e.reason}\n> {e.message}') await tdk.safe_client.safe_close() - exit(1) + sys.exit(1) loop = asyncio.get_event_loop() loop.run_until_complete(main_routine()) @@ -529,4 +544,4 @@ def create_dot_env(ctx, template_dir, api_url, api_key, force): except Exception as e: ClickPrinter.failure('Failed to create dot-env file') ClickPrinter.error(f'> {e}') - exit(1) + sys.exit(1) diff --git a/packages/dsw-tdk/dsw/tdk/consts.py b/packages/dsw-tdk/dsw/tdk/consts.py index 213e7069..50d9961d 100644 --- a/packages/dsw-tdk/dsw/tdk/consts.py +++ b/packages/dsw-tdk/dsw/tdk/consts.py @@ -1,7 +1,8 @@ import pathlib -import pathspec # type: ignore import re +import pathspec + APP = 'dsw-tdk' VERSION = '4.13.0' METAMODEL_VERSION = 16 @@ -17,4 +18,4 @@ DEFAULT_README = pathlib.Path('README.md') TEMPLATE_FILE = 'template.json' -PATHSPEC_FACTORY = pathspec.patterns.GitWildMatchPattern +PathspecFactory = pathspec.patterns.GitWildMatchPattern diff --git a/packages/dsw-tdk/dsw/tdk/core.py b/packages/dsw-tdk/dsw/tdk/core.py index 5f4f661a..8eb403a8 100644 --- a/packages/dsw-tdk/dsw/tdk/core.py +++ b/packages/dsw-tdk/dsw/tdk/core.py @@ -7,11 +7,9 @@ import re import shutil import tempfile -import watchfiles # type: ignore import zipfile - -from typing import List, Optional, Tuple +import watchfiles from .api_client import DSWAPIClient, DSWCommunicationError from .consts import DEFAULT_ENCODING, REGEX_SEMVER @@ -20,7 +18,7 @@ from .validation import ValidationError, TemplateValidator -ChangeItem = Tuple[watchfiles.Change, pathlib.Path] +ChangeItem = tuple[watchfiles.Change, pathlib.Path] def _change(item: ChangeItem, root: pathlib.Path) -> str: @@ -54,6 +52,7 @@ def __init__(self, message: str, hint: str): } +# pylint: disable=too-many-public-methods class TDKCore: def _check_metamodel_version(self): @@ -64,24 +63,24 @@ def _check_metamodel_version(self): if 'v' == api_version[0]: api_version = api_version[1:] if not re.match(REGEX_SEMVER, api_version): - self.logger.warning(f'Using non-stable release of API: {self.remote_version}') + self.logger.warning('Using non-stable release of API: %s', self.remote_version) return parts = api_version.split('.') ver = (int(parts[0]), int(parts[1]), int(parts[2])) vtag = f'v{ver[0]}.{ver[1]}.{ver[2]}' hint = 'Fix your metamodelVersion in template.json and/or visit docs' - if mm_ver not in METAMODEL_VERSION_SUPPORT.keys(): + if mm_ver not in METAMODEL_VERSION_SUPPORT: raise TDKProcessingError(f'Unknown metamodel version: {mm_ver}', hint) min_version = METAMODEL_VERSION_SUPPORT[mm_ver] if min_version > ver: raise TDKProcessingError(f'Unsupported metamodel version for API {vtag}', hint) - if mm_ver + 1 in METAMODEL_VERSION_SUPPORT.keys(): + if mm_ver + 1 in METAMODEL_VERSION_SUPPORT: max_version = METAMODEL_VERSION_SUPPORT[mm_ver + 1] if ver >= max_version: raise TDKProcessingError(f'Unsupported metamodel version for API {vtag}', hint) - def __init__(self, template: Optional[Template] = None, project: Optional[TemplateProject] = None, - client: Optional[DSWAPIClient] = None, logger: Optional[logging.Logger] = None): + def __init__(self, template: Template | None = None, project: TemplateProject | None = None, + client: DSWAPIClient | None = None, logger: logging.Logger | None = None): self.template = template self.project = project self.client = client @@ -113,13 +112,13 @@ def safe_client(self) -> DSWAPIClient: return self.client async def init_client(self, api_url: str, api_key: str): - self.logger.info(f'Connecting to {api_url}') + self.logger.info('Connecting to %s', api_url) self.client = DSWAPIClient(api_url=api_url, api_key=api_key) self.remote_version = await self.client.get_api_version() user = await self.client.get_current_user() - self.logger.info(f'Successfully authenticated as {user["firstName"]} ' - f'{user["lastName"]} ({user["email"]})') - self.logger.debug(f'Connected to API version {self.remote_version}') + self.logger.info('Successfully authenticated as %s %s (%s)', + user['firstName'], user['lastName'], user['email']) + self.logger.debug('Connected to API version %s', self.remote_version) def prepare_local(self, template_dir): self.logger.debug('Preparing local template project') @@ -131,32 +130,32 @@ def load_local(self, template_dir): self.safe_project.load() async def load_remote(self, template_id: str): - self.logger.info(f'Retrieving template draft {template_id}') + self.logger.info('Retrieving template draft %s', template_id) self.template = await self.safe_client.get_template_draft(remote_id=template_id) self.logger.debug('Retrieving template draft files') files = await self.safe_client.get_template_draft_files(remote_id=template_id) - self.logger.info(f'Retrieved {len(files)} file(s)') + self.logger.info('Retrieved %s file(s)', len(files)) for tfile in files: self.safe_template.files[tfile.filename.as_posix()] = tfile self.logger.debug('Retrieving template draft assets') assets = await self.safe_client.get_template_draft_assets(remote_id=template_id) - self.logger.info(f'Retrieved {len(assets)} asset(s)') + self.logger.info('Retrieved %s asset(s)', len(assets)) for tfile in assets: self.safe_template.files[tfile.filename.as_posix()] = tfile async def download_bundle(self, template_id: str) -> bytes: - self.logger.info(f'Retrieving template {template_id} bundle') + self.logger.info('Retrieving template %s bundle', template_id) return await self.safe_client.get_template_bundle(remote_id=template_id) - async def list_remote_templates(self) -> List[Template]: + async def list_remote_templates(self) -> list[Template]: self.logger.info('Listing remote document templates') return await self.safe_client.get_templates() - async def list_remote_drafts(self) -> List[Template]: + async def list_remote_drafts(self) -> list[Template]: self.logger.info('Listing remote document template drafts') return await self.safe_client.get_drafts() - def verify(self) -> List[ValidationError]: + def verify(self) -> list[ValidationError]: template = self.template or self.safe_project.template if template is None: raise RuntimeError('No template is loaded') @@ -166,7 +165,7 @@ def store_local(self, force: bool): if self.project is None: raise RuntimeError('No template project is initialized') self.project.template = self.safe_template - self.logger.debug(f'Initiating storing local template project (force={force})') + self.logger.debug('Initiating storing local template project (force=%s)', force) self.project.store(force=force) async def store_remote(self, force: bool): @@ -174,8 +173,9 @@ async def store_remote(self, force: bool): self._check_metamodel_version() org_id = await self.safe_client.get_organization_id() if org_id != self.safe_template.organization_id: - self.logger.warning(f'There is different organization ID set in the DSW instance' - f' (local: {self.safe_template.organization_id}, remote: {org_id})') + self.logger.warning('There is different organization ID set in the DSW instance' + ' (local: %s, remote: %s)', + self.safe_template.organization_id, org_id) self.remote_id = self.safe_template.id_with_org(org_id) template_exists = await self.safe_client.check_draft_exists(remote_id=self.remote_id) if template_exists and force: @@ -186,26 +186,39 @@ async def store_remote(self, force: bool): template_exists = not result if template_exists: - # TODO: do not remove if not necessary (make diff?) self.logger.info('Updating existing remote document template draft') - await self.safe_client.update_template_draft(template=self.safe_template, remote_id=self.remote_id) + await self.safe_client.update_template_draft( + template=self.safe_template, + remote_id=self.remote_id, + ) self.logger.debug('Retrieving remote assets') - remote_assets = await self.safe_client.get_template_draft_assets(remote_id=self.remote_id) + remote_assets = await self.safe_client.get_template_draft_assets( + remote_id=self.remote_id, + ) self.logger.debug('Retrieving remote files') - remote_files = await self.safe_client.get_template_draft_files(remote_id=self.remote_id) - await self.cleanup_remote_files(remote_assets=remote_assets, remote_files=remote_files) + remote_files = await self.safe_client.get_template_draft_files( + remote_id=self.remote_id, + ) + await self.cleanup_remote_files( + remote_assets=remote_assets, + remote_files=remote_files, + ) else: self.logger.info('Creating remote document template draft') - await self.safe_client.create_new_template_draft(template=self.safe_template, remote_id=self.remote_id) + await self.safe_client.create_new_template_draft( + template=self.safe_template, + remote_id=self.remote_id, + ) await self.store_remote_files() async def _update_template_file(self, remote_tfile: TemplateFile, local_tfile: TemplateFile, project_update: bool = False): try: - self.logger.debug(f'Updating existing remote {remote_tfile.remote_type.value} ' - f'{remote_tfile.filename.as_posix()} ({remote_tfile.remote_id}) started') + self.logger.debug('Updating existing remote %s %s (%s) started', + remote_tfile.remote_type.value, remote_tfile.filename.as_posix(), + remote_tfile.remote_id) local_tfile.remote_id = remote_tfile.remote_id - if remote_tfile.remote_type == TemplateFileType.asset: + if remote_tfile.remote_type == TemplateFileType.ASSET: result = await self.safe_client.put_template_draft_asset_content( remote_id=self.remote_id, tfile=local_tfile, @@ -215,25 +228,27 @@ async def _update_template_file(self, remote_tfile: TemplateFile, local_tfile: T remote_id=self.remote_id, tfile=local_tfile, ) - self.logger.debug(f'Updating existing remote {remote_tfile.remote_type.value} ' - f'{remote_tfile.filename.as_posix()} ({remote_tfile.remote_id}) ' - f'finished: {"ok" if result else "failed"}') + self.logger.debug('Updating existing remote %s %s (%s) finished: %s', + remote_tfile.remote_type.value, remote_tfile.filename.as_posix(), + remote_tfile.remote_id, 'ok' if result else 'failed') if project_update and result: self.safe_project.update_template_file(result) - except Exception as e: + except Exception as e1: try: - self.logger.debug(f'Trying to delete/create due to: {str(e)}') + self.logger.debug('Trying to delete/create due to: %s', str(e1)) await self._delete_template_file(tfile=remote_tfile) await self._create_template_file(tfile=local_tfile, project_update=True) - except Exception as e: - self.logger.error(f'Failed to update existing remote {remote_tfile.remote_type.value} ' - f'{remote_tfile.filename.as_posix()}: {e}') + except Exception as e2: + self.logger.error('Failed to update existing remote %s %s: %s', + remote_tfile.remote_type.value, + remote_tfile.filename.as_posix(), e2) async def _delete_template_file(self, tfile: TemplateFile, project_update: bool = False): try: - self.logger.debug(f'Deleting existing remote {tfile.remote_type.value} ' - f'{tfile.filename.as_posix()} ({tfile.remote_id}) started') - if tfile.remote_type == TemplateFileType.asset: + self.logger.debug('Deleting existing remote %s %s (%s) started', + tfile.remote_type.value, tfile.filename.as_posix(), + tfile.remote_id) + if tfile.remote_type == TemplateFileType.ASSET: result = await self.safe_client.delete_template_draft_asset( remote_id=self.remote_id, asset_id=tfile.remote_id, @@ -243,18 +258,19 @@ async def _delete_template_file(self, tfile: TemplateFile, project_update: bool remote_id=self.remote_id, file_id=tfile.remote_id, ) - self.logger.debug(f'Deleting existing remote {tfile.remote_type.value} ' - f'{tfile.filename.as_posix()} ({tfile.remote_id}) ' - f'finished: {"ok" if result else "failed"}') + self.logger.debug('Deleting existing remote %s %s (%s) finished: %s', + tfile.remote_type.value, tfile.filename.as_posix(), + tfile.remote_id, 'ok' if result else 'failed') if project_update and result: self.safe_project.remove_template_file(tfile.filename) except Exception as e: - self.logger.error(f'Failed to delete existing remote {tfile.remote_type.value} ' - f'{tfile.filename.as_posix()}: {e}') + self.logger.error('Failed to delete existing remote %s %s: %s', + tfile.remote_type.value, tfile.filename.as_posix(), e) - async def cleanup_remote_files(self, remote_assets: List[TemplateFile], remote_files: List[TemplateFile]): + async def cleanup_remote_files(self, remote_assets: list[TemplateFile], + remote_files: list[TemplateFile]): for tfile in self.safe_project.safe_template.files.values(): - self.logger.debug(f'Cleaning up remote {tfile.filename.as_posix()}') + self.logger.debug('Cleaning up remote %s', tfile.filename.as_posix()) for remote_asset in remote_assets: if remote_asset.filename == tfile.filename: await self._delete_template_file(tfile=remote_asset, project_update=False) @@ -264,60 +280,68 @@ async def cleanup_remote_files(self, remote_assets: List[TemplateFile], remote_f async def _create_template_file(self, tfile: TemplateFile, project_update: bool = False): try: - self.logger.debug(f'Storing remote {tfile.remote_type.value} ' - f'{tfile.filename.as_posix()} started') - if tfile.remote_type == TemplateFileType.asset: - result = await self.safe_client.post_template_draft_asset(remote_id=self.remote_id, tfile=tfile) + self.logger.debug('Storing remote %s %s started', + tfile.remote_type.value, tfile.filename.as_posix()) + if tfile.remote_type == TemplateFileType.ASSET: + result = await self.safe_client.post_template_draft_asset( + remote_id=self.remote_id, + tfile=tfile, + ) else: - result = await self.safe_client.post_template_draft_file(remote_id=self.remote_id, tfile=tfile) - self.logger.debug(f'Storing remote {tfile.remote_type.value} ' - f'{tfile.filename.as_posix()} finished: {result.remote_id}') + result = await self.safe_client.post_template_draft_file( + remote_id=self.remote_id, + tfile=tfile, + ) + self.logger.debug('Storing remote %s %s finished: %s', + tfile.remote_type.value, tfile.filename.as_posix(), result.remote_id) if project_update and result is not None: self.safe_project.update_template_file(result) except Exception as e: - self.logger.error(f'Failed to store remote {tfile.remote_type.value} {tfile.filename.as_posix()}: {str(e)}') + self.logger.error('Failed to store remote %s %s: %s', + tfile.remote_type.value, tfile.filename.as_posix(), e) async def store_remote_files(self): for tfile in self.safe_project.safe_template.files.values(): tfile.remote_id = None - tfile.remote_type = TemplateFileType.file if tfile.is_text else TemplateFileType.asset + tfile.remote_type = TemplateFileType.FILE if tfile.is_text else TemplateFileType.ASSET await self._create_template_file(tfile=tfile, project_update=True) def create_package(self, output: pathlib.Path, force: bool): if output.exists() and not force: raise RuntimeError(f'File {output} already exists (not forced)') - self.logger.debug(f'Opening ZIP file for write: {output}') + self.logger.debug('Opening ZIP file for write: %s', output.as_posix()) with zipfile.ZipFile(output, mode='w', compression=zipfile.ZIP_DEFLATED) as pkg: descriptor = self.safe_project.safe_template.serialize_remote() files = [] assets = [] for tfile in self.safe_project.safe_template.files.values(): if tfile.is_text: - self.logger.info(f'Adding template file {tfile.filename.as_posix()}') + self.logger.info('Adding template file %s', tfile.filename.as_posix()) files.append({ 'uuid': str(UUIDGen.generate()), 'content': tfile.content.decode(encoding=DEFAULT_ENCODING), 'fileName': str(tfile.filename.as_posix()), }) else: - self.logger.info(f'Adding template asset {tfile.filename.as_posix()}') + self.logger.info('Adding template asset %s', tfile.filename.as_posix()) assets.append({ 'uuid': str(UUIDGen.generate()), 'contentType': tfile.content_type, 'fileName': str(tfile.filename.as_posix()), }) - self.logger.debug(f'Packaging template asset {tfile.filename}') + self.logger.debug('Packaging template asset %s', tfile.filename.as_posix()) pkg.writestr(f'template/assets/{tfile.filename.as_posix()}', tfile.content) descriptor['files'] = files descriptor['assets'] = assets - timestamp = datetime.datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.%fZ') + timestamp = datetime.datetime.now(tz=datetime.UTC).strftime('%Y-%m-%dT%H:%M:%S.%fZ') descriptor['createdAt'] = timestamp descriptor['updatedAt'] = timestamp self.logger.debug('Packaging template.json file') pkg.writestr('template/template.json', data=json.dumps(descriptor, indent=4)) self.logger.debug('ZIP packaging done') - def extract_package(self, zip_data: bytes, template_dir: Optional[pathlib.Path], force: bool): + # pylint: disable=too-many-locals + def extract_package(self, zip_data: bytes, template_dir: pathlib.Path | None, force: bool): with tempfile.TemporaryDirectory() as tmp_dir: io_zip = io.BytesIO(zip_data) with zipfile.ZipFile(io_zip) as pkg: @@ -373,7 +397,7 @@ def extract_package(self, zip_data: bytes, template_dir: Optional[pathlib.Path], def create_dot_env(self, output: pathlib.Path, force: bool, api_url: str, api_key: str): if output.exists(): if force: - self.logger.warning(f'Overwriting {output.as_posix()} (forced)') + self.logger.warning('Overwriting %s (forced)', output.as_posix()) else: raise RuntimeError(f'File {output} already exists (not forced)') output.write_text( @@ -393,47 +417,49 @@ async def watch_project(self, callback, stop_event: asyncio.Event): ) )) - async def _update_descriptor(self): + async def update_descriptor(self): try: template_exists = await self.safe_client.check_draft_exists( remote_id=self.remote_id, ) if template_exists: - self.logger.info(f'Updating existing remote document template draft' - f' {self.safe_project.safe_template.id}') + self.logger.info('Updating existing remote document template draft %s', + self.safe_project.safe_template.id) await self.safe_client.update_template_draft( template=self.safe_project.safe_template, remote_id=self.remote_id, ) else: - # TODO: optimization - reload full template and send it, skip all other changes - self.logger.info(f'Document template draft {self.safe_project.safe_template.id} ' - f'does not exist on remote - full sync') + self.logger.info('Document template draft %s does not exist on remote - full sync', + self.safe_project.safe_template.id) await self.store_remote(force=False) except DSWCommunicationError as e: - self.logger.error(f'Failed to update document template draft' - f' {self.safe_project.safe_template.id}: {e.message}') + self.logger.error('Failed to update document template draft %s: %s', + self.safe_project.safe_template.id, e.message) except Exception as e: - self.logger.error(f'Failed to update document template draft' - f' {self.safe_project.safe_template.id}: {e}') + self.logger.error('Failed to update document template draft %s: %s', + self.safe_project.safe_template.id, e) - async def _delete_file(self, filepath: pathlib.Path): + async def delete_file(self, filepath: pathlib.Path): if not filepath.is_file(): - self.logger.debug(f'{filepath.as_posix()} is not a regular file - skipping') + self.logger.debug('%s is not a regular file - skipping', + filepath.as_posix()) return try: tfile = self.safe_project.get_template_file(filepath=filepath) if tfile is None: - # TODO: try to check if exists on remote (may not be synced yet) - self.logger.info(f'File {filepath.as_posix()} not tracked currently - skipping') + self.logger.info('File %s not tracked currently - skipping', + filepath.as_posix()) return await self._delete_template_file(tfile=tfile, project_update=True) except Exception as e: - self.logger.error(f'Failed to delete file {filepath.as_posix()}: {e}') + self.logger.error('Failed to delete file %s: %s', + filepath.as_posix(), e) - async def _update_file(self, filepath: pathlib.Path): + async def update_file(self, filepath: pathlib.Path): if not filepath.is_file(): - self.logger.debug(f'{filepath.as_posix()} is not a regular file - skipping') + self.logger.debug('%s is not a regular file - skipping', + filepath.as_posix()) return try: remote_tfile = self.safe_project.get_template_file(filepath=filepath) @@ -443,30 +469,30 @@ async def _update_file(self, filepath: pathlib.Path): else: await self._create_template_file(tfile=local_tfile, project_update=True) except Exception as e: - self.logger.error(f'Failed to update file {filepath.as_posix()}: {e}') + self.logger.error('Failed to update file %s: %s', filepath.as_posix(), e) - async def process_changes(self, changes: List[ChangeItem], force: bool): + async def process_changes(self, changes: list[ChangeItem], force: bool): self.changes_processor.clear() try: await self.changes_processor.process_changes(changes, force) except Exception as e: - self.logger.error(f'Failed to process changes: {e}') + self.logger.error('Failed to process changes: %s', e) class ChangesProcessor: def __init__(self, tdk: TDKCore): - self.tdk = tdk # type: TDKCore - self.descriptor_change = None # type: Optional[ChangeItem] - self.readme_change = None # type: Optional[ChangeItem] - self.file_changes = [] # type: List[ChangeItem] + self.tdk: TDKCore = tdk + self.descriptor_change: ChangeItem | None = None + self.readme_change: ChangeItem | None = None + self.file_changes: list[ChangeItem] = [] def clear(self): self.descriptor_change = None self.readme_change = None self.file_changes = [] - def _split_changes(self, changes: List[ChangeItem]): + def _split_changes(self, changes: list[ChangeItem]): for change in changes: if change[1] == self.tdk.safe_project.descriptor_path: self.descriptor_change = change @@ -479,30 +505,32 @@ async def _process_file_changes(self): deleted = set() updated = set() for file_change in self.file_changes: - self.tdk.logger.debug(f'Processing: {_change(file_change, self.tdk.safe_project.template_dir)}') + self.tdk.logger.debug('Processing: %s', + _change(file_change, self.tdk.safe_project.template_dir)) change_type = file_change[0] filepath = file_change[1] if change_type == watchfiles.Change.deleted and filepath not in deleted: self.tdk.logger.debug('Scheduling delete operation') deleted.add(filepath) - await self.tdk._delete_file(filepath) + await self.tdk.delete_file(filepath) elif filepath not in updated: self.tdk.logger.debug('Scheduling update operation') updated.add(filepath) - await self.tdk._update_file(filepath) + await self.tdk.update_file(filepath) async def _reload_descriptor(self, force: bool) -> bool: if self.descriptor_change is None: return False if self.descriptor_change[0] == watchfiles.Change.deleted: - raise RuntimeError(f'Deleted template descriptor {self.tdk.safe_project.descriptor_path} ... the end') - self.tdk.logger.debug(f'Reloading {TemplateProject.TEMPLATE_FILE} file') + raise RuntimeError(f'Deleted {self.tdk.safe_project.descriptor_path} ... the end') + self.tdk.logger.debug('Reloading %s file', TemplateProject.TEMPLATE_FILE) previous_id = self.tdk.safe_project.safe_template.id self.tdk.safe_project.load_descriptor() self.tdk.safe_project.load_readme() new_id = self.tdk.safe_project.safe_template.id if new_id != previous_id: - self.tdk.logger.warning(f'Template ID changed from {previous_id} to {new_id}') + self.tdk.logger.warning('Template ID changed from %s to %s', + previous_id, new_id) self.tdk.safe_project.load() await self.tdk.store_remote(force=force) self.tdk.logger.info('Template fully reloaded... waiting for new changes') @@ -521,10 +549,10 @@ async def _reload_readme(self) -> bool: async def _update_descriptor(self): if self.readme_change is not None or self.descriptor_change is not None: self.tdk.logger.debug('Updating template descriptor (metadata)') - await self.tdk._update_descriptor() + await self.tdk.update_descriptor() self.tdk.safe_project.template = self.tdk.safe_template - async def process_changes(self, changes: List[ChangeItem], force: bool): + async def process_changes(self, changes: list[ChangeItem], force: bool): self._split_changes(changes) full_reload = await self._reload_descriptor(force) if not full_reload: diff --git a/packages/dsw-tdk/dsw/tdk/model.py b/packages/dsw-tdk/dsw/tdk/model.py index 0046c80b..7bc9bdbd 100644 --- a/packages/dsw-tdk/dsw/tdk/model.py +++ b/packages/dsw-tdk/dsw/tdk/model.py @@ -3,29 +3,30 @@ import logging import mimetypes import pathlib -import pathspec # type: ignore from collections import OrderedDict -from typing import List, Dict, Optional, Tuple, Any +from typing import Any -from .consts import VERSION, DEFAULT_ENCODING, METAMODEL_VERSION, PATHSPEC_FACTORY +import pathspec + +from .consts import VERSION, DEFAULT_ENCODING, METAMODEL_VERSION, PathspecFactory mimetypes.init() class TemplateFileType(enum.Enum): - asset = 'asset' - file = 'file' + ASSET = 'asset' + FILE = 'file' class PackageFilter: - def __init__(self, *, organization_id=None, km_id=None, min_version=None, - max_version=None): - self.organization_id = organization_id # type: Optional[str] - self.km_id = km_id # type: Optional[str] - self.min_version = min_version # type: Optional[str] - self.max_version = max_version # type: Optional[str] + def __init__(self, *, organization_id: str | None = None, km_id: str | None = None, + min_version: str | None = None, max_version: str | None = None): + self.organization_id = organization_id + self.km_id = km_id + self.min_version = min_version + self.max_version = max_version @classmethod def load(cls, data): @@ -47,9 +48,10 @@ def serialize(self): class Step: - def __init__(self, *, name=None, options=None): - self.name = name # type: str - self.options = options or dict() # type: Dict[str, str] + def __init__(self, *, name: str | None = None, + options: dict[str, str] | None = None): + self.name = name + self.options = options or {} @classmethod def load(cls, data): @@ -69,11 +71,12 @@ class Format: DEFAULT_ICON = 'fas fa-file' - def __init__(self, *, uuid=None, name=None, icon=None): - self.uuid = uuid # type: str - self.name = name # type: str - self.icon = icon or self.DEFAULT_ICON # type: str - self.steps = [] # type: List[Step] + def __init__(self, *, uuid: str | None = None, name: str | None = None, + icon: str | None = None): + self.uuid = uuid + self.name = name + self.icon: str = icon or self.DEFAULT_ICON + self.steps: list[Step] = [] @classmethod def load(cls, data): @@ -100,11 +103,12 @@ class TDKConfig: DEFAULT_README = 'README.md' DEFAULT_FILES = ['*'] - def __init__(self, *, version=None, readme_file=None, files=None): - self.version = version or VERSION # type: str - readme_file_str = readme_file or self.DEFAULT_README # type: str - self.readme_file = pathlib.Path(readme_file_str) # type: pathlib.Path - self.files = files or [] # type: List[str] + def __init__(self, *, version: str | None = None, readme_file: str | None = None, + files: list[str] | None = None): + self.version: str = version or VERSION + readme_file_str: str = readme_file or self.DEFAULT_README + self.readme_file: pathlib.Path = pathlib.Path(readme_file_str) + self.files: list[str] = files or [] @classmethod def load(cls, data): @@ -127,19 +131,19 @@ class TemplateFile: DEFAULT_CONTENT_TYPE = 'application/octet-stream' TEMPLATE_EXTENSIONS = ('.j2', '.jinja', '.jinja2', '.jnj') - def __init__(self, *, remote_id=None, remote_type=None, filename=None, - content_type=None, content=None): - self.remote_id = remote_id # type: Optional[str] - self.filename = filename # type: pathlib.Path - self.content = content # type: bytes - self.content_type = content_type or self.guess_type() # type: str - self.remote_type = remote_type or self.guess_tfile_type() # type: TemplateFileType + def __init__(self, *, filename: pathlib.Path, + remote_id: str | None = None, remote_type: TemplateFileType | None = None, + content_type: str | None = None, content: bytes = b''): + self.remote_id = remote_id + self.filename = filename + self.content = content + self.content_type: str = content_type or self.guess_type() + self.remote_type: TemplateFileType = remote_type or self.guess_tfile_type() def guess_tfile_type(self): - return TemplateFileType.file if self.is_text else TemplateFileType.asset + return TemplateFileType.FILE if self.is_text else TemplateFileType.ASSET def guess_type(self) -> str: - # TODO: add own map of file extensions filename = self.filename.name for ext in self.TEMPLATE_EXTENSIONS: if filename.endswith(ext): @@ -151,8 +155,7 @@ def guess_type(self) -> str: @property def is_text(self): - # TODO: custom mapping (also some starting with "application" are textual) - if getattr(self, 'remote_type', None) == TemplateFileType.file: + if getattr(self, 'remote_type', None) == TemplateFileType.FILE: return True return self.content_type.startswith('text') @@ -163,6 +166,7 @@ def has_remote_id(self): class Template: + # pylint: disable-next=too-many-arguments def __init__(self, *, template_id=None, organization_id=None, version=None, name=None, description=None, readme=None, template_license=None, metamodel_version=None, tdk_config=None, loaded_json=None): @@ -173,13 +177,13 @@ def __init__(self, *, template_id=None, organization_id=None, version=None, name self.description = description # type: str self.readme = readme # type: str self.license = template_license # type: str - self.metamodel_version = metamodel_version or METAMODEL_VERSION # type: int - self.allowed_packages = [] # type: List[PackageFilter] - self.formats = [] # type: List[Format] - self.files = {} # type: Dict[str, TemplateFile] - self.extras = [] # type: List[TemplateFile] - self.tdk_config = tdk_config or TDKConfig() # type: TDKConfig - self.loaded_json = loaded_json or OrderedDict() # type: OrderedDict + self.metamodel_version: int = metamodel_version or METAMODEL_VERSION + self.allowed_packages: list[PackageFilter] = [] + self.formats: list[Format] = [] + self.files: dict[str, TemplateFile] = {} + self.extras: list[TemplateFile] = [] + self.tdk_config: TDKConfig = tdk_config or TDKConfig() + self.loaded_json: OrderedDict = loaded_json or OrderedDict() @property def id(self) -> str: @@ -200,8 +204,8 @@ def _common_load(cls, data): org_id = data['organizationId'] tmp_id = data['templateId'] version = data['version'] - except KeyError: - raise RuntimeError('Cannot retrieve template ID') + except KeyError as e: + raise RuntimeError('Cannot retrieve template ID') from e template = Template( template_id=tmp_id, organization_id=org_id, @@ -244,7 +248,7 @@ def serialize_local(self) -> OrderedDict: self.loaded_json['_tdk'] = self.tdk_config.serialize() return self.loaded_json - def serialize_remote(self) -> Dict[str, Any]: + def serialize_remote(self) -> dict[str, Any]: return { 'id': self.id, 'templateId': self.template_id, @@ -260,7 +264,7 @@ def serialize_remote(self) -> Dict[str, Any]: 'phase': 'DraftDocumentTemplatePhase', } - def serialize_for_update(self) -> Dict[str, Any]: + def serialize_for_update(self) -> dict[str, Any]: return { 'templateId': self.template_id, 'version': self.version, @@ -274,7 +278,7 @@ def serialize_for_update(self) -> Dict[str, Any]: 'phase': 'DraftDocumentTemplatePhase', } - def serialize_for_create(self, based_on: Optional[str] = None) -> Dict[str, Any]: + def serialize_for_create(self, based_on: str | None = None) -> dict[str, Any]: return { 'basedOn': based_on, 'name': self.name, @@ -282,7 +286,7 @@ def serialize_for_create(self, based_on: Optional[str] = None) -> Dict[str, Any] 'version': self.version, } - def serialize_local_new(self) -> Dict[str, Any]: + def serialize_local_new(self) -> dict[str, Any]: return { 'templateId': self.template_id, 'organizationId': self.organization_id, @@ -297,7 +301,7 @@ def serialize_local_new(self) -> Dict[str, Any]: } -def _to_ordered_dict(tuples: List[Tuple[str, Any]]) -> OrderedDict: +def _to_ordered_dict(tuples: list[tuple[str, Any]]) -> OrderedDict: return OrderedDict(tuples) @@ -309,12 +313,12 @@ class TemplateProject: json_decoder = json.JSONDecoder(object_pairs_hook=_to_ordered_dict) - def __init__(self, template_dir, logger): - self.template_dir = pathlib.Path(template_dir) # type: pathlib.Path + def __init__(self, template_dir: pathlib.Path, logger: logging.Logger): + self.template_dir = pathlib.Path(template_dir) self.descriptor_path = self.template_dir / self.TEMPLATE_FILE - self.template = None # type: Optional[Template] - self.used_readme = None # type: Optional[pathlib.Path] - self._logger = logger # type: logging.Logger + self.template: Template | None = None + self.used_readme: pathlib.Path | None = None + self._logger = logger @property def logger(self) -> logging.Logger: @@ -332,8 +336,8 @@ def load_descriptor(self): try: content = self.descriptor_path.read_text(encoding=DEFAULT_ENCODING) self.template = Template.load_local(self.json_decoder.decode(content)) - except Exception: - raise RuntimeError(f'Unable to load template using {self.descriptor_path}.') + except Exception as e: + raise RuntimeError(f'Unable to load template using {self.descriptor_path}.') from e def load_readme(self): readme = self.safe_template.tdk_config.readme_file @@ -342,7 +346,7 @@ def load_readme(self): self.used_readme = self.template_dir / readme self.safe_template.readme = self.used_readme.read_text(encoding=DEFAULT_ENCODING) except Exception as e: - raise RuntimeWarning(f'README file "{readme}" cannot be loaded: {e}') + raise RuntimeWarning(f'README file "{readme}" cannot be loaded: {e}') from e def load_file(self, filepath: pathlib.Path) -> TemplateFile: try: @@ -354,7 +358,7 @@ def load_file(self, filepath: pathlib.Path) -> TemplateFile: self.safe_template.files[filepath.as_posix()] = tfile return tfile except Exception as e: - raise RuntimeWarning(f'Failed to load template file {filepath}: {e}') + raise RuntimeWarning(f'Failed to load template file {filepath}: {e}') from e def load_files(self): self.safe_template.files.clear() @@ -363,22 +367,24 @@ def load_files(self): @property def files_pathspec(self) -> pathspec.PathSpec: - # TODO: make this more efficient (reload only when tdk_config changes, otherwise cache) patterns = self.safe_template.tdk_config.files + self.DEFAULT_PATTERNS - return pathspec.PathSpec.from_lines(PATHSPEC_FACTORY, patterns) + return pathspec.PathSpec.from_lines(PathspecFactory, patterns) - def list_files(self) -> List[pathlib.Path]: - files = (pathlib.Path(p) for p in self.files_pathspec.match_tree_files(str(self.template_dir))) + def list_files(self) -> list[pathlib.Path]: + files = (pathlib.Path(p) + for p in self.files_pathspec.match_tree_files(str(self.template_dir))) if self.used_readme is not None: return list(p for p in files if p != self.used_readme.relative_to(self.template_dir)) return list(files) - def _relative_paths_eq(self, filepath1: Optional[pathlib.Path], filepath2: Optional[pathlib.Path]) -> bool: + def _relative_paths_eq(self, filepath1: pathlib.Path | None, + filepath2: pathlib.Path | None) -> bool: if filepath1 is None or filepath2 is None: return False return filepath1.relative_to(self.template_dir) == filepath2.relative_to(self.template_dir) - def is_template_file(self, filepath: pathlib.Path, include_descriptor: bool = False, include_readme: bool = False): + def is_template_file(self, filepath: pathlib.Path, include_descriptor: bool = False, + include_readme: bool = False): if include_readme and self._relative_paths_eq(filepath, self.used_readme): return True if include_descriptor and self._relative_paths_eq(filepath, self.descriptor_path): @@ -401,7 +407,7 @@ def update_template_file(self, tfile: TemplateFile): filename = tfile.filename.as_posix() self.safe_template.files[filename] = tfile - def get_template_file(self, filepath: pathlib.Path) -> Optional[TemplateFile]: + def get_template_file(self, filepath: pathlib.Path) -> TemplateFile | None: if filepath.is_absolute(): filepath = filepath.relative_to(self.template_dir) return self.safe_template.files.get(filepath.as_posix(), None) @@ -420,7 +426,10 @@ def _write_file(self, filepath: pathlib.Path, contents: bytes, force: bool): def store_descriptor(self, force: bool): self._write_file( filepath=self.descriptor_path, - contents=json.dumps(self.safe_template.serialize_local(), indent=4).encode(encoding=DEFAULT_ENCODING), + contents=json.dumps( + obj=self.safe_template.serialize_local(), + indent=4 + ).encode(encoding=DEFAULT_ENCODING), force=force, ) diff --git a/packages/dsw-tdk/dsw/tdk/utils.py b/packages/dsw-tdk/dsw/tdk/utils.py index 496e60da..7fa1814d 100644 --- a/packages/dsw-tdk/dsw/tdk/utils.py +++ b/packages/dsw-tdk/dsw/tdk/utils.py @@ -1,8 +1,7 @@ -import jinja2 # type: ignore import pathlib import uuid -from typing import List, Set, Optional +import jinja2 from .consts import DEFAULT_ENCODING, DEFAULT_README from .model import Template, TemplateFile, Format, Step, PackageFilter @@ -20,10 +19,10 @@ class UUIDGen: - _uuids = set() # type: Set[uuid.UUID] + _uuids: set[uuid.UUID] = set() @classmethod - def used(cls) -> Set[uuid.UUID]: + def used(cls) -> set[uuid.UUID]: return cls._uuids @classmethod @@ -91,10 +90,10 @@ class TemplateBuilder: def __init__(self): self.template = Template() - self._formats = [] # type: List[FormatSpec] + self._formats: list[FormatSpec] = [] @property - def formats(self) -> List[FormatSpec]: + def formats(self) -> list[FormatSpec]: return self._formats def _validate_field(self, field_name: str): @@ -177,13 +176,13 @@ def build(self) -> Template: license_file = j2_env.get_template('LICENSE.j2').render(template=self.template) self.template.tdk_config.files.append('LICENSE') self.template.files['LICENSE'] = TemplateFile( - filename='LICENSE', + filename=pathlib.Path('LICENSE'), content_type='text/plain', content=license_file.encode(encoding=DEFAULT_ENCODING), ) self.template.files['.env'] = TemplateFile( - filename='.env', + filename=pathlib.Path('.env'), content_type='text/plain', content=create_dot_env().encode(encoding=DEFAULT_ENCODING), ) @@ -191,7 +190,7 @@ def build(self) -> Template: return self.template -def create_dot_env(api_url: Optional[str] = None, api_key: Optional[str] = None) -> str: +def create_dot_env(api_url: str | None = None, api_key: str | None = None) -> str: return j2_env.get_template('env.j2').render(api_url=api_url, api_key=api_key) diff --git a/packages/dsw-tdk/dsw/tdk/validation.py b/packages/dsw-tdk/dsw/tdk/validation.py index 8ce2006c..866d6b2f 100644 --- a/packages/dsw-tdk/dsw/tdk/validation.py +++ b/packages/dsw-tdk/dsw/tdk/validation.py @@ -15,55 +15,82 @@ def __init__(self, field_name: str, message: str): def _validate_required(field_name: str, value) -> List[ValidationError]: if value is None: - return [ValidationError(field_name, 'Missing but it is required')] + return [ValidationError( + field_name=field_name, + message='Missing but it is required', + )] return [] def _validate_non_empty(field_name: str, value) -> List[ValidationError]: if value is not None and len(value.strip()) == 0: - return [ValidationError(field_name, 'Cannot be empty or only-whitespace')] + return [ValidationError( + field_name=field_name, + message='Cannot be empty or only-whitespace', + )] return [] def _validate_content_type(field_name: str, value) -> List[ValidationError]: if value is not None and re.match(REGEX_MIME_TYPE, value) is None: - return [ValidationError(field_name, 'Content type should be valid IANA media type')] + return [ValidationError( + field_name=field_name, + message='Content type should be valid IANA media type', + )] return [] def _validate_extension(field_name: str, value) -> List[ValidationError]: if value is not None and re.match(REGEX_ORGANIZATION_ID, value) is None: - return [ValidationError(field_name, 'File extension should contain only letters, numbers and dots (inside-only)')] + return [ValidationError( + field_name=field_name, + message='File extension should contain only letters, numbers and dots (inside-only)', + )] return [] def _validate_organization_id(field_name: str, value) -> List[ValidationError]: if value is not None and re.match(REGEX_ORGANIZATION_ID, value) is None: - return [ValidationError(field_name, 'Organization ID may contain only letters, numbers, and period (inside-only)')] + return [ValidationError( + field_name=field_name, + message='Organization ID may contain only letters, numbers, and period (inside-only)', + )] return [] def _validate_template_id(field_name: str, value) -> List[ValidationError]: if value is not None and re.match(REGEX_TEMPLATE_ID, value) is None: - return [ValidationError(field_name, 'Template ID may contain only letters, numbers, and dash (inside-only)')] + return [ValidationError( + field_name=field_name, + message='Template ID may contain only letters, numbers, and dash (inside-only)', + )] return [] def _validate_km_id(field_name: str, value) -> List[ValidationError]: if value is not None and re.match(REGEX_KM_ID, value) is None: - return [ValidationError(field_name, 'KM ID may contain only letters, numbers, and dash (inside-only)')] + return [ValidationError( + field_name=field_name, + message='KM ID may contain only letters, numbers, and dash (inside-only)', + )] return [] def _validate_version(field_name: str, value) -> List[ValidationError]: if value is not None and re.match(REGEX_SEMVER, value) is None: - return [ValidationError(field_name, 'Version must be in semver format ..')] + return [ValidationError( + field_name=field_name, + message='Version must be in semver format ..', + )] return [] def _validate_natural(field_name: str, value) -> List[ValidationError]: if value is not None and (not isinstance(value, int) or value < 1): - return [ValidationError(field_name, 'It must be positive integer')] + return [ValidationError( + field_name=field_name, + message='It must be positive integer', + )] return [] @@ -72,16 +99,31 @@ def _validate_package_id(field_name: str, value: str) -> List[ValidationError]: if value is None: return res if not isinstance(value, str): - return [ValidationError(field_name, 'Package ID is not a string')] + return [ValidationError( + field_name=field_name, + message='Package ID is not a string', + )] parts = value.split(':') if len(parts) != 3: - res.append(ValidationError(field_name, 'Package ID is not valid (only {len(parts)} parts)')) + res.append(ValidationError( + field_name=field_name, + message='Package ID is not valid (only {len(parts)} parts)', + )) if re.match(REGEX_ORGANIZATION_ID, parts[0]) is None: - res.append(ValidationError(field_name, 'Package ID contains invalid organization id')) + res.append(ValidationError( + field_name=field_name, + message='Package ID contains invalid organization id', + )) if re.match(REGEX_KM_ID, parts[1]) is None: - res.append(ValidationError(field_name, 'Package ID contains invalid KM id')) + res.append(ValidationError( + field_name=field_name, + message='Package ID contains invalid KM id', + )) if re.match(REGEX_SEMVER, parts[2]) is None: - res.append(ValidationError(field_name, 'Package ID contains invalid version')) + res.append(ValidationError( + field_name=field_name, + message='Package ID contains invalid version', + )) return res @@ -91,11 +133,20 @@ def _validate_jinja_options(field_name: str, value: Dict[str, str]) -> List[Vali return res for k in ('template', 'content-type', 'extension'): if k not in value.keys(): - res.append(ValidationError(field_name, 'Jinja option cannot be left out')) + res.append(ValidationError( + field_name=field_name, + message='Jinja option cannot be left out', + )) elif value[k] is None or not isinstance(value[k], str) or len(value[k]) == 0: - res.append(ValidationError(field_name, 'Jinja option cannot be empty')) + res.append(ValidationError( + field_name=field_name, + message='Jinja option cannot be empty', + )) if 'content-type' in value.keys(): - res.extend(_validate_content_type(f'{field_name}.content-type', value['content-type'])) + res.extend(_validate_content_type( + field_name=f'{field_name}.content-type', + value=value['content-type'], + )) return res @@ -129,7 +180,10 @@ def collect_errors(self, entity, field_name_prefix: str = '') -> List[Validation if field_name.startswith('__'): continue for validator in validators: - result.extend(validator(field_name_prefix + field_name, getattr(entity, field_name))) + result.extend(validator( + field_name=f'{field_name_prefix}{field_name}', + value=getattr(entity, field_name), + )) if '__all' in self.rules.keys(): result.extend(self.rules['__all'](field_name_prefix, entity)) return result diff --git a/packages/dsw-tdk/pyproject.toml b/packages/dsw-tdk/pyproject.toml index 78afdcae..d016f033 100644 --- a/packages/dsw-tdk/pyproject.toml +++ b/packages/dsw-tdk/pyproject.toml @@ -17,13 +17,13 @@ classifiers = [ 'Development Status :: 5 - Production/Stable', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python', - 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Topic :: Internet :: WWW/HTTP', 'Topic :: Utilities', ] -requires-python = '>=3.9, <4' +requires-python = '>=3.10, <4' dependencies = [ 'aiohttp', 'click', diff --git a/scripts/clean.sh b/scripts/clean.sh new file mode 100755 index 00000000..cc71533c --- /dev/null +++ b/scripts/clean.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env sh +set -e + +for PKG in $(ls packages); do + echo "Cleaning $PKG" + rm -rf "packages/$PKG/build" + rm -rf "packages/$PKG/env" + find "packages/$PKG" | grep -E "(.egg-info$)" | xargs rm -rf + find "packages/$PKG" | grep -E "(/__pycache__$|\.pyc$|\.pyo$)" | xargs rm -rf +done