Skip to content

Commit

Permalink
Merge pull request #11 from IRL2/mypy
Browse files Browse the repository at this point in the history
Run mypy in CI
  • Loading branch information
jbarnoud authored Sep 15, 2023
2 parents d0df6fa + 8a64db7 commit 89c8829
Show file tree
Hide file tree
Showing 25 changed files with 253 additions and 82 deletions.
39 changes: 39 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ jobs:
runs-on: ubuntu-latest
defaults:
run:
# This is necessary for the conda action. It replaces `conda init` as
# the shell does not load ,profile or .bashrc.
shell: bash -el {0}
steps:
- uses: actions/checkout@v3
Expand All @@ -24,6 +26,43 @@ jobs:
run: python -m pytest --cov narupa python-libraries -n auto -m 'not serial'
- name: Serial tests
run: python -m pytest --cov narupa python-libraries -n auto -m 'serial'
mypy:
name: Type analysis for python
runs-on: ubuntu-latest
defaults:
run:
# This is necessary for the conda action. It replaces `conda init` as
# the shell does not load ,profile or .bashrc.
shell: bash -el {0}
steps:
- uses: actions/checkout@v3
- uses: conda-incubator/setup-miniconda@v2
with:
auto-update-conda: true
miniforge-version: latest
- name: Install narupa dependancies
run: conda install mpi4py openmm
- name: Install tests dependancies
run: python -m pip install -r python-libraries/requirements.test
- name: Compile
run: ./compile.sh --no-dotnet
- name: mypy
run: |
# Mypy accepts paths, modules or packages as inputs. However, only
# packages work reasonably well with packages. So we need to generate
# a list of the packages we want to test.
packages=$(find python-libraries -name __init__.py \
| sed 's/__init__.py//g' \
| awk '{split($0, a, /src/); print(a[2])}' \
| sed 's#/#.#g' \
| cut -c 2- \
| sed 's/\.$//g' \
| grep -v '^$' \
| grep -v protocol \
| sed 's/^/-p /g' \
| grep -v '\..*\.' \
| tr '\n' ' ')
mypy --ignore-missing-imports --namespace-packages --check-untyped-defs --allow-redefinition $packages
csharp-tests:
name: C# unit and integration tests
runs-on: ubuntu-latest
Expand Down
5 changes: 3 additions & 2 deletions python-libraries/narupa-ase/src/narupa/ase/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ase import Atoms, Atom # type: ignore
import itertools
import numpy as np
import numpy.typing as npt

from narupa.trajectory import FrameData

Expand Down Expand Up @@ -160,7 +161,7 @@ def add_frame_data_positions_to_ase(frame_data, ase_atoms):
ase_atoms.set_positions(np.array(frame_data.particle_positions) * NM_TO_ANG)


def add_ase_positions_to_frame_data(data: FrameData, positions: np.array):
def add_ase_positions_to_frame_data(data: FrameData, positions: npt.NDArray):
"""
Adds ASE positions to the frame data, converting to nanometers.
Expand Down Expand Up @@ -196,7 +197,7 @@ def add_ase_topology_to_frame_data(frame_data: FrameData, ase_atoms: Atoms, gene
frame_data.residue_names = ["ASE"]
frame_data.residue_chains = [0]
frame_data.residue_count = 1
frame_data.residue_ids =['1']
frame_data.residue_ids = ['1']

frame_data.chain_names = ["A"]
frame_data.chain_count = 1
Expand Down
5 changes: 3 additions & 2 deletions python-libraries/narupa-ase/src/narupa/ase/imd.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,9 @@ def _register_commands(self):
self._server.register_command(RESET_COMMAND_KEY, self.reset)
self._server.register_command(STEP_COMMAND_KEY, self.step)
self._server.register_command(PAUSE_COMMAND_KEY, self.pause)
self._server.register_command(SET_DYNAMICS_INTERVAL_COMMAND_KEY, self._set_dynamics_interval)
self._server.register_command(GET_DYNAMICS_INTERVAL_COMMAND_KEY, self._get_dynamics_interval)
# TODO: Fix type annotations for commands with arguments
self._server.register_command(SET_DYNAMICS_INTERVAL_COMMAND_KEY, self._set_dynamics_interval) # type: ignore
self._server.register_command(GET_DYNAMICS_INTERVAL_COMMAND_KEY, self._get_dynamics_interval) # type: ignore

def close(self):
"""
Expand Down
2 changes: 2 additions & 0 deletions python-libraries/narupa-ase/src/narupa/ase/imd_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class ImdCalculator(Calculator):
"""

_previous_interactions: Dict[str, ParticleInteraction]

def __init__(self, imd_state: ImdStateWrapper,
calculator: Optional[Calculator] = None,
atoms: Optional[Atoms] = None,
Expand Down
6 changes: 5 additions & 1 deletion python-libraries/narupa-ase/src/narupa/ase/openmm/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,14 @@ def from_xml(cls, simulation_xml,
:param logging_params: The :class:LoggingParams to set up trajectory logging with.
:return: An OpenMM simulation runner.
"""
if params is not None:
platform = params.platform
else:
platform = None
with open(simulation_xml) as infile:
simulation = serializer.deserialize_simulation(
infile.read(),
platform_name=params.platform,
platform_name=platform,
)
return cls(simulation, params, logging_params)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import datetime
import os
from io import TextIOWrapper
from typing import Optional

import ase.io
Expand Down Expand Up @@ -83,6 +84,11 @@ class TrajectoryLogger:
"""
format: str
frame_index: int
atoms: Atoms
base_path: str
parallel: bool
_file_descriptor: Optional[TextIOWrapper]

def __init__(self, atoms: Atoms, filename: str, format: Optional[str] = None, timestamp=True, parallel=True,
**kwargs):
Expand Down
6 changes: 3 additions & 3 deletions python-libraries/narupa-core/src/narupa/app/app_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def close(self):
Close the application server and all services.
"""
if self.running_discovery:
self._discovery.close()
self._discovery.close() # type: ignore
for service in self._services:
service.close()
self._server.close()
Expand All @@ -192,10 +192,10 @@ def _add_service_entry(self, name: str, port: int):

def _update_discovery_services(self):
try:
self._discovery.unregister_service(self._service_hub)
self._discovery.unregister_service(self._service_hub) # type: ignore
except KeyError:
pass
self._discovery.register_service(self._service_hub)
self._discovery.register_service(self._service_hub) # type: ignore


def qualified_server_name(base_name: str):
Expand Down
25 changes: 13 additions & 12 deletions python-libraries/narupa-core/src/narupa/app/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings
from collections import deque, ChainMap
from functools import wraps, partial
from typing import Iterable, Tuple, Type, TypeVar
from typing import Iterable, Tuple, Type, TypeVar, cast
from typing import Optional, Sequence, Dict, MutableMapping
from uuid import uuid4

Expand Down Expand Up @@ -162,9 +162,9 @@ class NarupaImdClient:
_subscribed_to_all_frames: bool

def __init__(self, *,
trajectory_address: Tuple[str, int] = None,
imd_address: Tuple[str, int] = None,
multiplayer_address: Tuple[str, int] = None,
trajectory_address: Optional[Tuple[str, int]] = None,
imd_address: Optional[Tuple[str, int]] = None,
multiplayer_address: Optional[Tuple[str, int]] = None,
max_frames=50,
all_frames: Optional[bool] = None,
):
Expand Down Expand Up @@ -320,9 +320,9 @@ def connect_multiplayer(self, address: Tuple[str, int]):
self._multiplayer_client = self._connect_client(NarupaClient, address)

def connect(self, *,
trajectory_address: Tuple[str, int] = None,
imd_address: Tuple[str, int] = None,
multiplayer_address: Tuple[str, int] = None,
trajectory_address: Optional[Tuple[str, int]] = None,
imd_address: Optional[Tuple[str, int]] = None,
multiplayer_address: Optional[Tuple[str, int]] = None,
):
"""
Connects the client to all services for which addresses are provided.
Expand Down Expand Up @@ -582,7 +582,8 @@ def subscribe_multiplayer(self, interval=DEFAULT_STATE_UPDATE_INTERVAL):
:raises grpc._channel._Rendezvous: When not connected to a
multiplayer service
"""
self._multiplayer_client.subscribe_all_state_updates(interval)
multiplayer_client = cast(NarupaClient, self._multiplayer_client)
multiplayer_client.subscribe_all_state_updates(interval)

@need_multiplayer
def attempt_update_multiplayer_state(
Expand Down Expand Up @@ -796,8 +797,8 @@ def subscribe_to_all_frames(self):
if self._are_framed_subscribed:
return
self._subscribed_to_all_frames = True
self._frame_client: FrameClient # @need_frames makes sure of that
self._frame_client.subscribe_frames_async(self._on_frame_received)
frame_client = cast(FrameClient, self._frame_client) # @need_frames makes sure of that
frame_client.subscribe_frames_async(self._on_frame_received)
self._are_framed_subscribed = True

@need_frames
Expand Down Expand Up @@ -831,8 +832,8 @@ def subscribe_to_frames(self, interval: float = DEFAULT_SUBSCRIPTION_INTERVAL):
if self._are_framed_subscribed:
return
self._subscribed_to_all_frames = False
self._frame_client: FrameClient # @need_frames makes sure of that
self._frame_client.subscribe_last_frames_async(
frame_client = cast(FrameClient, self._frame_client) # @need_frames makes sure of that
frame_client.subscribe_last_frames_async(
self._on_frame_received,
DEFAULT_SUBSCRIPTION_INTERVAL,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Module providing an implementation of the :class:`CommandServicer`.
"""
from typing import Dict, Callable, Optional
from typing import Dict, Callable, Optional, Union
from typing import NamedTuple

import grpc
Expand All @@ -22,7 +22,10 @@
from narupa.utilities.key_lockable_map import KeyLockableMap
from narupa.utilities.protobuf_utilities import dict_to_struct, struct_to_dict

CommandHandler = Callable[[CommandArguments], Optional[CommandResult]]
CommandHandler = Union[
Callable[[CommandArguments], Optional[CommandResult]],
Callable[[], Optional[CommandResult]],
]


class CommandRegistration(NamedTuple):
Expand Down
4 changes: 2 additions & 2 deletions python-libraries/narupa-core/src/narupa/core/narupa_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Copyright (c) Intangible Realities Lab, University Of Bristol. All rights reserved.
# Licensed under the GPL. See License.txt in the project root for license information.

from typing import Dict, Iterable, ContextManager, Union, Any
from typing import Dict, Iterable, ContextManager, Union, Any, Mapping
from uuid import uuid4

import grpc
Expand Down Expand Up @@ -139,7 +139,7 @@ def attempt_update_state(self, change: DictionaryChange) -> bool:

def attempt_update_locks(
self,
lock_updates: Dict[str, Union[float, None]]
lock_updates: Mapping[str, Union[float, None]]
) -> bool:
"""
Attempt to acquire and/or free a number of locks on the shared state.
Expand Down
12 changes: 9 additions & 3 deletions python-libraries/narupa-core/src/narupa/core/narupa_server.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
# Copyright (c) Intangible Realities Lab, University Of Bristol. All rights reserved.
# Licensed under the GPL. See License.txt in the project root for license information.
from typing import Callable, Optional, Dict, ContextManager, Set
from typing import Callable, Optional, Dict, ContextManager, Set, Union

from narupa.command import CommandService
from narupa.command.command_service import CommandRegistration
from narupa.command.command_service import CommandRegistration, CommandHandler
from narupa.core import GrpcServer
from narupa.state.state_service import StateService
from narupa.utilities.change_buffers import (
DictionaryChange,
DictionaryChangeBuffer,
)

CommandCallable = Union[
Callable[[Dict], Optional[Dict]],
Callable[[], None],
Callable[[Dict], None]
]


class NarupaServer(GrpcServer):
"""
Expand Down Expand Up @@ -42,7 +48,7 @@ def commands(self) -> Dict[str, CommandRegistration]:
"""
return self._command_service.commands

def register_command(self, name: str, callback: Callable[[Dict], Optional[Dict]],
def register_command(self, name: str, callback: CommandHandler,
default_arguments: Optional[Dict] = None):
"""
Registers a command with the :class:`CommandService` running on this server.
Expand Down
6 changes: 5 additions & 1 deletion python-libraries/narupa-core/src/narupa/imd/imd_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the GPL. See License.txt in the project root for license information.
import logging
from uuid import uuid4
from typing import Dict, Set
from typing import Dict, Set, Mapping

import grpc
from narupa.core import NarupaClient
Expand Down Expand Up @@ -37,6 +37,10 @@ def interactions(self) -> Dict[str, ParticleInteraction]:
key: dict_to_interaction(value)
for key, value in state.items()
if key.startswith(INTERACTION_PREFIX)
# We can have a misformatted interactions in the shared state.
# Here we ignore them silently.
# TODO: log a warning when this happens.
and isinstance(value, Mapping)
}

def start_interaction(self) -> str:
Expand Down
18 changes: 10 additions & 8 deletions python-libraries/narupa-core/src/narupa/imd/imd_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
from typing import Collection, Tuple, Optional, Iterable

import numpy as np
import numpy.typing as npt
from narupa.imd.particle_interaction import ParticleInteraction


def calculate_imd_force(
positions: np.ndarray,
masses: np.ndarray,
positions: npt.NDArray,
masses: npt.NDArray,
interactions: Iterable[ParticleInteraction],
periodic_box_lengths: Optional[np.ndarray] = None,
) -> Tuple[float, np.array]:
periodic_box_lengths: Optional[npt.NDArray] = None,
) -> Tuple[float, npt.NDArray]:
"""
Reference implementation of the Narupa IMD force.
Expand All @@ -46,8 +47,8 @@ def calculate_imd_force(
return total_energy, forces


def apply_single_interaction_force(positions: np.ndarray, masses: np.ndarray, interaction, forces: np.ndarray,
periodic_box_lengths: Optional[np.array] = None) -> float:
def apply_single_interaction_force(positions: npt.NDArray, masses: npt.NDArray, interaction, forces: npt.NDArray,
periodic_box_lengths: Optional[npt.NDArray] = None) -> float:
"""
Calculates the energy and adds the forces to the particles of a single application of an interaction potential.
Expand Down Expand Up @@ -154,6 +155,7 @@ def get_center_of_mass_subset(
before calculating centre of mass.
:return: The center of mass of the subset of positions.
"""
subset = list(subset)
subset_positions = positions[subset]
subset_masses = masses[subset, np.newaxis]
subset_total_mass = subset_masses.sum()
Expand Down Expand Up @@ -194,8 +196,8 @@ def calculate_gaussian_force(particle_position: np.ndarray, interaction_position
return energy, force


def calculate_spring_force(particle_position: np.array, interaction_position: np.array, k=1,
periodic_box_lengths: Optional[np.ndarray] = None) -> Tuple[float, np.array]:
def calculate_spring_force(particle_position: npt.NDArray, interaction_position: npt.NDArray, k=1,
periodic_box_lengths: Optional[npt.NDArray] = None) -> Tuple[float, npt.NDArray]:
"""
Computes the interactive harmonic potential (or spring) force.
Expand Down
4 changes: 2 additions & 2 deletions python-libraries/narupa-core/src/narupa/imd/imd_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
Module providing methods for storing ParticleInteractions in a StateDictionary.
"""
from typing import Dict, Any
from typing import Dict, Any, Mapping

from narupa.state.state_dictionary import StateDictionary
from narupa.utilities.change_buffers import DictionaryChange
Expand Down Expand Up @@ -106,7 +106,7 @@ def interaction_to_dict(interaction: ParticleInteraction) -> Dict[str, Serializa
raise TypeError from e


def dict_to_interaction(dictionary: Dict[str, Any]) -> ParticleInteraction:
def dict_to_interaction(dictionary: Mapping[str, Any]) -> ParticleInteraction:
kwargs = dict(**dictionary)
if 'particles' in kwargs:
kwargs['particles'] = [int(i) for i in kwargs['particles']]
Expand Down
Loading

0 comments on commit 89c8829

Please sign in to comment.