Skip to content

Commit

Permalink
Update code standards
Browse files Browse the repository at this point in the history
  • Loading branch information
GDYendell committed Jul 5, 2024
1 parent ca2574f commit ec7af0d
Show file tree
Hide file tree
Showing 21 changed files with 240 additions and 231 deletions.
77 changes: 49 additions & 28 deletions src/pvi/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
from pathlib import Path
from typing import List, Optional
from typing import Annotated, Optional

import typer

Expand All @@ -26,19 +26,25 @@ def version_callback(value: bool):

@app.callback()
def main(
version: Optional[bool] = typer.Option(
# TODO: typer does not support `<type> | None` yet
# https://github.com/tiangolo/typer/issues/533
version: Optional[bool] = typer.Option( # noqa
None,
"--version",
callback=version_callback,
is_eager=True,
help="Print the version and exit",
)
),
):
"""PVI builder interface"""


@app.command()
def schema(output: Path = typer.Argument(..., help="filename to write the schema to")):
def schema(
output: Annotated[
Path, typer.Argument(..., help="filename to write the schema to")
],
):
"""Write the JSON schema for the pvi interface"""
assert output.name.endswith(
".schema.json"
Expand All @@ -57,18 +63,25 @@ def schema(output: Path = typer.Argument(..., help="filename to write the schema

@app.command()
def format(
output_path: Path = typer.Argument(
..., help="Directory to write output file(s) to"
),
device_path: Path = typer.Argument(..., help="Path to the .pvi.device.yaml file"),
formatter_path: Path = typer.Argument(
..., help="Path to the .pvi.formatter.yaml file"
),
yaml_paths: List[Path] = typer.Option(
[], "--yaml-path", help="Paths to directories with .pvi.device.yaml files"
),
output_path: Annotated[
Path, typer.Argument(..., help="Directory to write output file(s) to")
],
device_path: Annotated[
Path, typer.Argument(..., help="Path to the .pvi.device.yaml file")
],
formatter_path: Annotated[
Path, typer.Argument(..., help="Path to the .pvi.formatter.yaml file")
],
yaml_paths: Annotated[
Optional[list[Path]], # noqa
typer.Option(
..., "--yaml-path", help="Paths to directories with .pvi.device.yaml files"
),
] = None,
):
"""Create screen product from device and formatter YAML"""
yaml_paths = yaml_paths or []

device = Device.deserialize(device_path)
device.deserialize_parents(yaml_paths)

Expand All @@ -78,9 +91,11 @@ def format(

@app.command()
def generate_template(
device_path: Path = typer.Argument(..., help="Path to the .pvi.device.yaml file"),
pv_prefix: str = typer.Argument(..., help="Prefix of PVI PV"),
output_path: Path = typer.Argument(..., help="Output file to generate"),
device_path: Annotated[
Path, typer.Argument(..., help="Path to the .pvi.device.yaml file")
],
pv_prefix: Annotated[str, typer.Argument(..., help="Prefix of PVI PV")],
output_path: Annotated[Path, typer.Argument(..., help="Output file to generate")],
):
"""Create template with info tags for device signals"""
device = Device.deserialize(device_path)
Expand All @@ -89,13 +104,18 @@ def generate_template(

@convert_app.command()
def device(
output: Path = typer.Argument(..., help="Directory to write output file to"),
h: Path = typer.Argument(..., help="Path to the .h file to convert"),
templates: List[Path] = typer.Option(
[], "--template", help="Paths to .template files to convert"
),
output: Annotated[
Path, typer.Argument(..., help="Directory to write output file to")
],
h: Annotated[Path, typer.Argument(..., help="Path to the .h file to convert")],
templates: Annotated[
Optional[list[Path]], # noqa
typer.Option(..., "--template", help="Paths to .template files to convert"),
] = None,
):
"""Convert template to device YAML"""
templates = templates or []

if not output.exists():
os.mkdir(output)

Expand All @@ -108,12 +128,13 @@ def device(

@app.command()
def regroup(
device_path: Path = typer.Argument(
..., help="Path to the device.yaml file to regroup"
),
ui_paths: List[Path] = typer.Argument(
..., help="Paths to the ui files to regroup the PVs by"
),
device_path: Annotated[
Path, typer.Argument(..., help="Path to the device.yaml file to regroup")
],
ui_paths: Annotated[
list[Path],
typer.Argument(..., help="Paths to the ui files to regroup the PVs by"),
],
):
"""Regroup a device.yaml file based on ui files that the PVs appear in"""
device = Device.deserialize(device_path)
Expand Down
12 changes: 7 additions & 5 deletions src/pvi/_convert/_asyn_convert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, ClassVar, List, Optional, Type, cast
from typing import Any, ClassVar, cast

from pydantic import Field

Expand Down Expand Up @@ -40,7 +40,7 @@ def model_post_init(self, __context: Any):
if "DESC" not in self.fields.keys():
self.fields["DESC"] = self.name

def get_parameter_name(self) -> Optional[str]:
def get_parameter_name(self) -> str | None:
# e.g. from: field(INP, "@asyn($(PORT),$(ADDR=0),$(TIMEOUT=1))FILE_PATH")
# extract: FILE_PATH
parameter_name_extractor = r"@asyn\(.*\)(\S+)"
Expand All @@ -53,7 +53,7 @@ def get_parameter_name(self) -> Optional[str]:
parameter_name = match.group(1)
return parameter_name

def asyn_component_type(self) -> Type["AsynParameter"]:
def asyn_component_type(self) -> type["AsynParameter"]:
# For waveform records the data type is defined by DTYP
if self.type == "waveform":
return get_waveform_parameter(self.fields["DTYP"])
Expand Down Expand Up @@ -236,7 +236,7 @@ class AsynFloat64Waveform(AsynWaveform):


WaveformRecordTypes = [AsynWaveform] + cast(
List[Type[AsynWaveform]], rec_subclasses(AsynWaveform)
list[type[AsynWaveform]], rec_subclasses(AsynWaveform)
)


Expand All @@ -248,4 +248,6 @@ def get_waveform_parameter(dtyp: str):
):
return waveform_cls

assert False, f"Waveform type for DTYP {dtyp} not found in {WaveformRecordTypes}"
raise AssertionError(

Check warning on line 251 in src/pvi/_convert/_asyn_convert.py

View check run for this annotation

Codecov / codecov/patch

src/pvi/_convert/_asyn_convert.py#L251

Added line #L251 was not covered by tests
f"Waveform type for DTYP {dtyp} not found in {WaveformRecordTypes}"
)
9 changes: 4 additions & 5 deletions src/pvi/_convert/_parameters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import re
from enum import Enum
from functools import cached_property
from typing import Dict, Optional

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -64,8 +63,8 @@ class DisplayForm(Enum):
class Record(BaseModel):
pv: str # The pv of the record e.g. $(P)$(M)Status
type: str # The record type string e.g. ao, stringin
fields: Dict[str, str] # The record fields
infos: Dict[str, str] # Any infos to be added to the record
fields: dict[str, str] # The record fields
infos: dict[str, str] # Any infos to be added to the record

@cached_property
def name(self) -> str:
Expand All @@ -78,7 +77,7 @@ class Parameter(BaseModel):

invalid: list[str] = ["DESC", "DTYP", "INP", "OUT", "PINI", "VAL"]

def _remove_invalid(self, fields: Dict[str, str]) -> Dict[str, str]:
def _remove_invalid(self, fields: dict[str, str]) -> dict[str, str]:
valid_fields = {
key: value for (key, value) in fields.items() if key not in self.invalid
}
Expand All @@ -89,5 +88,5 @@ def generate_component(self) -> ComponentUnion:


class ReadParameterMixin:
def _get_read_record(self) -> Optional[str]:
def _get_read_record(self) -> str | None:
raise NotImplementedError(self)
35 changes: 17 additions & 18 deletions src/pvi/_convert/_template_convert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import re
from pathlib import Path
from typing import List, Tuple

from pvi.device import (
ComponentUnion,
Expand All @@ -23,7 +22,7 @@


class TemplateConverter:
def __init__(self, templates: List[Path]):
def __init__(self, templates: list[Path]):
self.templates = templates
self._text = [t.read_text() for t in self.templates]

Expand All @@ -35,11 +34,11 @@ def convert(self) -> Tree:
children=template_components,
)
for template, template_components in zip(
self.templates, self._extract_components()
self.templates, self._extract_components(), strict=True
)
]

def _extract_components(self) -> List[List[ComponentUnion]]:
def _extract_components(self) -> list[list[ComponentUnion]]:
components = []
for text in self._text:
record_extractor = RecordExtractor(text)
Expand Down Expand Up @@ -70,7 +69,7 @@ def _extract_record_strs(self):
record_extractor = re.compile(r"\s*^[^#\n]*record\([^{]*{[^}]*}", re.MULTILINE)
return re.findall(record_extractor, self._text)

def _parse_record(self, record_str: str) -> Tuple:
def _parse_record(self, record_str: str) -> tuple:
# extract three groups from a record definition e.g.
# from:
# record(waveform, "$(P)$(R)FilePath")
Expand Down Expand Up @@ -100,7 +99,7 @@ def _parse_record(self, record_str: str) -> Tuple:
raise RecordError(f"Parse failed on record: {record_str}")
return matches[0]

def _extract_fields(self, fields_str: str) -> List[Tuple[str, str]]:
def _extract_fields(self, fields_str: str) -> list[tuple[str, str]]:
# extract two groups from a field e.g.
# from: field(PINI, "YES")
# extract:
Expand All @@ -111,7 +110,7 @@ def _extract_fields(self, fields_str: str) -> List[Tuple[str, str]]:
)
return re.findall(field_extractor, fields_str)

def _extract_infos(self, fields_str: str) -> List[Tuple[str, str]]:
def _extract_infos(self, fields_str: str) -> list[tuple[str, str]]:
# extract two groups from an info tag e.g.
# from: info(autosaveFields, "VAL")
# extract:
Expand All @@ -136,7 +135,7 @@ def _create_asyn_record(self, record_str: str) -> AsynRecord:
record = AsynRecord(pv=record_name, type=record_type, fields=fields, infos=info)
return record

def get_asyn_records(self) -> List[AsynRecord]:
def get_asyn_records(self) -> list[AsynRecord]:
record_strs = self._extract_record_strs()
record_list = []
for record_str in record_strs:
Expand All @@ -149,10 +148,10 @@ def get_asyn_records(self) -> List[AsynRecord]:

class RecordRoleSorter:
@staticmethod
def sort_records(records: List[AsynRecord]) -> List[Parameter]:
def sort_records(records: list[AsynRecord]) -> list[Parameter]:
def _sort_inputs_outputs(
records: List[AsynRecord],
) -> Tuple[List[AsynRecord], List[AsynRecord]]:
records: list[AsynRecord],
) -> tuple[list[AsynRecord], list[AsynRecord]]:
inp_records = [r for r in records if "INP" in r.fields]
write_records = [r for r in records if "OUT" in r.fields]

Expand All @@ -169,7 +168,7 @@ def _sort_inputs_outputs(
return read_records, write_records

read_records, write_records = _sort_inputs_outputs(records)
parameters: List[Parameter] = []
parameters: list[Parameter] = []
parameters += ParameterRoleMatcher.get_actions(read_records, write_records)
parameters += ParameterRoleMatcher.get_readbacks(read_records, write_records)
parameters += ParameterRoleMatcher.get_setting_pairs(
Expand Down Expand Up @@ -241,8 +240,8 @@ def generate_component(self) -> SignalW:
class ParameterRoleMatcher:
@staticmethod
def get_actions(
read_records: List[AsynRecord], write_records: List[AsynRecord]
) -> List[Action]:
read_records: list[AsynRecord], write_records: list[AsynRecord]
) -> list[Action]:
actions = [
Action(write_record=w)
for w in write_records
Expand All @@ -253,8 +252,8 @@ def get_actions(

@staticmethod
def get_readbacks(
read_records: List[AsynRecord], write_records: List[AsynRecord]
) -> List[Readback]:
read_records: list[AsynRecord], write_records: list[AsynRecord]
) -> list[Readback]:
readbacks = [
Readback(read_record=r)
for r in read_records
Expand All @@ -265,8 +264,8 @@ def get_readbacks(

@staticmethod
def get_setting_pairs(
read_records: List[AsynRecord], write_records: List[AsynRecord]
) -> List[SettingPair]:
read_records: list[AsynRecord], write_records: list[AsynRecord]
) -> list[SettingPair]:
setting_pairs = [
SettingPair(read_record=r, write_record=w)
for r in read_records
Expand Down
Loading

0 comments on commit ec7af0d

Please sign in to comment.