-
Notifications
You must be signed in to change notification settings - Fork 333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Algorithm] Added TQC #1631
base: main
Are you sure you want to change the base?
[Algorithm] Added TQC #1631
Conversation
Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this!
I made a quick initial review, but generally it looks it's on the eight track!
Would it make sense to host this loss as a loss in objectives? Like this we could write tests for it and adopt it as a "feature" of the lib in itself.
I can help write the tests.
examples/tqc/utils.py
Outdated
sorted_z_part = sorted_z[:, :-self.top_quantiles_to_drop] | ||
|
||
# compute target | ||
target = reward + not_done * self.gamma * (sorted_z_part - alpha * next_log_pi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use our value estimators for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a slight issue / inconvenience with this: the shape of reward is [batch size, 1], whereas the shape of the term in the bracket on the right side is [batch size, n_nets * n_quantiles - top_quantiles_to_drop]. (so with the standard hyper parameters we would have [batch size, 115])
In the computation of target, there is therefore some broadcasting going on, and target ends up having shape [batch size, n_nets * n_quantiles - top_quantiles_to_drop]. The critic outputs a tensor of shape [batch size, n_nets, n_quantiles] (although this could be reshaped to [batch size, n_nets * n_quantiles] of course; or with the standard hyper parameters [batch size, 125]).
However, the TD0Estimator (or any other one presumably) expects "reward" and "state_action_value" to be of the same shape (i.e. it does not support broadcasting in the computation of target)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what a reasonable solution here would be. The core issue seems to be that TD0Estimator does not allow for broadcasting in the computation of target; however we want to broadcast here since the critic doesn't output one value estimate but a collection of quantile estimates for the q value function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error arising from this issue (critic output of different shape than reward) is:
RuntimeError: All input tensors (value, reward and done states) must share a unique shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have now added some code to demonstrate this problem in the latest commit e208692 (it's a bit too cumbersome to post all the code in a comment). Running it should reproduce the issue .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for leaving this dormant for so long
It's ok like this if the value can't be adapted. I would just have a dummy make_value_estimator
that does nothing else than checking the value estimator type and telling the user if that doesn't fit what is required. (no need to set it since the code takes care of it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries! Ok - I'll see to doing that on Monday 😊
Do you think it would make sense to allow the value estimators to accept critic outputs of a different shape? (Like is needed for TQC) What's the reason for imposing the restriction on critic shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Would you like me to move the TQCLoss implementation into the objectives directory?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will draft the move for you and maybe leave it to you to complete the missing bits?
I'll have a look at how the critic works in this case for the shape restriction.
The reason we have that is quite simply that not ensuring it can lead to issues :)
The most common is broadcasting: say you pass a reward of shape [T, 1] and a done state of shape [T], value estimation without check will broadcast the shapes and return a value of shape [T, T].
Because you can always expand a tensor with zero memory cost (as we do in MARL) we prefer to enforce that all shapes match ahead of time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
That makes sense, thanks for explaining! Perhaps one could add an option to explicitly specify the shapes of reward, value and done states when constructing the estimator? Then do some consistency checks on those shapes.
Thank you Vincent! Will get to work on those changes now. I think adding the TQCLoss to the objectives would be a great idea. If you could assist in writing tests, that would be great. |
Just to keep you in the loop: |
Amazing! Definitely do let me know if I can help. Would love to hear your opinion on the issue with the value estimators (see the comment and the latest commit). |
@maxweissenbacher I made a PR against your main branch with proposed changes |
Refactor TQC
Hi @maxweissenbacher |
Hi Vincent, my apologies for leaving this dormant for so long, I have been busy with moving house. I will get to work on it this week! |
from torchrl.objectives.utils import ValueEstimators | ||
|
||
|
||
class TQCLoss(LossModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the docstring should be added
It should include:
- a link to the arxiv of the paper or similar
- a short description of what this algo is about
- an example of how to create and run the module
- some notes regarding value estimation etc
# no need to pass device, should be handled by actor/qvalue nets | ||
# device: torch.device, | ||
# gamma should be passed to the value estimator construction | ||
# for consistency with other losses | ||
# gamma: float=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these notes can be removed
self, | ||
actor_network: TensorDictModule, | ||
qvalue_network: TensorDictModule, | ||
top_quantiles_to_drop: float = 10, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
top_quantiles_to_drop: float = 10, | |
*, | |
top_quantiles_to_drop: float = 10, |
create_target_params=True, # Create a target critic network | ||
) | ||
|
||
# self.device = device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# self.device = device |
self.actor(td_next, params=self.actor_params) | ||
self.critic(td_next, params=self.target_critic_params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.actor(td_next, params=self.actor_params) | |
self.critic(td_next, params=self.target_critic_params) | |
with self.actor_params.to_module(self.actor): | |
self.actor(td_next) | |
with self.target_critic_params.to_module(self.actor): | |
self.critic(td_next) |
|
||
# compute and cut quantiles at the next state | ||
next_z = td_next.get(self.tensor_keys.state_action_value) | ||
sorted_z, _ = torch.sort(next_z.reshape(*tensordict_copy.batch_size, -1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorted_z, _ = torch.sort(next_z.reshape(*tensordict_copy.batch_size, -1)) | |
sorted_z = torch.sort(next_z.reshape(*tensordict_copy.batch_size, -1)).values |
sorted_z_part - alpha * next_log_pi | ||
) | ||
|
||
self.critic(tensordict_copy, params=self.critic_params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.critic(tensordict_copy, params=self.critic_params) | |
with self.critic_params.to_module(self.critic): | |
self.critic(tensordict_copy) |
self.actor(tensordict_copy, params=self.actor_params) | ||
self.critic(tensordict_copy, params=self.critic_params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.actor(tensordict_copy, params=self.actor_params) | |
self.critic(tensordict_copy, params=self.critic_params) | |
with self.actor_params.to_module(self.actor): | |
self.actor(tensordict_copy) | |
with self.critic_params.to_module(self.critic): | |
self.critic(tensordict_copy) |
with set_exploration_type(ExplorationType.RANDOM): | ||
dist = self.actor.get_dist( | ||
tensordict, | ||
params=self.actor_params, | ||
) | ||
a_reparm = dist.rsample() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with set_exploration_type(ExplorationType.RANDOM): | |
dist = self.actor.get_dist( | |
tensordict, | |
params=self.actor_params, | |
) | |
a_reparm = dist.rsample() | |
with set_exploration_type(ExplorationType.RANDOM), self.actor_params.to_module(self.actor): | |
dist = self.actor.get_dist( | |
tensordict | |
) | |
a_reparm = dist.rsample() |
No worry at all! I was just checking in to see if there was anything blocking you really :) |
Description
Added implementation of TQC algorithm. I've tried to stay as close to the original PyTorch implementation as possible.
utils.py
file, where the critic and actor architectures and the TQC loss are implemented.TQCLoss
class, which is a subclass ofLossModule
. Currently this class is implemented in theutils.py
file, but this could probably be enhanced a little bit and then added to theobjectives
directory of torchrl, where for instance theSACLoss
classes etc. are implemented. Please feel free to comment on whether you think this should be done and what changes to the current implementation of myTQCLoss
implementation you would recommend.Motivation and Context
This issue
closes #1623
Types of changes
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!