diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8f01da2c..0a4be21c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -7,6 +7,8 @@ jobs: runs-on: ubuntu-latest defaults: run: + # This is necessary for the conda action. It replaces `conda init` as + # the shell does not load ,profile or .bashrc. shell: bash -el {0} steps: - uses: actions/checkout@v3 @@ -24,6 +26,43 @@ jobs: run: python -m pytest --cov narupa python-libraries -n auto -m 'not serial' - name: Serial tests run: python -m pytest --cov narupa python-libraries -n auto -m 'serial' + mypy: + name: Type analysis for python + runs-on: ubuntu-latest + defaults: + run: + # This is necessary for the conda action. It replaces `conda init` as + # the shell does not load ,profile or .bashrc. + shell: bash -el {0} + steps: + - uses: actions/checkout@v3 + - uses: conda-incubator/setup-miniconda@v2 + with: + auto-update-conda: true + miniforge-version: latest + - name: Install narupa dependancies + run: conda install mpi4py openmm + - name: Install tests dependancies + run: python -m pip install -r python-libraries/requirements.test + - name: Compile + run: ./compile.sh --no-dotnet + - name: mypy + run: | + # Mypy accepts paths, modules or packages as inputs. However, only + # packages work reasonably well with packages. So we need to generate + # a list of the packages we want to test. + packages=$(find python-libraries -name __init__.py \ + | sed 's/__init__.py//g' \ + | awk '{split($0, a, /src/); print(a[2])}' \ + | sed 's#/#.#g' \ + | cut -c 2- \ + | sed 's/\.$//g' \ + | grep -v '^$' \ + | grep -v protocol \ + | sed 's/^/-p /g' \ + | grep -v '\..*\.' \ + | tr '\n' ' ') + mypy --ignore-missing-imports --namespace-packages --check-untyped-defs --allow-redefinition $packages csharp-tests: name: C# unit and integration tests runs-on: ubuntu-latest diff --git a/python-libraries/narupa-ase/src/narupa/ase/converter.py b/python-libraries/narupa-ase/src/narupa/ase/converter.py index 52bd8091..bf54363b 100644 --- a/python-libraries/narupa-ase/src/narupa/ase/converter.py +++ b/python-libraries/narupa-ase/src/narupa/ase/converter.py @@ -10,6 +10,7 @@ from ase import Atoms, Atom # type: ignore import itertools import numpy as np +import numpy.typing as npt from narupa.trajectory import FrameData @@ -160,7 +161,7 @@ def add_frame_data_positions_to_ase(frame_data, ase_atoms): ase_atoms.set_positions(np.array(frame_data.particle_positions) * NM_TO_ANG) -def add_ase_positions_to_frame_data(data: FrameData, positions: np.array): +def add_ase_positions_to_frame_data(data: FrameData, positions: npt.NDArray): """ Adds ASE positions to the frame data, converting to nanometers. @@ -196,7 +197,7 @@ def add_ase_topology_to_frame_data(frame_data: FrameData, ase_atoms: Atoms, gene frame_data.residue_names = ["ASE"] frame_data.residue_chains = [0] frame_data.residue_count = 1 - frame_data.residue_ids =['1'] + frame_data.residue_ids = ['1'] frame_data.chain_names = ["A"] frame_data.chain_count = 1 diff --git a/python-libraries/narupa-ase/src/narupa/ase/imd.py b/python-libraries/narupa-ase/src/narupa/ase/imd.py index 758dd528..c7f16d17 100644 --- a/python-libraries/narupa-ase/src/narupa/ase/imd.py +++ b/python-libraries/narupa-ase/src/narupa/ase/imd.py @@ -321,8 +321,9 @@ def _register_commands(self): self._server.register_command(RESET_COMMAND_KEY, self.reset) self._server.register_command(STEP_COMMAND_KEY, self.step) self._server.register_command(PAUSE_COMMAND_KEY, self.pause) - self._server.register_command(SET_DYNAMICS_INTERVAL_COMMAND_KEY, self._set_dynamics_interval) - self._server.register_command(GET_DYNAMICS_INTERVAL_COMMAND_KEY, self._get_dynamics_interval) + # TODO: Fix type annotations for commands with arguments + self._server.register_command(SET_DYNAMICS_INTERVAL_COMMAND_KEY, self._set_dynamics_interval) # type: ignore + self._server.register_command(GET_DYNAMICS_INTERVAL_COMMAND_KEY, self._get_dynamics_interval) # type: ignore def close(self): """ diff --git a/python-libraries/narupa-ase/src/narupa/ase/imd_calculator.py b/python-libraries/narupa-ase/src/narupa/ase/imd_calculator.py index db459ed4..9c567b0f 100644 --- a/python-libraries/narupa-ase/src/narupa/ase/imd_calculator.py +++ b/python-libraries/narupa-ase/src/narupa/ase/imd_calculator.py @@ -41,6 +41,8 @@ class ImdCalculator(Calculator): """ + _previous_interactions: Dict[str, ParticleInteraction] + def __init__(self, imd_state: ImdStateWrapper, calculator: Optional[Calculator] = None, atoms: Optional[Atoms] = None, diff --git a/python-libraries/narupa-ase/src/narupa/ase/openmm/runner.py b/python-libraries/narupa-ase/src/narupa/ase/openmm/runner.py index 812289ad..4e3670ff 100644 --- a/python-libraries/narupa-ase/src/narupa/ase/openmm/runner.py +++ b/python-libraries/narupa-ase/src/narupa/ase/openmm/runner.py @@ -203,10 +203,14 @@ def from_xml(cls, simulation_xml, :param logging_params: The :class:LoggingParams to set up trajectory logging with. :return: An OpenMM simulation runner. """ + if params is not None: + platform = params.platform + else: + platform = None with open(simulation_xml) as infile: simulation = serializer.deserialize_simulation( infile.read(), - platform_name=params.platform, + platform_name=platform, ) return cls(simulation, params, logging_params) diff --git a/python-libraries/narupa-ase/src/narupa/ase/trajectory_logger.py b/python-libraries/narupa-ase/src/narupa/ase/trajectory_logger.py index 4729386f..cb591c60 100644 --- a/python-libraries/narupa-ase/src/narupa/ase/trajectory_logger.py +++ b/python-libraries/narupa-ase/src/narupa/ase/trajectory_logger.py @@ -8,6 +8,7 @@ import datetime import os +from io import TextIOWrapper from typing import Optional import ase.io @@ -83,6 +84,11 @@ class TrajectoryLogger: """ format: str + frame_index: int + atoms: Atoms + base_path: str + parallel: bool + _file_descriptor: Optional[TextIOWrapper] def __init__(self, atoms: Atoms, filename: str, format: Optional[str] = None, timestamp=True, parallel=True, **kwargs): diff --git a/python-libraries/narupa-core/src/narupa/app/app_server.py b/python-libraries/narupa-core/src/narupa/app/app_server.py index f740ce8d..ff806c61 100644 --- a/python-libraries/narupa-core/src/narupa/app/app_server.py +++ b/python-libraries/narupa-core/src/narupa/app/app_server.py @@ -171,7 +171,7 @@ def close(self): Close the application server and all services. """ if self.running_discovery: - self._discovery.close() + self._discovery.close() # type: ignore for service in self._services: service.close() self._server.close() @@ -192,10 +192,10 @@ def _add_service_entry(self, name: str, port: int): def _update_discovery_services(self): try: - self._discovery.unregister_service(self._service_hub) + self._discovery.unregister_service(self._service_hub) # type: ignore except KeyError: pass - self._discovery.register_service(self._service_hub) + self._discovery.register_service(self._service_hub) # type: ignore def qualified_server_name(base_name: str): diff --git a/python-libraries/narupa-core/src/narupa/app/client.py b/python-libraries/narupa-core/src/narupa/app/client.py index 141aba9b..42532c05 100644 --- a/python-libraries/narupa-core/src/narupa/app/client.py +++ b/python-libraries/narupa-core/src/narupa/app/client.py @@ -8,7 +8,7 @@ import warnings from collections import deque, ChainMap from functools import wraps, partial -from typing import Iterable, Tuple, Type, TypeVar +from typing import Iterable, Tuple, Type, TypeVar, cast from typing import Optional, Sequence, Dict, MutableMapping from uuid import uuid4 @@ -162,9 +162,9 @@ class NarupaImdClient: _subscribed_to_all_frames: bool def __init__(self, *, - trajectory_address: Tuple[str, int] = None, - imd_address: Tuple[str, int] = None, - multiplayer_address: Tuple[str, int] = None, + trajectory_address: Optional[Tuple[str, int]] = None, + imd_address: Optional[Tuple[str, int]] = None, + multiplayer_address: Optional[Tuple[str, int]] = None, max_frames=50, all_frames: Optional[bool] = None, ): @@ -320,9 +320,9 @@ def connect_multiplayer(self, address: Tuple[str, int]): self._multiplayer_client = self._connect_client(NarupaClient, address) def connect(self, *, - trajectory_address: Tuple[str, int] = None, - imd_address: Tuple[str, int] = None, - multiplayer_address: Tuple[str, int] = None, + trajectory_address: Optional[Tuple[str, int]] = None, + imd_address: Optional[Tuple[str, int]] = None, + multiplayer_address: Optional[Tuple[str, int]] = None, ): """ Connects the client to all services for which addresses are provided. @@ -582,7 +582,8 @@ def subscribe_multiplayer(self, interval=DEFAULT_STATE_UPDATE_INTERVAL): :raises grpc._channel._Rendezvous: When not connected to a multiplayer service """ - self._multiplayer_client.subscribe_all_state_updates(interval) + multiplayer_client = cast(NarupaClient, self._multiplayer_client) + multiplayer_client.subscribe_all_state_updates(interval) @need_multiplayer def attempt_update_multiplayer_state( @@ -796,8 +797,8 @@ def subscribe_to_all_frames(self): if self._are_framed_subscribed: return self._subscribed_to_all_frames = True - self._frame_client: FrameClient # @need_frames makes sure of that - self._frame_client.subscribe_frames_async(self._on_frame_received) + frame_client = cast(FrameClient, self._frame_client) # @need_frames makes sure of that + frame_client.subscribe_frames_async(self._on_frame_received) self._are_framed_subscribed = True @need_frames @@ -831,8 +832,8 @@ def subscribe_to_frames(self, interval: float = DEFAULT_SUBSCRIPTION_INTERVAL): if self._are_framed_subscribed: return self._subscribed_to_all_frames = False - self._frame_client: FrameClient # @need_frames makes sure of that - self._frame_client.subscribe_last_frames_async( + frame_client = cast(FrameClient, self._frame_client) # @need_frames makes sure of that + frame_client.subscribe_last_frames_async( self._on_frame_received, DEFAULT_SUBSCRIPTION_INTERVAL, ) diff --git a/python-libraries/narupa-core/src/narupa/command/command_service.py b/python-libraries/narupa-core/src/narupa/command/command_service.py index 24f0935c..67361d29 100644 --- a/python-libraries/narupa-core/src/narupa/command/command_service.py +++ b/python-libraries/narupa-core/src/narupa/command/command_service.py @@ -4,7 +4,7 @@ Module providing an implementation of the :class:`CommandServicer`. """ -from typing import Dict, Callable, Optional +from typing import Dict, Callable, Optional, Union from typing import NamedTuple import grpc @@ -22,7 +22,10 @@ from narupa.utilities.key_lockable_map import KeyLockableMap from narupa.utilities.protobuf_utilities import dict_to_struct, struct_to_dict -CommandHandler = Callable[[CommandArguments], Optional[CommandResult]] +CommandHandler = Union[ + Callable[[CommandArguments], Optional[CommandResult]], + Callable[[], Optional[CommandResult]], +] class CommandRegistration(NamedTuple): diff --git a/python-libraries/narupa-core/src/narupa/core/narupa_client.py b/python-libraries/narupa-core/src/narupa/core/narupa_client.py index 87441786..a837f8df 100644 --- a/python-libraries/narupa-core/src/narupa/core/narupa_client.py +++ b/python-libraries/narupa-core/src/narupa/core/narupa_client.py @@ -5,7 +5,7 @@ # Copyright (c) Intangible Realities Lab, University Of Bristol. All rights reserved. # Licensed under the GPL. See License.txt in the project root for license information. -from typing import Dict, Iterable, ContextManager, Union, Any +from typing import Dict, Iterable, ContextManager, Union, Any, Mapping from uuid import uuid4 import grpc @@ -139,7 +139,7 @@ def attempt_update_state(self, change: DictionaryChange) -> bool: def attempt_update_locks( self, - lock_updates: Dict[str, Union[float, None]] + lock_updates: Mapping[str, Union[float, None]] ) -> bool: """ Attempt to acquire and/or free a number of locks on the shared state. diff --git a/python-libraries/narupa-core/src/narupa/core/narupa_server.py b/python-libraries/narupa-core/src/narupa/core/narupa_server.py index d944193e..d7fadbb9 100644 --- a/python-libraries/narupa-core/src/narupa/core/narupa_server.py +++ b/python-libraries/narupa-core/src/narupa/core/narupa_server.py @@ -1,9 +1,9 @@ # Copyright (c) Intangible Realities Lab, University Of Bristol. All rights reserved. # Licensed under the GPL. See License.txt in the project root for license information. -from typing import Callable, Optional, Dict, ContextManager, Set +from typing import Callable, Optional, Dict, ContextManager, Set, Union from narupa.command import CommandService -from narupa.command.command_service import CommandRegistration +from narupa.command.command_service import CommandRegistration, CommandHandler from narupa.core import GrpcServer from narupa.state.state_service import StateService from narupa.utilities.change_buffers import ( @@ -11,6 +11,12 @@ DictionaryChangeBuffer, ) +CommandCallable = Union[ + Callable[[Dict], Optional[Dict]], + Callable[[], None], + Callable[[Dict], None] +] + class NarupaServer(GrpcServer): """ @@ -42,7 +48,7 @@ def commands(self) -> Dict[str, CommandRegistration]: """ return self._command_service.commands - def register_command(self, name: str, callback: Callable[[Dict], Optional[Dict]], + def register_command(self, name: str, callback: CommandHandler, default_arguments: Optional[Dict] = None): """ Registers a command with the :class:`CommandService` running on this server. diff --git a/python-libraries/narupa-core/src/narupa/imd/imd_client.py b/python-libraries/narupa-core/src/narupa/imd/imd_client.py index bec4808f..2a1180e9 100644 --- a/python-libraries/narupa-core/src/narupa/imd/imd_client.py +++ b/python-libraries/narupa-core/src/narupa/imd/imd_client.py @@ -2,7 +2,7 @@ # Licensed under the GPL. See License.txt in the project root for license information. import logging from uuid import uuid4 -from typing import Dict, Set +from typing import Dict, Set, Mapping import grpc from narupa.core import NarupaClient @@ -37,6 +37,10 @@ def interactions(self) -> Dict[str, ParticleInteraction]: key: dict_to_interaction(value) for key, value in state.items() if key.startswith(INTERACTION_PREFIX) + # We can have a misformatted interactions in the shared state. + # Here we ignore them silently. + # TODO: log a warning when this happens. + and isinstance(value, Mapping) } def start_interaction(self) -> str: diff --git a/python-libraries/narupa-core/src/narupa/imd/imd_force.py b/python-libraries/narupa-core/src/narupa/imd/imd_force.py index 22120d1f..cfc269c8 100644 --- a/python-libraries/narupa-core/src/narupa/imd/imd_force.py +++ b/python-libraries/narupa-core/src/narupa/imd/imd_force.py @@ -14,15 +14,16 @@ from typing import Collection, Tuple, Optional, Iterable import numpy as np +import numpy.typing as npt from narupa.imd.particle_interaction import ParticleInteraction def calculate_imd_force( - positions: np.ndarray, - masses: np.ndarray, + positions: npt.NDArray, + masses: npt.NDArray, interactions: Iterable[ParticleInteraction], - periodic_box_lengths: Optional[np.ndarray] = None, -) -> Tuple[float, np.array]: + periodic_box_lengths: Optional[npt.NDArray] = None, +) -> Tuple[float, npt.NDArray]: """ Reference implementation of the Narupa IMD force. @@ -46,8 +47,8 @@ def calculate_imd_force( return total_energy, forces -def apply_single_interaction_force(positions: np.ndarray, masses: np.ndarray, interaction, forces: np.ndarray, - periodic_box_lengths: Optional[np.array] = None) -> float: +def apply_single_interaction_force(positions: npt.NDArray, masses: npt.NDArray, interaction, forces: npt.NDArray, + periodic_box_lengths: Optional[npt.NDArray] = None) -> float: """ Calculates the energy and adds the forces to the particles of a single application of an interaction potential. @@ -154,6 +155,7 @@ def get_center_of_mass_subset( before calculating centre of mass. :return: The center of mass of the subset of positions. """ + subset = list(subset) subset_positions = positions[subset] subset_masses = masses[subset, np.newaxis] subset_total_mass = subset_masses.sum() @@ -194,8 +196,8 @@ def calculate_gaussian_force(particle_position: np.ndarray, interaction_position return energy, force -def calculate_spring_force(particle_position: np.array, interaction_position: np.array, k=1, - periodic_box_lengths: Optional[np.ndarray] = None) -> Tuple[float, np.array]: +def calculate_spring_force(particle_position: npt.NDArray, interaction_position: npt.NDArray, k=1, + periodic_box_lengths: Optional[npt.NDArray] = None) -> Tuple[float, npt.NDArray]: """ Computes the interactive harmonic potential (or spring) force. diff --git a/python-libraries/narupa-core/src/narupa/imd/imd_state.py b/python-libraries/narupa-core/src/narupa/imd/imd_state.py index dd7573f4..c674ada0 100644 --- a/python-libraries/narupa-core/src/narupa/imd/imd_state.py +++ b/python-libraries/narupa-core/src/narupa/imd/imd_state.py @@ -3,7 +3,7 @@ """ Module providing methods for storing ParticleInteractions in a StateDictionary. """ -from typing import Dict, Any +from typing import Dict, Any, Mapping from narupa.state.state_dictionary import StateDictionary from narupa.utilities.change_buffers import DictionaryChange @@ -106,7 +106,7 @@ def interaction_to_dict(interaction: ParticleInteraction) -> Dict[str, Serializa raise TypeError from e -def dict_to_interaction(dictionary: Dict[str, Any]) -> ParticleInteraction: +def dict_to_interaction(dictionary: Mapping[str, Any]) -> ParticleInteraction: kwargs = dict(**dictionary) if 'particles' in kwargs: kwargs['particles'] = [int(i) for i in kwargs['particles']] diff --git a/python-libraries/narupa-core/src/narupa/imd/particle_interaction.py b/python-libraries/narupa-core/src/narupa/imd/particle_interaction.py index f86e2ba3..d85bc3b2 100644 --- a/python-libraries/narupa-core/src/narupa/imd/particle_interaction.py +++ b/python-libraries/narupa-core/src/narupa/imd/particle_interaction.py @@ -4,8 +4,9 @@ Module providing a wrapper class around the protobuf interaction message. """ import math -from typing import Dict, Any, Iterable, Collection +from typing import Dict, Any, Iterable, Collection, Union import numpy as np +import numpy.typing as npt DEFAULT_MAX_FORCE = 20000.0 DEFAULT_FORCE_TYPE = "gaussian" @@ -80,7 +81,7 @@ def scale(self, value: float): self._scale = float(value) @property - def position(self) -> np.array: + def position(self) -> npt.NDArray[np.float64]: """ The position of the interaction in nanometers, which defaults to ``[0 0 0]`` """ @@ -101,7 +102,12 @@ def particles(self) -> np.ndarray: return self._particles @particles.setter - def particles(self, particles: Collection[int]): + def particles(self, particles: Union[list[int], tuple[int]]): + # We would like to type the `particles` argument as `Collection` and it + # should be precise enough. However, it appears not to be compatible + # with `npt.ArrayLike` in the context `np.unique`; and `ArrayLike` + # allows scalar that do not have a `len` method. Therefore we use a + # type hint that is likely more restrictive than needed. if len(particles) < 2: self._particles = np.array(particles) else: diff --git a/python-libraries/narupa-core/src/narupa/trajectory/frame_data.py b/python-libraries/narupa-core/src/narupa/trajectory/frame_data.py index ed6e13ab..72aebb05 100644 --- a/python-libraries/narupa-core/src/narupa/trajectory/frame_data.py +++ b/python-libraries/narupa-core/src/narupa/trajectory/frame_data.py @@ -3,9 +3,10 @@ from collections import namedtuple from collections.abc import Set import numbers -from typing import Dict, Optional, List +from typing import Dict, Optional, List, TypeVar, Union import numpy as np +import numpy.typing as npt from narupa.protocol import trajectory from narupa.utilities.protobuf_utilities import value_to_object, object_to_value @@ -39,6 +40,9 @@ '_Shortcut', ['record_type', 'key', 'field_type', 'to_python', 'to_raw'] ) +Array2Dfloat = Union[List[List[float]], npt.NDArray[Union[np.float32, np.float64]]] +Array2Dint = Union[List[List[int]], npt.NDArray[Union[np.int_]]] + class MissingDataError(KeyError): """ @@ -140,14 +144,14 @@ class FrameData(metaclass=_FrameDataMeta): The set of shortcuts that contain data is available from the :attr:`used_shortcuts`. """ - bond_pairs: List[List[int]] = _Shortcut( # type: ignore[assignment] + bond_pairs: Array2Dint = _Shortcut( # type: ignore[assignment] key=BOND_PAIRS, record_type='arrays', field_type='index', to_python=_n_by_2, to_raw=_flatten_array) bond_orders: List[float] = _Shortcut( # type: ignore[assignment] key=BOND_ORDERS, record_type='arrays', field_type='float', to_python=_as_is, to_raw=_as_is) - particle_positions: List[List[float]] = _Shortcut( # type: ignore[assignment] + particle_positions: Array2Dfloat = _Shortcut( # type: ignore[assignment] key=PARTICLE_POSITIONS, record_type='arrays', field_type='float', to_python=_n_by_3, to_raw=_flatten_array) particle_elements: List[int] = _Shortcut( # type: ignore[assignment] @@ -169,7 +173,7 @@ class FrameData(metaclass=_FrameDataMeta): residue_names: List[str] = _Shortcut( # type: ignore[assignment] key=RESIDUE_NAMES, record_type='arrays', field_type='string', to_python=_as_is, to_raw=_as_is) - residue_ids: List[int] = _Shortcut( # type: ignore[assignment] + residue_ids: List[str] = _Shortcut( # type: ignore[assignment] key=RESIDUE_IDS, record_type='arrays', field_type='string', to_python=_as_is, to_raw=_as_is) residue_chains: List[int] = _Shortcut( # type: ignore[assignment] @@ -364,7 +368,7 @@ class ValuesView(RecordView): def _convert_to_python(field): return value_to_object(field) - def set(self, key, value): + def set(self, key: str, value): self._raw_record[key].CopyFrom(object_to_value(value)) @@ -379,7 +383,7 @@ class ArraysView(RecordView): def _convert_to_python(field): return field.ListFields()[0][1].values - def set(self, key, value): + def set(self, key: str, value): try: reference_value = value[0] except IndexError: @@ -387,7 +391,7 @@ def set(self, key, value): except TypeError: raise ValueError('Value must be indexable.') - if isinstance(reference_value, numbers.Integral) and reference_value >= 0: + if isinstance(reference_value, numbers.Integral) and int(reference_value) >= 0: type_attribute = 'index_values' elif isinstance(reference_value, numbers.Real): type_attribute = 'float_values' diff --git a/python-libraries/narupa-core/src/narupa/utilities/change_buffers.py b/python-libraries/narupa-core/src/narupa/utilities/change_buffers.py index 98957904..fec840c7 100644 --- a/python-libraries/narupa-core/src/narupa/utilities/change_buffers.py +++ b/python-libraries/narupa-core/src/narupa/utilities/change_buffers.py @@ -87,7 +87,7 @@ def subscribe_changes(self, interval: float = 0) \ with self.create_view() as view: yield from view.subscribe_changes(interval) - def update(self, updates: KeyUpdates = None, removals: KeyRemovals = None): + def update(self, updates: Optional[KeyUpdates] = None, removals: Optional[KeyRemovals] = None): """ Updates the shared dictionary with key values pairs from :updates: and key removals from :removals:. @@ -167,7 +167,7 @@ def freeze(self): self._frozen = True self._any_changes.notify() - def update(self, updates: KeyUpdates = None, removals: KeyRemovals = None): + def update(self, updates: Optional[KeyUpdates] = None, removals: Optional[KeyRemovals] = None): """ Update the known changes from a dictionary of keys that have changed to their new values or have been removed. diff --git a/python-libraries/narupa-core/src/narupa/utilities/protobuf_utilities.py b/python-libraries/narupa-core/src/narupa/utilities/protobuf_utilities.py index e187e7eb..8b89aeb4 100644 --- a/python-libraries/narupa-core/src/narupa/utilities/protobuf_utilities.py +++ b/python-libraries/narupa-core/src/narupa/utilities/protobuf_utilities.py @@ -21,12 +21,13 @@ _Level0Mapping = Mapping[str, Union[_SerializablePrimitive, _TerminalIterable, _TerminalMapping]] Serializable = Union[ _SerializablePrimitive, + _TerminalIterable, + _TerminalMapping, _Level0Iterable, - _Level0Mapping, + _Level0Mapping ] - -def dict_to_struct(dictionary: Dict[str, Serializable]) -> Struct: +def dict_to_struct(dictionary: Mapping[str, Serializable]) -> Struct: """ Converts a python dictionary to a protobuf :class:`Struct`. The dictionary must consist of types that can be serialised. @@ -36,7 +37,7 @@ def dict_to_struct(dictionary: Dict[str, Serializable]) -> Struct: """ struct = Struct() try: - struct.update(dictionary) + struct.update(dictionary) # type: ignore except (ValueError, TypeError, AttributeError): raise ValueError( 'Could not convert object into a protobuf struct. The object to ' diff --git a/python-libraries/narupa-core/tests/imd/test_imd_client.py b/python-libraries/narupa-core/tests/imd/test_imd_client.py index 0f0eca72..8204375f 100644 --- a/python-libraries/narupa-core/tests/imd/test_imd_client.py +++ b/python-libraries/narupa-core/tests/imd/test_imd_client.py @@ -1,6 +1,13 @@ +import itertools import time import pytest from narupa.imd.particle_interaction import ParticleInteraction +from narupa.imd.imd_state import ( + INTERACTION_PREFIX, + interaction_to_dict, + dict_to_interaction, +) +from narupa.utilities.change_buffers import DictionaryChange from .test_imd_server import imd_server_client, imd_server, interaction @@ -132,3 +139,48 @@ def test_subscribe_own_interaction_removed(imd_server_client): time.sleep(IMMEDIATE_REPLY_WAIT_TIME * 5) assert interaction_id not in imd_client.interactions + + +def test_interactions_property(imd_server_client): + """ + Test that the `interactions` property returns what is expected. + """ + imd_server, imd_client = imd_server_client + + # The IMD server does some validation and conversion of the interactions. + # We want to add invalid interactions in the shared state to emulate a + # possible third party server or a bug in the server. So we need to remove + # the conversion step. + imd_server._state_service.state_dictionary.content_updated._callbacks = [] + + # There are the interactions we should get when calling the property. + real_interactions = { + f'{INTERACTION_PREFIX}.first_id': ParticleInteraction(), + f'{INTERACTION_PREFIX}.second_id': ParticleInteraction(), + } + # We should not get these interactions because the ID is not one of an + # interaction. + interactions_with_incompatible_id = { + 'not.a.valid.interaction.id': ParticleInteraction(), + } + interaction_updates = DictionaryChange(updates = { + key: interaction_to_dict(interaction) + for key, interaction + in itertools.chain( + real_interactions.items(), + interactions_with_incompatible_id.items(), + ) + }) + imd_server.update_state(None, interaction_updates) + + # These interactions have an ID matching an interaction, but the content does not match. + incorrect_interactions = { + f'{INTERACTION_PREFIX}.third_id': 'not.an.interaction', + } + interaction_updates = DictionaryChange(updates = incorrect_interactions) + imd_server.update_state(None, interaction_updates) + + imd_client.subscribe_all_state_updates(interval=0) + time.sleep(IMMEDIATE_REPLY_WAIT_TIME) + print(imd_client.interactions) + assert imd_client.interactions.keys() == real_interactions.keys() diff --git a/python-libraries/narupa-essd/src/narupa/essd/server.py b/python-libraries/narupa-essd/src/narupa/essd/server.py index b0ccb9b4..340821a4 100644 --- a/python-libraries/narupa-essd/src/narupa/essd/server.py +++ b/python-libraries/narupa-essd/src/narupa/essd/server.py @@ -19,9 +19,9 @@ import threading import time from socket import socket, AF_INET, SOCK_DGRAM, SOL_SOCKET, SO_BROADCAST, SO_REUSEADDR -from typing import Optional, Dict +from typing import Optional, Dict, List -from narupa.essd.utils import get_broadcast_addresses, is_in_network, resolve_host_broadcast_address +from narupa.essd.utils import get_broadcast_addresses, is_in_network, resolve_host_broadcast_address, InterfaceAddresses from narupa.essd.servicehub import ServiceHub BROADCAST_PORT = 54545 @@ -42,8 +42,9 @@ def configure_reusable_socket() -> socket: class DiscoveryServer: - services: Dict[str, ServiceHub] + services: Dict[ServiceHub, List[InterfaceAddresses]] _socket: socket + _broadcast_thread: Optional[threading.Thread] def __init__(self, broadcast_port: Optional[int] = None, delay=0.5): if broadcast_port is None: @@ -109,7 +110,8 @@ def start(self): def close(self): if self.is_broadcasting: self._cancel = True - self._broadcast_thread.join() + # `is_broadcasting` made sure `_broadcast_thread` is not None + self._broadcast_thread.join() # type: ignore self._broadcast_thread = None self._cancel = False self._socket.close() @@ -140,7 +142,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() - def get_broadcast_addresses_for_service(self, service): + def get_broadcast_addresses_for_service(self, service) -> List[InterfaceAddresses]: address = service.address if address == "[::]": return self.broadcast_addresses diff --git a/python-libraries/narupa-essd/src/narupa/essd/utils.py b/python-libraries/narupa-essd/src/narupa/essd/utils.py index 3ba80ced..a6a53424 100644 --- a/python-libraries/narupa-essd/src/narupa/essd/utils.py +++ b/python-libraries/narupa-essd/src/narupa/essd/utils.py @@ -28,7 +28,7 @@ def get_ipv4_addresses(interfaces: Optional[Iterable[str]] = None) -> List[Inter return ipv4_addrs -def get_broadcast_addresses(interfaces: Optional[Iterable[str]] = None) -> List[Dict[str, str]]: +def get_broadcast_addresses(interfaces: Optional[Iterable[str]] = None) -> List[InterfaceAddresses]: """ Gets all the IPV4 addresses currently available on all the given interfaces that have broadcast addresses. @@ -54,7 +54,7 @@ def get_broadcast_addresses(interfaces: Optional[Iterable[str]] = None) -> List[ def resolve_host_broadcast_address( host: str, - ipv4_addrs: List[InterfaceAddresses] = None, + ipv4_addrs: Optional[List[InterfaceAddresses]] = None, ): try: address = socket.gethostbyname(host) @@ -88,7 +88,11 @@ def is_in_network(address: str, interface_address_entry: InterfaceAddresses) -> broadcast_address = ipaddress.ip_address(interface_address_entry['broadcast']) # to network address e.g. 255.255.255.0 & 192.168.1.255 = 192.168.1.0 network_address = ipaddress.ip_address(int(netmask) & int(broadcast_address)) - ip_network = ipaddress.ip_network((network_address, interface_address_entry['netmask'])) + # The doc and typing stub seem to indicate this is not a valid call of + # ipaddress.ip_network, but this is well tested so we accept it for the + # time being. + # TODO: Fix this line as the types seem to be incorrect. + ip_network = ipaddress.ip_network((network_address, interface_address_entry['netmask'])) # type: ignore except ValueError: raise ValueError(f'Given address {interface_address_entry} is not a valid IP network address.') except KeyError: @@ -101,4 +105,4 @@ def get_broadcastable_ip(): broadcast_addresses = get_broadcast_addresses() if len(broadcast_addresses) == 0: raise RuntimeError("No broadcastable IP addresses could be found on the system!") - return broadcast_addresses[0]['addr'] \ No newline at end of file + return broadcast_addresses[0]['addr'] diff --git a/python-libraries/narupa-lammps/src/narupa/lammps/LammpsImd.py b/python-libraries/narupa-lammps/src/narupa/lammps/LammpsImd.py index 40039cfd..471532e6 100644 --- a/python-libraries/narupa-lammps/src/narupa/lammps/LammpsImd.py +++ b/python-libraries/narupa-lammps/src/narupa/lammps/LammpsImd.py @@ -9,6 +9,7 @@ import logging from typing import List, Optional import numpy as np +import numpy.typing as npt try: from lammps import lammps # type: ignore @@ -70,7 +71,7 @@ class LammpsImd: """ need_to_collect_topology = True - def __init__(self, port: int = None, address: str = "[::]"): + def __init__(self, port: Optional[int] = None, address: str = "[::]"): """ Items that should be initialised on instantiation of lammpsHook class The MPI routines are essential to stop thread issues that cause internal @@ -110,11 +111,11 @@ def __init__(self, port: int = None, address: str = "[::]"): logging.info("Serving on %s ", port) # Set some variables that do not change during the simulation - self.n_atoms = None + self.n_atoms: Optional[int] = None self.units = None - self.units_type = None - self.force_factor = None - self.distance_factor = None + self.units_type: Optional[str] = None + self.force_factor: Optional[float] = None + self.distance_factor: Optional[float] = None self.masses = None self.atom_type = None self.n_atoms_in_dummy = 10 @@ -237,7 +238,7 @@ def _gather_lammps_particle_types(self, lammps_class): def _lammps_positions_to_frame_data(self, frame_data: FrameData, - data_array: np.array) -> FrameData: + data_array: npt.NDArray): """ Convert the flat ctype.c_double data into the frame_data format. for the moment this assumes we are in LAMMPS real units. Its unclear at this stage if is possible @@ -256,7 +257,7 @@ def _lammps_positions_to_frame_data(self, def _add_pos_to_framedata(self, frame_data, positions): frame_data.arrays[PARTICLE_POSITIONS] = positions - def _add_interaction_to_ctype(self, interaction_forces: np.array, lammps_forces): + def _add_interaction_to_ctype(self, interaction_forces: npt.NDArray, lammps_forces): """ Adds the interaction forces to the LAMMPS array @@ -311,6 +312,8 @@ def _manipulate_lammps_internal_matrix(self, lammps_class, positions, distance_f # Collect matrix from LAMMPS forces = self._gather_lammps_array(matrix_type, lammps_class) + if self.n_atoms is None: + raise ValueError("Number of atoms undefined.") # Convert the positions to a 2D, 3N array for use in calculate)imd_force positions_3n = np.ctypeslib.as_array(positions, shape=(self.n_atoms * 3)).reshape(self.n_atoms, 3) # Convert the positions to the narupa internal so that the forces are added in the correct position @@ -328,11 +331,13 @@ def _manipulate_lammps_internal_matrix(self, lammps_class, positions, distance_f if matrix_type == 'f': # Create numpy arrays with the forces to be added + if self.masses is None: + raise ValueError energy_kjmol, forces_kjmol = calculate_imd_force(positions_3n, self.masses, interactions.values()) + self._add_interaction_to_ctype(forces_kjmol, forces) # Convert the positions back so that they will render correctly. positions_3n *= self.distance_factor - self._add_interaction_to_ctype(forces_kjmol, forces) self._return_array_to_lammps(matrix_type, forces, lammps_class) def _extract_positions(self, lammps_class): @@ -379,9 +384,10 @@ def _extract_fundamental_factors(self, lammps_class): """ units = self.find_unit_type(lammps_class) n_atoms = lammps_class.get_natoms() - units_type = LAMMPS_UNITS_CHECK.get(units, None)[0] - distance_factor = LAMMPS_UNITS_CHECK.get(units, None)[1] - force_factor = LAMMPS_UNITS_CHECK.get(units, None)[2] + unit_check = LAMMPS_UNITS_CHECK.get(units, None) + if unit_check is None: + raise KeyError + units_type, distance_factor, force_factor = unit_check self._log_mpi("units : %s %s %s %s", self.me, units_type, force_factor, distance_factor) self.n_atoms = n_atoms self.distance_factor = distance_factor @@ -396,15 +402,17 @@ def _collect_and_send_frame_data(self, positions): :param positions: the flat (1D) positions arrays """ - self.frame_data.particle_count = self.n_atoms - self.frame_data.particle_elements = self.atom_type + if self.n_atoms is not None: + self.frame_data.particle_count = self.n_atoms + if self.atom_type is not None: + self.frame_data.particle_elements = self.atom_type self._lammps_positions_to_frame_data(self.frame_data, positions) # Send frame data self.frame_service.send_frame(self.frame_index, self.frame_data) self.frame_index += 1 - def _log_mpi(self, passed_string: str = None, *args, **kwargs): + def _log_mpi(self, passed_string: str, *args, **kwargs): """ Wrapper function for printing on one core only diff --git a/python-libraries/narupa-lammps/src/narupa/lammps/mock.py b/python-libraries/narupa-lammps/src/narupa/lammps/mock.py index f19d683c..4900e3bf 100644 --- a/python-libraries/narupa-lammps/src/narupa/lammps/mock.py +++ b/python-libraries/narupa-lammps/src/narupa/lammps/mock.py @@ -5,7 +5,7 @@ without LAMMPS installed. """ import ctypes -from typing import List, Union +from typing import List, Union, Optional class MockLammps: @@ -14,7 +14,7 @@ class MockLammps: without having to have LAMMPS installed on a server """ - def __init__(self, n_atoms_in_dummy: int = None): + def __init__(self, n_atoms_in_dummy: Optional[int] = None): # Set a default atom length for tests _DEFAULT_ATOMS = 3 self.n_atoms = n_atoms_in_dummy if n_atoms_in_dummy is not None else _DEFAULT_ATOMS diff --git a/python-libraries/narupa-openmm/src/narupa/openmm/converter.py b/python-libraries/narupa-openmm/src/narupa/openmm/converter.py index c8eba796..96e56ef4 100644 --- a/python-libraries/narupa-openmm/src/narupa/openmm/converter.py +++ b/python-libraries/narupa-openmm/src/narupa/openmm/converter.py @@ -71,7 +71,7 @@ def add_openmm_topology_to_frame_data(data: FrameData, topology: Topology) -> No residue_indices.append(atom.residue.index) for bond in topology.bonds(): - bonds.append((bond[0].index, bond[1].index)) + bonds.append([bond[0].index, bond[1].index]) data.particle_names = atom_names data.particle_elements = elements diff --git a/python-libraries/narupa-openmm/src/narupa/openmm/imd.py b/python-libraries/narupa-openmm/src/narupa/openmm/imd.py index 6b118add..e444cf66 100644 --- a/python-libraries/narupa-openmm/src/narupa/openmm/imd.py +++ b/python-libraries/narupa-openmm/src/narupa/openmm/imd.py @@ -50,10 +50,11 @@ simulation.run(10) """ -from typing import Tuple, Dict, List, Set +from typing import Tuple, Dict, List, Set, Optional import itertools import numpy as np +import numpy.typing as npt import simtk.openmm as mm from simtk.openmm import app @@ -71,6 +72,18 @@ class NarupaImdReporter: + frame_interval: int + force_interval: int + imd_force: mm.CustomExternalForce + imd_state: ImdStateWrapper + frame_publisher: FramePublisher + n_particles: Optional[int] + masses: Optional[npt.NDArray] + positions: Optional[npt.NDArray] + _is_force_dirty: bool + _previous_force_index: Set[int] + _frame_index: int + def __init__( self, frame_interval: int, @@ -186,6 +199,8 @@ def _apply_forces( """ Set the iMD forces based on the user interactions. """ + if self.masses is None: + raise InitialisationError _, forces_kjmol = calculate_imd_force( positions, self.masses, interactions.values(), ) @@ -209,6 +224,16 @@ def _reset_forces(self): self._previous_force_index = set() +class InitialisationError(Exception): + """ + Error raised when the runner has not been initialised correctly and some + attribute have not been set. + + This most likely means that `_on_first_frame` has not been called as + expected. + """ + + def _build_particle_interaction_index_set(interactions: Dict[str, ParticleInteraction]) -> Set[int]: """ Get a set of the indices of the particles involved in interactions. @@ -289,4 +314,4 @@ def get_imd_forces_from_system( if isinstance(force, mm.CustomExternalForce) and force.getEnergyFunction() == IMD_FORCE_EXPRESSION and force.getNumParticles() == system_num_particles - ] \ No newline at end of file + ]