diff --git a/nle/agent/agent.py b/nle/agent/agent.py index 2fcf3b827..748fbe17a 100644 --- a/nle/agent/agent.py +++ b/nle/agent/agent.py @@ -113,10 +113,9 @@ def nested_map(f, n): if isinstance(n, tuple) or isinstance(n, list): return n.__class__(nested_map(f, sn) for sn in n) - elif isinstance(n, dict): + if isinstance(n, dict): return {k: nested_map(f, v) for k, v in n.items()} - else: - return f(n) + return f(n) def compute_baseline_loss(advantages): diff --git a/nle/env/base.py b/nle/env/base.py index 479a87067..534911856 100644 --- a/nle/env/base.py +++ b/nle/env/base.py @@ -485,7 +485,7 @@ def render(self, mode="human"): tty_colors = obs[self._observation_keys.index("tty_colors")] tty_cursor = obs[self._observation_keys.index("tty_cursor")] print(nethack.tty_render(tty_chars, tty_colors, tty_cursor)) - return + return None if mode == "full": message_index = self._observation_keys.index("message") @@ -509,7 +509,7 @@ def render(self, mode="human"): chars = self.last_observation[self._observation_keys.index("chars")] colors = self.last_observation[self._observation_keys.index("colors")] print(nethack.tty_render(chars, colors)) - return + return None if mode in ("ansi", "string"): # Misnomer: This is the least ANSI of them all. chars = self.last_observation[self._observation_keys.index("chars")] diff --git a/nle/nethack/nethack.py b/nle/nethack/nethack.py index 68a0f0220..8a9d87bae 100644 --- a/nle/nethack/nethack.py +++ b/nle/nethack/nethack.py @@ -307,7 +307,7 @@ def set_current_seeds(self, core=None, disp=None, reseed=False): seeds = [core, disp, reseed] if any(s is None for s in seeds): if all(s is None for s in seeds): - return + return None for i, (s, s0) in enumerate(zip(seeds, self.get_current_seeds())): if s is None: seeds[i] = s0 diff --git a/nle/scripts/collect_env.py b/nle/scripts/collect_env.py index a4e42988c..4c6c3bccd 100644 --- a/nle/scripts/collect_env.py +++ b/nle/scripts/collect_env.py @@ -189,14 +189,13 @@ def get_nvidia_smi(): def get_platform(): if sys.platform.startswith("linux"): return "linux" - elif sys.platform.startswith("win32"): + if sys.platform.startswith("win32"): return "win32" - elif sys.platform.startswith("cygwin"): + if sys.platform.startswith("cygwin"): return "cygwin" - elif sys.platform.startswith("darwin"): + if sys.platform.startswith("darwin"): return "darwin" - else: - return sys.platform + return sys.platform def get_mac_version(run_lambda): diff --git a/nle/scripts/read_tty.py b/nle/scripts/read_tty.py index 96bed66d1..62d332bf2 100644 --- a/nle/scripts/read_tty.py +++ b/nle/scripts/read_tty.py @@ -71,16 +71,15 @@ def getfile(filename): f = os.fdopen(os.dup(0), "rb") os.dup2(1, 0) return f - elif os.path.splitext(filename)[1] in (".bz2", ".bzip2"): + if os.path.splitext(filename)[1] in (".bz2", ".bzip2"): import bz2 return bz2.BZ2File(filename) - elif os.path.splitext(filename)[1] in (".gz", ".gzip"): + if os.path.splitext(filename)[1] in (".gz", ".gzip"): import gzip return gzip.GzipFile(filename) - else: - return open(filename, "rb") + return open(filename, "rb") def color(s, value): diff --git a/pyproject.toml b/pyproject.toml index 1014c6b2f..dd33d9172 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,24 +16,6 @@ force_single_line = true profile = "black" skip_glob = "**/__init__.py" -[tool.pylint.messages_control] -disable = [ - "missing-class-docstring", - "invalid-name", # pylint is very strict. - "missing-class-docstring", - "missing-function-docstring", - "missing-module-docstring", - "c-extension-no-member", - "no-member", # too many false positives. -] -[tool.pylint.typecheck] -generated-members=["numpy.*", "torch.*", "nle._pynethack.*"] -[tool.pylint.design] -max-args=15 # Maximum number of arguments for function / method. -max-attributes=50 # Maximum number of attributes for a class (see R0902). -max-bool-expr=5 # Maximum number of boolean expressions in an if statement (see R0916). -max-branches=15 # Maximum number of branch for function / method body. -max-locals=30 # Maximum number of locals for function / method body. [tool.ruff] # See https://docs.astral.sh/ruff/rules/. extend-exclude = [ @@ -55,6 +37,7 @@ select = [ "E", "F", "W", + "R", ] [tool.ruff.lint.flake8-comprehensions] allow-dict-calls-with-keyword-arguments = true