Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
tae898 committed Apr 12, 2024
1 parent 75750de commit 82284f8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
9 changes: 5 additions & 4 deletions room_env/envs/room2.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def __init__(

# place an object in one of the rooms when it is created.
if self.deterministic:
self.location = sample_max_value_key(
self.init_probs, keys_to_exclude=["stay"]
)
self.location = sample_max_value_key(self.init_probs)

else:
self.location = random.choices(
Expand Down Expand Up @@ -150,6 +148,7 @@ def __init__(
init_probs: dict,
transition_probs: dict,
question_prob: float,
deterministic: bool,
) -> None:
"""Static object does not move. Once they are initialized, they stay forever.
Expand All @@ -160,6 +159,7 @@ def __init__(
transition_probs: just a place holder. It's not gonna be used anyway.
question_prob: the probability of a question being asked at every
observation
deterministic: whether the object is deterministic.
"""
super().__init__(
Expand All @@ -168,7 +168,7 @@ def __init__(
init_probs,
transition_probs,
question_prob,
deterministic=True,
deterministic=deterministic,
)
assert self.transition_probs is None, "Static objects do not move."

Expand Down Expand Up @@ -573,6 +573,7 @@ def _create_objects(self) -> None:
init_probs,
self.object_transition_config["static"][name],
self.object_question_probs["static"][name],
self.deterministic_objects,
)
)

Expand Down
4 changes: 4 additions & 0 deletions test/room_env2/test_room_env_v2_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def test_all(self) -> None:
"room1": {"north": 0, "east": 0, "south": 0, "west": 0, "stay": 0},
},
question_prob=0.5,
deterministic=True,
)
with self.assertRaises(AssertionError):
foo = StaticObject(
Expand All @@ -188,18 +189,21 @@ def test_all(self) -> None:
"room1": {"north": 0, "east": 0, "south": 0, "west": 0, "stay": 0},
},
question_prob=0.5,
deterministic=True,
)
foo = StaticObject(
name="foo",
init_probs={"room0": 1.0, "room1": 0},
transition_probs=None,
question_prob=0.5,
deterministic=True,
)
bar = StaticObject(
name="foo",
init_probs={"room0": 1.0, "room1": 0},
transition_probs=None,
question_prob=0.5,
deterministic=True,
)
self.assertEqual(foo, bar)

Expand Down

0 comments on commit 82284f8

Please sign in to comment.