Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔧 Replace black/isort/pyupgrade/flake8 with ruff #101

Merged
merged 3 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,16 @@ exclude: >

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
rev: v4.5.0
hooks:
- id: check-json
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://github.com/asottile/pyupgrade
rev: v2.31.1
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.6
hooks:
- id: pyupgrade
args: [--py38-plus]

- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort

- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
- id: ruff
args: [--fix]
- id: ruff-format
12 changes: 3 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,10 @@ pytest --lammps-workdir "test_workdir"

### Pre-commit

The code is formatted and linted using [pre-commit](https://pre-commit.com/), so that the code conform to the standard. One must simply install the repository with the `pre-commit` extra dependencies:
The code is formatted and linted using [pre-commit](https://pre-commit.com/), so that the code conform to the standard:

```shell
cd aiida-lammps
pip install -e .[pre-commit]
pre-commit run --all
```
or to automate runs, triggered before each commit:
Expand All @@ -120,11 +119,6 @@ or to automate runs, triggered before each commit:
pre-commit install
```

The pre-commit can also be run in an isolated environment via `tox` with:

```shell
tox -e pre-commit
```

## License
The `aiida-lammps` plugin package is released under the MIT license. See the `LICENSE.txt` file for more details.

The `aiida-lammps` plugin package is released under the MIT license. See the `LICENSE` file for more details.
9 changes: 3 additions & 6 deletions aiida_lammps/calculations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
the input structure and whether or not a restart file is provided.
"""
import os
from typing import Union
from typing import ClassVar, Dict, Union

from aiida import orm
from aiida.common import datastructures
Expand All @@ -30,7 +30,7 @@ class LammpsBaseCalculation(CalcJob):
the input structure and whether or not a restart file is provided.
"""

_DEFAULT_VARIABLES = {
_DEFAULT_VARIABLES: ClassVar[Dict[str, str]] = {
"input_filename": "input.in",
"structure_filename": "structure.dat",
"output_filename": "lammps.out",
Expand Down Expand Up @@ -325,10 +325,7 @@ def prepare_for_submission(self, folder):
else:
_parameters = {}

if "settings" in self.inputs:
settings = self.inputs.settings.get_dict()
else:
settings = {}
settings = self.inputs.settings.get_dict() if "settings" in self.inputs else {}

# Set the remote copy list and the symlink so that if one needs to use restartfiles from
# a previous calculations one can do so without problems
Expand Down
5 changes: 2 additions & 3 deletions aiida_lammps/calculations/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,8 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
# namespace, falling back to the filename of the ``SinglefileData`` node if not defined.
filename = filenames.get(key, node.filename)

with folder.open(filename, "wb") as target:
with node.open(mode="rb") as source:
shutil.copyfileobj(source, target)
with folder.open(filename, "wb") as target, node.open(mode="rb") as source:
shutil.copyfileobj(source, target)

provenance_exclude_list.append(filename)

Expand Down
32 changes: 16 additions & 16 deletions aiida_lammps/data/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class written by Sebastiaan Huber.
import json
import os
import pathlib
from typing import BinaryIO, List, Optional, Union
from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Union
import warnings

from aiida import orm, plugins
Expand Down Expand Up @@ -164,7 +164,7 @@ class written by Sebaastian Huber.
"lammps_potentials.json",
)

_extra_keys = {
_extra_keys: ClassVar[Dict[str, Any]] = {
"title": {"validator": _validate_string},
"developer": {"validator": _validate_string_list},
"publication_year": {"validator": _validate_datetime},
Expand All @@ -190,12 +190,12 @@ class written by Sebaastian Huber.
def get_or_create(
cls,
source: Union[str, pathlib.Path, BinaryIO],
filename: str = None,
pair_style: str = None,
species: list = None,
atom_style: str = None,
units: str = None,
extra_tags: dict = None,
filename: Optional[str] = None,
pair_style: Optional[str] = None,
species: Optional[list] = None,
atom_style: Optional[str] = None,
units: Optional[str] = None,
extra_tags: Optional[dict] = None,
):
"""
Get lammps potential data node from database or create a new one.
Expand Down Expand Up @@ -290,7 +290,7 @@ def prepare_source(cls, source: Union[str, pathlib.Path, BinaryIO]) -> BinaryIO:
) and not cls.is_readable_byte_stream(source):
raise TypeError(
"`source` should be a `str` or `pathlib.Path` filepath on "
+ f"disk or a stream of bytes, got: {source}"
f"disk or a stream of bytes, got: {source}"
)

if isinstance(source, (str, pathlib.Path)):
Expand Down Expand Up @@ -328,7 +328,7 @@ def validate_pair_style(self, pair_style: str):
"""
if pair_style is None:
raise TypeError("The pair_style of the potential must be provided.")
if pair_style not in self.default_potential_info.keys():
if pair_style not in self.default_potential_info:
raise KeyError(f'The pair_style "{pair_style}" is not valid')
self.base.attributes.set("pair_style", pair_style)

Expand Down Expand Up @@ -440,12 +440,12 @@ def validate_extra_tags(self, extra_tags: dict):
def set_file(
self,
source: Union[str, pathlib.Path, BinaryIO],
filename: str = None,
pair_style: str = None,
species: list = None,
atom_style: str = None,
units: str = None,
extra_tags: dict = None,
filename: Optional[str] = None,
pair_style: Optional[str] = None,
species: Optional[list] = None,
atom_style: Optional[str] = None,
units: Optional[str] = None,
extra_tags: Optional[dict] = None,
**kwargs,
):
"""Set the file content.
Expand Down
40 changes: 18 additions & 22 deletions aiida_lammps/data/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def set_from_fileobj(self, fileobj, aliases=None):
self._compression_method,
) as zip_file:
for step_id, trajectory_step in enumerate(iter_trajectories(fileobj)):

# extract data to store in attributes
time_steps.append(trajectory_step.timestep)
if number_atoms is None:
Expand Down Expand Up @@ -135,13 +134,13 @@ def set_from_fileobj(self, fileobj, aliases=None):

self.base.attributes.set("number_steps", len(time_steps))
self.base.attributes.set("number_atoms", number_atoms)
self.base.attributes.set("field_names", list(sorted(field_names)))
self.base.attributes.set("field_names", sorted(field_names))
self.base.attributes.set("trajectory_filename", self._trajectory_filename)
self.base.attributes.set("timestep_filename", self._timestep_filename)
self.base.attributes.set("zip_prefix", self._zip_prefix)
self.base.attributes.set("compression_method", self._compression_method)
self.base.attributes.set("aliases", aliases)
self.base.attributes.set("elements", list(sorted(elements)))
self.base.attributes.set("elements", sorted(elements))

@property
def number_steps(self):
Expand Down Expand Up @@ -199,14 +198,12 @@ def get_step_string(self, step_idx):
with self.base.repository.open(
self.base.attributes.get("trajectory_filename"),
mode="rb",
) as handle:
with ZipFile(
handle,
"r",
self.base.attributes.get("compression_method"),
) as zip_file:
with zip_file.open(zip_name, "r") as step_file:
content = step_file.read()
) as handle, ZipFile(
handle,
"r",
self.base.attributes.get("compression_method"),
) as zip_file, zip_file.open(zip_name, "r") as step_file:
content = step_file.read()
return content.decode("utf8")

def get_step_data(self, step_idx):
Expand All @@ -224,17 +221,16 @@ def iter_step_strings(self, steps=None):
with self.base.repository.open(
self.base.attributes.get("trajectory_filename"),
mode="rb",
) as handle:
with ZipFile(
handle,
"r",
self.base.attributes.get("compression_method"),
) as zip_file:
for step_idx in steps:
zip_name = f'{self.base.attributes.get("zip_prefix")}{step_idx}'
with zip_file.open(zip_name) as step_file:
content = step_file.read()
yield content
) as handle, ZipFile(
handle,
"r",
self.base.attributes.get("compression_method"),
) as zip_file:
for step_idx in steps:
zip_name = f'{self.base.attributes.get("zip_prefix")}{step_idx}'
with zip_file.open(zip_name) as step_file:
content = step_file.read()
yield content

def get_step_structure(
self,
Expand Down
22 changes: 12 additions & 10 deletions aiida_lammps/parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,16 +208,18 @@ def parse_restartfile(

restart_filename = ""

if parameters.get("restart", {}).get("print_final", False):
if input_restart_filename in list_of_files:
with self.node.outputs.retrieved.base.repository.open(
input_restart_filename,
mode="rb",
) as handle:
restart_file = orm.SinglefileData(handle)
self.out("restartfile", restart_file)
restart_found = True
restart_filename = input_restart_filename
if (
parameters.get("restart", {}).get("print_final", False)
and input_restart_filename in list_of_files
):
with self.node.outputs.retrieved.base.repository.open(
input_restart_filename,
mode="rb",
) as handle:
restart_file = orm.SinglefileData(handle)
self.out("restartfile", restart_file)
restart_found = True
restart_filename = input_restart_filename

if (
parameters.get("restart", {}).get("print_intermediate", False)
Expand Down
Loading