Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Reward-to-go #16

Closed
vmoens opened this issue Feb 11, 2022 · 7 comments
Closed

[Feature Request] Reward-to-go #16

vmoens opened this issue Feb 11, 2022 · 7 comments
Assignees
Labels
Good first issue A good way to start hacking torchrl! new algo New algorithm request or PR

Comments

@vmoens
Copy link
Contributor

vmoens commented Feb 11, 2022

Implement reward-to-go (as in here)

@vmoens vmoens added the new algo New algorithm request or PR label Feb 11, 2022
@vmoens vmoens self-assigned this Feb 11, 2022
@Benjamin-eecs Benjamin-eecs changed the title Reward-to-go [Feature Request] Reward-to-go Jul 21, 2022
@vmoens vmoens added the Good first issue A good way to start hacking torchrl! label Jan 5, 2023
@skandermoalla
Copy link
Contributor

@vmoens Is this issue still open? @BY571 Does #1038 close it?

@vmoens
Copy link
Contributor Author

vmoens commented May 17, 2023

Yes this is closed now!

@vmoens vmoens closed this as completed May 17, 2023
@skandermoalla
Copy link
Contributor

skandermoalla commented May 17, 2023

Nice! @vmoens Do you know if this is meant to handle truncation correctly (e.g. as introduced in Gymnasium)? I.e. allows to give a value for the truncation state, $V(s_{truncation})$, to bootstrap all the "reward-to-go", $G_t$, in that episode from it.
Or maybe, if this is deferred to the learning logic, there is a flag for states belonging to truncated episodes so that one could bootstrap their values (e.g. $\nabla \log \pi(a_t|s_t)[(G_t + (1_{s_t \in truncatedEpisode} * V(s_{truncation})) ...]$).
I'm not aware of the updates since #403.

@skandermoalla
Copy link
Contributor

Similarly, for episodes that have not finished yet (typically the last episode in each concurrent environment), is there a way to find those and mask them out in the loss?
Thanks!

@vmoens
Copy link
Contributor Author

vmoens commented May 18, 2023

I think that -- provided that you pass the correct mask to the function -- truncation should be handled properly.
@BY571 can you confirm?

This is the tranform. It looks at done or truncated here.
The functional handles these two as a "done" but as you can see, upstream the transform will do done = done | truncated.

Let us know if something is not clear!

@BY571
Copy link
Contributor

BY571 commented May 18, 2023

Yes, when an episode was ended (without done=True) truncated is set true on that last state the transform handles it as if that was the last state of the episode:

>>> from torchrl.envs.transforms import Reward2GoTransform
>>> import torch
>>> from tensordict import TensorDict
>>> r2g = Reward2GoTransform(in_keys=["reward"], out_keys=["reward_to_go"])
>>> td = TensorDict({"reward": torch.ones(4,1), "next": {"done":torch.zeros(4,1).to(dtype=bool), "truncated": torch.zeros(4,1).to(dtype=bool)}}, batch_size=())
>>> td["next"]["truncated"][-1]=True
>>> r2g._inv_call(td)["reward_to_go"]
tensor([[4.],
        [3.],
        [2.],
        [1.]])

If you want to mask these episodes out completely you might have to set states, reward, actions (etc) to zero. Simply setting truncated to True for all those steps would not work. Then the reward-to-go transform returns only the current reward per step as it expects that each step is a single episode with length=1:

>>> td = TensorDict({"reward": torch.ones(4,1), "next": {"done":torch.zeros(4,1).to(dtype=bool), "truncated": torch.zeros(4,1).to(dtype=bool)}}, batch_size=())
>>> td["next"]["truncated"][-3:]=True
>>> r2g._inv_call(td)["reward_to_go"]
tensor([[2.],
        [1.],
        [1.],
        [1.]])
      
>>> td = TensorDict({"reward": torch.ones(4,1), "next": {"done":torch.zeros(4,1).to(dtype=bool), "truncated": torch.zeros(4,1).to(dtype=bool)}}, batch_size=())
>>> td["next"]["truncated"][-3:]=True
>>> td["reward"][-3:]=0
>>> r2g._inv_call(td)["reward_to_go"]
tensor([[1.],
        [0.],
        [0.],
        [0.]])

Let me know if this helped to clarify

@skandermoalla
Copy link
Contributor

skandermoalla commented May 22, 2023

Thanks both for the clarifications.

So I guess the answer to my question is that the transform is aware of truncation and handles it as termination.
So, it will not bootstrap truncated episodes or mask unfinished ones. This is left to the user.

Also, it does not expect the last state to have ["next"]["truncated”] = True or ["next"][“done”] = True, it will only complain if there is not any done or truncated in the batch.

>>> from torchrl.envs.transforms import Reward2GoTransform
>>> import torch
>>> from tensordict import TensorDict
>>> r2g = Reward2GoTransform(in_keys=["reward"], out_keys=["reward_to_go"])
>>> td = TensorDict({"reward": torch.ones(4,1), "next": {"done":torch.zeros(4,1).to(dtype=bool), "truncated": torch.zeros(4,1).to(dtype=bool)}}, batch_size=())
>>> td["next"]["truncated"][1]=True
>>> r2g._inv_call(td)["reward_to_go"]
tensor([[2.],	# Belongs to truncated episode.
        [1.],	# Belongs to truncated episode. Next is truncated.
        [2.],	# Belongs to unfinished episode.
        [1.]])	# Belongs to unfinished episode. Next is not truncated, nor done.

Otherwise, regarding the last step truncation:

Yes, when an episode was ended (without done=True) truncated is set true

Where does this happen? It doesn't seem to be done by a collector at the last frame of a batch. (I’m new to TorchRL and in the process of deciding whether I should adopt it!)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Good first issue A good way to start hacking torchrl! new algo New algorithm request or PR
Projects
None yet
Development

No branches or pull requests

3 participants