Skip to content

Commit

Permalink
Add more rules to ruff linting
Browse files Browse the repository at this point in the history
  • Loading branch information
larsevj committed Sep 10, 2024
1 parent f02295d commit fc03b62
Show file tree
Hide file tree
Showing 20 changed files with 309 additions and 297 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/style.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: fmu-ensemble
name: style

on:
push:
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
rev: v0.6.4
hooks:
- id: ruff
args: [ --extend-select, I, --fix ]
args: [ --fix ]
- id: ruff-format

exclude: "tests/data/testensemble-reek001"
39 changes: 39 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,42 @@ write_to = "src/fmu/ensemble/version.py"
ignore_directives = ["argparse", "automodule"]
# This looks like a bug in rstcheck:
ignore_messages = "Hyperlink target .* is not referenced"

[tool.ruff]
src = ["src"]
line-length = 88

[tool.ruff.lint]
select = [
"W", # pycodestyle
"I", # isort
"B", # flake-8-bugbear
"SIM", # flake-8-simplify
"F", # pyflakes
"PL", # pylint
"NPY", # numpy specific rules
"C4", # flake8-comprehensions
]
ignore = ["PLW2901", # redefined-loop-name
"PLR2004", # magic-value-comparison
"PLR0915", # too-many-statements
"PLR0912", # too-many-branches
"PLR0911", # too-many-return-statements
"PLC2701", # import-private-name
"PLR6201", # literal-membership
"PLR0914", # too-many-locals
"PLR6301", # no-self-use
"PLW1641", # eq-without-hash
"PLR0904", # too-many-public-methods
"PLR1702", # too-many-nested-blocks
"PLW3201", # bad-dunder-method-name
"B028", # no-explicit-stacklevel
]

[tool.ruff.lint.extend-per-file-ignores]
"tests/*" = [
"PLW0603" # global-statement
]

[tool.ruff.lint.pylint]
max-args = 20
25 changes: 11 additions & 14 deletions src/fmu/ensemble/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,12 @@ def __init__(
globbedpaths = [glob.glob(path) for path in paths]
globbedpaths = list({item for sublist in globbedpaths for item in sublist})
if not globbedpaths:
if isinstance(runpathfile, str):
if not runpathfile:
logger.warning("Initialized empty ScratchEnsemble")
return
if isinstance(runpathfile, pd.DataFrame):
if runpathfile.empty:
logger.warning("Initialized empty ScratchEnsemble")
return
if isinstance(runpathfile, str) and not runpathfile:
logger.warning("Initialized empty ScratchEnsemble")
return
if isinstance(runpathfile, pd.DataFrame) and runpathfile.empty:
logger.warning("Initialized empty ScratchEnsemble")
return

count = None
if globbedpaths:
Expand Down Expand Up @@ -893,9 +891,8 @@ def filter(self, localpath, inplace=True, **kwargs):
if inplace:
if not realization.contains(localpath, **kwargs):
deletethese.append(realidx)
else:
if realization.contains(localpath, **kwargs):
keepthese.append(realidx)
elif realization.contains(localpath, **kwargs):
keepthese.append(realidx)

if inplace:
logger.info("Removing realizations %s", deletethese)
Expand Down Expand Up @@ -932,7 +929,7 @@ def drop(self, localpath, **kwargs):
if shortcut2path(self.keys(), localpath) not in self.keys():
raise ValueError("%s not found" % localpath)
for _, realization in self.realizations.items():
try:
try: # noqa: SIM105
realization.drop(localpath, **kwargs)
except ValueError:
pass # Allow localpath to be missing in some realizations
Expand Down Expand Up @@ -1176,7 +1173,7 @@ def get_wellnames(self, well_match=None):
for well in well_match:
result = result.union(set(eclsum.wells(well)))

return sorted(list(result))
return sorted(result)

def get_groupnames(self, group_match=None):
"""
Expand Down Expand Up @@ -1213,7 +1210,7 @@ def get_groupnames(self, group_match=None):
for group in group_match:
result = result.union(set(eclsum.groups(group)))

return sorted(list(result))
return sorted(result)

def agg(self, aggregation, keylist=None, excludekeys=None):
"""Aggregate the ensemble data into one VirtualRealization
Expand Down
4 changes: 2 additions & 2 deletions src/fmu/ensemble/ensembleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def drop(self, localpath, **kwargs):
if self.shortcut2path(localpath) not in self.keys():
raise ValueError("%s not found" % localpath)
for _, ensemble in self._ensembles.items():
try:
try: # noqa: SIM105
ensemble.drop(localpath, **kwargs)
except ValueError:
pass # Allow localpath to be missing in some ensembles.
Expand Down Expand Up @@ -781,4 +781,4 @@ def get_wellnames(self, well_match=None):
result = set()
for _, ensemble in self._ensembles.items():
result = result.union(ensemble.get_wellnames(well_match))
return sorted(list(result))
return sorted(result)
149 changes: 76 additions & 73 deletions src/fmu/ensemble/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, observations):
observations: dict with observation structure or string
with path to a yaml file.
"""
self.observations = dict()
self.observations = {}

if isinstance(observations, str):
with open(observations) as yamlfile:
Expand Down Expand Up @@ -183,7 +183,7 @@ def load_smry(self, realization, smryvector, time_index="yearly", smryerror=None
# it is ok (assuming ISO-datestrings)

# Modify the observation object (self)
if "smry" not in self.observations.keys():
if "smry" not in self.observations:
self.observations["smry"] = [] # Empty list

# Construct a virtual observation with observation units
Expand Down Expand Up @@ -251,7 +251,7 @@ def _realization_mismatch(self, real):
# mismatch_df = pd.DataFrame(columns=['OBSTYPE', 'OBSKEY',
# 'DATE', 'OBSINDEX', 'MISMATCH', 'L1', 'L2', 'SIGN'])
mismatches = []
for obstype in self.observations.keys():
for obstype in self.observations:
for obsunit in self.observations[obstype]: # (list)
if obstype == "txt":
try:
Expand All @@ -267,20 +267,20 @@ def _realization_mismatch(self, real):
measerror = 1
sign = (mismatch > 0) - (mismatch < 0)
mismatches.append(
dict(
OBSTYPE=obstype,
OBSKEY=str(obsunit["localpath"])
{
"OBSTYPE": obstype,
"OBSKEY": str(obsunit["localpath"])
+ "/"
+ str(obsunit["key"]),
LABEL=obsunit.get("label", ""),
MISMATCH=mismatch,
L1=abs(mismatch),
L2=abs(mismatch) ** 2,
SIMVALUE=sim_value,
OBSVALUE=obsunit["value"],
MEASERROR=measerror,
SIGN=sign,
)
"LABEL": obsunit.get("label", ""),
"MISMATCH": mismatch,
"L1": abs(mismatch),
"L2": abs(mismatch) ** 2,
"SIMVALUE": sim_value,
"OBSVALUE": obsunit["value"],
"MEASERROR": measerror,
"SIGN": sign,
}
)
if obstype == "scalar":
try:
Expand All @@ -294,18 +294,18 @@ def _realization_mismatch(self, real):
measerror = 1
sign = (mismatch > 0) - (mismatch < 0)
mismatches.append(
dict(
OBSTYPE=obstype,
OBSKEY=str(obsunit["key"]),
LABEL=obsunit.get("label", ""),
MISMATCH=mismatch,
L1=abs(mismatch),
SIMVALUE=sim_value,
OBSVALUE=obsunit["value"],
MEASERROR=measerror,
L2=abs(mismatch) ** 2,
SIGN=sign,
)
{
"OBSTYPE": obstype,
"OBSKEY": str(obsunit["key"]),
"LABEL": obsunit.get("label", ""),
"MISMATCH": mismatch,
"L1": abs(mismatch),
"SIMVALUE": sim_value,
"OBSVALUE": obsunit["value"],
"MEASERROR": measerror,
"L2": abs(mismatch) ** 2,
"SIGN": sign,
}
)
if obstype == "smryh":
if "time_index" in obsunit:
Expand Down Expand Up @@ -352,16 +352,16 @@ def _realization_mismatch(self, real):
)
measerror = 1
mismatches.append(
dict(
OBSTYPE="smryh",
OBSKEY=obsunit["key"],
LABEL=obsunit.get("label", ""),
MISMATCH=sim_hist["mismatch"].sum(),
MEASERROR=measerror,
L1=sim_hist["mismatch"].abs().sum(),
L2=math.sqrt((sim_hist["mismatch"] ** 2).sum()),
TIME_INDEX=time_index_str,
)
{
"OBSTYPE": "smryh",
"OBSKEY": obsunit["key"],
"LABEL": obsunit.get("label", ""),
"MISMATCH": sim_hist["mismatch"].sum(),
"MEASERROR": measerror,
"L1": sim_hist["mismatch"].abs().sum(),
"L2": math.sqrt((sim_hist["mismatch"] ** 2).sum()),
"TIME_INDEX": time_index_str,
}
)
if obstype == "smry":
# For 'smry', there is a list of
Expand All @@ -381,19 +381,19 @@ def _realization_mismatch(self, real):
mismatch = float(sim_value - unit["value"])
sign = (mismatch > 0) - (mismatch < 0)
mismatches.append(
dict(
OBSTYPE="smry",
OBSKEY=obsunit["key"],
DATE=unit["date"],
MEASERROR=unit["error"],
LABEL=unit.get("label", ""),
MISMATCH=mismatch,
OBSVALUE=unit["value"],
SIMVALUE=sim_value,
L1=abs(mismatch),
L2=abs(mismatch) ** 2,
SIGN=sign,
)
{
"OBSTYPE": "smry",
"OBSKEY": obsunit["key"],
"DATE": unit["date"],
"MEASERROR": unit["error"],
"LABEL": unit.get("label", ""),
"MISMATCH": mismatch,
"OBSVALUE": unit["value"],
"SIMVALUE": sim_value,
"L1": abs(mismatch),
"L2": abs(mismatch) ** 2,
"SIGN": sign,
}
)
return pd.DataFrame(mismatches)

Expand Down Expand Up @@ -422,13 +422,12 @@ def _realization_misfit(self, real, defaulterrors=False, corr=None):
zeroerrors = mismatch["MEASERROR"] < 1e-7
if defaulterrors:
mismatch[zeroerrors]["MEASERROR"] = 1
else:
if zeroerrors.any():
print(mismatch[zeroerrors])
raise ValueError(
"Zero measurement error in observation set"
+ ". can't be used to calculate misfit"
)
elif zeroerrors.any():
print(mismatch[zeroerrors])
raise ValueError(
"Zero measurement error in observation set"
+ ". can't be used to calculate misfit"
)
if "MISFIT" not in mismatch.columns:
mismatch["MISFIT"] = mismatch["L2"] / (mismatch["MEASERROR"] ** 2)

Expand Down Expand Up @@ -460,7 +459,7 @@ def _clean_observations(self):
)
self.observations.pop(key)
# Check smryh observations for validity
if "smryh" in self.observations.keys():
if "smryh" in self.observations:
smryhunits = self.observations["smryh"]
if not isinstance(smryhunits, list):
logger.warning(
Expand All @@ -484,32 +483,36 @@ def _clean_observations(self):
continue
# If time_index is not a supported mnemonic,
# parse it to a date object
if "time_index" in unit:
if unit["time_index"] not in [
if (
"time_index" in unit
and unit["time_index"]
not in {
"raw",
"report",
"yearly",
"daily",
"first",
"last",
"monthly",
] and not isinstance(unit["time_index"], datetime.datetime):
try:
unit["time_index"] = dateutil.parser.isoparse(
unit["time_index"]
).date()
except (TypeError, ValueError) as exception:
logger.warning(
"Parsing date %s failed with error",
(str(unit["time_index"]), str(exception)),
)
del smryhunits[smryhunits.index(unit)]
continue
}
and not isinstance(unit["time_index"], datetime.datetime)
):
try:
unit["time_index"] = dateutil.parser.isoparse(
unit["time_index"]
).date()
except (TypeError, ValueError) as exception:
logger.warning(
"Parsing date %s failed with error",
(str(unit["time_index"]), str(exception)),
)
del smryhunits[smryhunits.index(unit)]
continue
# If everything has been deleted through cleanup, delete the section
if not smryhunits:
del self.observations["smryh"]
# Check smry observations for validity
if "smry" in self.observations.keys():
if "smry" in self.observations:
# We already know that observations['smry'] is a list
# Each list element must be a dict with
# the mandatory keys 'key' and 'observation'
Expand Down
Loading

0 comments on commit fc03b62

Please sign in to comment.