-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A Bug in HERReplayBuffer. #811
Comments
Thanks for the question. I think this case was covered by the first line, However, if you have some codes illustrating the bug, please feel free to share them so that we can investigate further. |
For reference, there is a test case for testing cycled indices for HERReplayBuffer in test_buffer.py. |
@Juno-T Thanks for your reply. I found that the variable
As |
So according to the paper, they found that instead of rewriting the goal by the terminal state's achieved goal, it is equally good or better to use random future state's achieved goal (see fig.6). If you look at the
|
@Juno-T I think I understand what you mean, but the calculation of import numpy as np
from tianshou.data import Batch
from tianshou.data import HERReplayBuffer
from env import MyGoalEnv # from test.base.env import MyGoalEnv
def test_herreplaybuffer():
env_size = 99
bufsize = 15
env = MyGoalEnv(env_size, array_state=False)
def compute_reward_fn(ag, g):
return env.compute_reward_fn(ag, g, {})
buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8)
buf.future_p = 1
for x, ep_len in enumerate([10, 20]):
obs, _ = env.reset()
for i in range(ep_len):
act = 1
obs_next, rew, terminated, truncated, info = env.step(act)
batch = Batch(
obs=obs,
act=[act],
rew=rew,
terminated=(i == ep_len - 1),
truncated=(i == ep_len - 1),
obs_next=obs_next,
info=info
)
if x == 1 and obs["observation"] < 10:
obs = obs_next
continue
buf.add(batch)
obs = obs_next
# The above code generate two episodes. States numbered from 0 to 9 for the first episode. States numbered from 10 to 19 for the second episode.
buf._restore_cache()
sample_indices = np.array([10]) # Suppose the sampled indices is [10]
buf.rewrite_transitions(sample_indices)
print(buf.obs)
# We can see that the desired_goal for 10 may be 6,7,8,9,10,11. But 6,7,8,9 are the states in episode 1.
if __name__ == '__main__':
test_herreplaybuffer() In the above code, the |
@Juno-T If change the code to Here is an explanation for my code:
|
Thank you for the clarification, I have confirmed the bug locally and it is indeed a bug in I think this should fix it: def rewrite_transitions(self, indices: np.ndarray) -> None:
...
# Calculate future timestep to use
current = indices[0]
terminal = indices[-1]
episodes_len = (terminal - current + self.maxsize) % self.maxsize
future_offset = np.random.uniform(size=len(indices[0])) * (episodes_len)
future_offset = future_offset.astype(int)
future_t = (current + future_offset) % self.maxsize
... Would you like to create a pull request to fix this bug? You can add your test case to the test_buffer.py and also verfiy some edge cases. |
@Juno-T Thanks for your reply. I will make a pull request when I'm free. |
Hello, I have some questions in the code of
rewrite_transitions
inHERReplayBuffer
:As ReplayBuffer is implemented as a circular queue, the indices in
terminal
may be less than the corresponding indices incurrent
. So that some elements infuture_offset
may be negative, which will make the states infuture_t
not desired future states.@Juno-T
The text was updated successfully, but these errors were encountered: