Skip to content
This repository has been archived by the owner on May 6, 2024. It is now read-only.

Wizkit support #78

Merged
merged 6 commits into from
Sep 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ venv.bak/
.spyderproject
.spyproject

# IDE
.idea/

# Rope project settings
.ropeproject

Expand Down
10 changes: 7 additions & 3 deletions nle/env/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@ def _get_observation(self, observation):
for key, i in zip(self._original_observation_keys, self._original_indices)
}

def print_action_meanings(self):
for a_idx, a in enumerate(self._actions):
print(a_idx, a)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have something similar when you play via python nle/scripts/play.py, but if you find this helpful, why not.

Generally we try to use _ for pseudo-private methods, so this might be better w/o the underscore?

The same goes for tests btw -- we are not 100% strict about this, but I'd prefer there to be no access to underscored variables/functions in tests, as this breaks encapsulation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the underscore from the method's name.

How do you write unit tests for some internal behaviour of an object without accessing internal fields?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

That's a great question and as I said we're not 100% strict about that. But my idea would be that unit tests should test public behavior, not their private implementation.

More on this idea here: https://testing.googleblog.com/2015/01/testing-on-toilet-prefer-testing-public.html?m=1


def step(self, action: int):
"""Steps the environment.

Expand Down Expand Up @@ -348,7 +352,7 @@ def _in_moveloop(self, observation):
program_state = observation[self._program_state_index]
return program_state[3] # in_moveloop

def reset(self):
def reset(self, wizkit_items=None):
"""Resets the environment.

Note:
Expand All @@ -362,7 +366,7 @@ def reset(self):
"""
self._episode += 1
new_ttyrec = self._ttyrec_pattern % self._episode if self.savedir else None
self.last_observation = self.env.reset(new_ttyrec)
self.last_observation = self.env.reset(new_ttyrec, wizkit_items=wizkit_items)

# Only run on the first reset to initialize stats file
if self._setup_statsfile:
Expand Down Expand Up @@ -396,7 +400,7 @@ def reset(self):
warnings.warn(
"Not in moveloop after 1000 tries, aborting (ttyrec: %s)." % new_ttyrec
)
return self.reset()
return self.reset(wizkit_items=wizkit_items)

return self._get_observation(self.last_observation)

Expand Down
22 changes: 19 additions & 3 deletions nle/nethack/nethack.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,16 @@
]

HACKDIR = os.getenv("HACKDIR", pkg_resources.resource_filename("nle", "nethackdir"))
WIZKIT_FNAME = "wizkit.txt"


def _set_env_vars(options, hackdir):
def _set_env_vars(options, hackdir, wizkit=None):
# TODO: Investigate not using environment variables for this.
os.environ["NETHACKOPTIONS"] = ",".join(options)
os.environ["HACKDIR"] = hackdir
os.environ["TERM"] = os.environ.get("TERM", "screen")
if wizkit is not None:
os.environ["WIZKIT"] = wizkit


# TODO: Not thread-safe for many reasons.
Expand Down Expand Up @@ -109,6 +112,7 @@ def __init__(
self._options = list(options) + ["name:" + playername]
if wizard:
self._options.append("playmode:debug")
self._wizard = wizard

_set_env_vars(self._options, self._vardir)
self._ttyrec = ttyrec
Expand All @@ -134,8 +138,20 @@ def step(self, action):
self._pynethack.step(action)
return self._step_return(), self._pynethack.done()

def reset(self, new_ttyrec=None):
_set_env_vars(self._options, self._vardir)
def _write_wizkit_file(self, wizkit_items):
# TODO ideally we need to check the validity of the requested items
with open(os.path.join(self._vardir, WIZKIT_FNAME), "a") as f:
for item in wizkit_items:
f.write(f"{item}\n")

def reset(self, new_ttyrec=None, wizkit_items=None):
if wizkit_items is not None:
if not self._wizard:
raise ValueError("Set wizard=True to use the wizkit option.")
self._write_wizkit_file(wizkit_items)
_set_env_vars(self._options, self._vardir, wizkit=WIZKIT_FNAME)
else:
_set_env_vars(self._options, self._vardir)
if new_ttyrec is None:
self._pynethack.reset()
else:
Expand Down
44 changes: 44 additions & 0 deletions nle/tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,50 @@ def test_default_wizard_mode(self, env_name, wizard):
assert "playmode:debug" not in env.env._options


class TestWizkit:
@pytest.yield_fixture(autouse=True) # will be applied to all tests in class
def make_cwd_tmp(self, tmpdir):
"""Makes cwd point to the test's tmpdir."""
with tmpdir.as_cwd():
yield

def test_meatball_exists(self):
"""Test loading stuff via wizkit"""
env = gym.make("NetHack-v0", wizard=True)
found = dict(meatball=0)
obs = env.reset(wizkit_items=list(found.keys()))
for line in obs["inv_strs"]:
if np.all(line == 0):
break
for key in found:
if key in line.tobytes().decode("utf-8"):
found[key] += 1
for key, count in found.items():
assert key == key and count > 0
del env

def test_wizkit_no_wizard_mode(self):
env = gym.make("NetHack-v0", wizard=False)
with pytest.raises(ValueError) as e_info:
env.reset(wizkit_items=["meatball"])
assert e_info.value.args[0] == "Set wizard=True to use the wizkit option."

def test_wizkit_file(self):
env = gym.make("NetHack-v0", wizard=True)
req_items = ["meatball", "apple"]
env.reset(wizkit_items=req_items)
path_to_wizkit = os.path.join(env.env._vardir, nethack.nethack.WIZKIT_FNAME)

# test file exists
os.path.exists(path_to_wizkit)

# test that file content corresponds to what you requested
with open(path_to_wizkit, "r") as f:
for item, line in zip(req_items, f):
assert item == line.strip()
del env


@pytest.mark.parametrize("env_name", [e for e in get_nethack_env_ids() if "Score" in e])
class TestBasicGymEnv:
def test_inventory(self, env_name):
Expand Down