-
Notifications
You must be signed in to change notification settings - Fork 333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Reward-to-go #16
Comments
Yes this is closed now! |
Nice! @vmoens Do you know if this is meant to handle truncation correctly (e.g. as introduced in Gymnasium)? I.e. allows to give a value for the truncation state, |
Similarly, for episodes that have not finished yet (typically the last episode in each concurrent environment), is there a way to find those and mask them out in the loss? |
I think that -- provided that you pass the correct mask to the function -- truncation should be handled properly. This is the tranform. It looks at done or truncated here. Let us know if something is not clear! |
Yes, when an episode was ended (without done=True) truncated is set true on that last state the transform handles it as if that was the last state of the episode:
If you want to mask these episodes out completely you might have to set states, reward, actions (etc) to zero. Simply setting truncated to True for all those steps would not work. Then the reward-to-go transform returns only the current reward per step as it expects that each step is a single episode with length=1:
Let me know if this helped to clarify |
Thanks both for the clarifications. So I guess the answer to my question is that the transform is aware of truncation and handles it as termination. Also, it does not expect the last state to have ["next"]["truncated”] = True or ["next"][“done”] = True, it will only complain if there is not any done or truncated in the batch. >>> from torchrl.envs.transforms import Reward2GoTransform
>>> import torch
>>> from tensordict import TensorDict
>>> r2g = Reward2GoTransform(in_keys=["reward"], out_keys=["reward_to_go"])
>>> td = TensorDict({"reward": torch.ones(4,1), "next": {"done":torch.zeros(4,1).to(dtype=bool), "truncated": torch.zeros(4,1).to(dtype=bool)}}, batch_size=())
>>> td["next"]["truncated"][1]=True
>>> r2g._inv_call(td)["reward_to_go"]
tensor([[2.], # Belongs to truncated episode.
[1.], # Belongs to truncated episode. Next is truncated.
[2.], # Belongs to unfinished episode.
[1.]]) # Belongs to unfinished episode. Next is not truncated, nor done. Otherwise, regarding the last step truncation:
Where does this happen? It doesn't seem to be done by a collector at the last frame of a batch. (I’m new to TorchRL and in the process of deciding whether I should adopt it!) |
Implement reward-to-go (as in here)
The text was updated successfully, but these errors were encountered: