Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] Added TQC #1631

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open

Conversation

maxweissenbacher
Copy link
Contributor

@maxweissenbacher maxweissenbacher commented Oct 17, 2023

Description

Added implementation of TQC algorithm. I've tried to stay as close to the original PyTorch implementation as possible.

  • Implementation is based on the implementation of SAC in the examples (here).
  • The main differences to SAC are in the utils.py file, where the critic and actor architectures and the TQC loss are implemented.
  • Added a TQCLoss class, which is a subclass of LossModule. Currently this class is implemented in the utils.py file, but this could probably be enhanced a little bit and then added to the objectives directory of torchrl, where for instance the SACLoss classes etc. are implemented. Please feel free to comment on whether you think this should be done and what changes to the current implementation of my TQCLoss implementation you would recommend.
  • I have tested the code on the HalfCheetah-v4 environment and it seems to perform well (see screenshot below). I am happy to conduct more extensive tests (such as several runs with different seed values to assess robustness, training on different environments) if you think that would be a helpful check.

eval reward

training losses

Motivation and Context

This issue closes #1623

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.*
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • Might be a good idea to add a few test runs of TQC to the EXAMPLES.md file.

@facebook-github-bot
Copy link

Hi @maxweissenbacher!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@maxweissenbacher maxweissenbacher changed the title Added TQC implementation (initial commit) Added TQC Oct 17, 2023
@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 17, 2023
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this!

I made a quick initial review, but generally it looks it's on the eight track!

Would it make sense to host this loss as a loss in objectives? Like this we could write tests for it and adopt it as a "feature" of the lib in itself.
I can help write the tests.

examples/tqc/utils.py Outdated Show resolved Hide resolved
sorted_z_part = sorted_z[:, :-self.top_quantiles_to_drop]

# compute target
target = reward + not_done * self.gamma * (sorted_z_part - alpha * next_log_pi)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use our value estimators for this?

Copy link
Contributor Author

@maxweissenbacher maxweissenbacher Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a slight issue / inconvenience with this: the shape of reward is [batch size, 1], whereas the shape of the term in the bracket on the right side is [batch size, n_nets * n_quantiles - top_quantiles_to_drop]. (so with the standard hyper parameters we would have [batch size, 115])

In the computation of target, there is therefore some broadcasting going on, and target ends up having shape [batch size, n_nets * n_quantiles - top_quantiles_to_drop]. The critic outputs a tensor of shape [batch size, n_nets, n_quantiles] (although this could be reshaped to [batch size, n_nets * n_quantiles] of course; or with the standard hyper parameters [batch size, 125]).

However, the TD0Estimator (or any other one presumably) expects "reward" and "state_action_value" to be of the same shape (i.e. it does not support broadcasting in the computation of target)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what a reasonable solution here would be. The core issue seems to be that TD0Estimator does not allow for broadcasting in the computation of target; however we want to broadcast here since the critic doesn't output one value estimate but a collection of quantile estimates for the q value function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error arising from this issue (critic output of different shape than reward) is:

RuntimeError: All input tensors (value, reward and done states) must share a unique shape.

Copy link
Contributor Author

@maxweissenbacher maxweissenbacher Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now added some code to demonstrate this problem in the latest commit e208692 (it's a bit too cumbersome to post all the code in a comment). Running it should reproduce the issue .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for leaving this dormant for so long
It's ok like this if the value can't be adapted. I would just have a dummy make_value_estimator that does nothing else than checking the value estimator type and telling the user if that doesn't fit what is required. (no need to set it since the code takes care of it)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries! Ok - I'll see to doing that on Monday 😊

Do you think it would make sense to allow the value estimators to accept critic outputs of a different shape? (Like is needed for TQC) What's the reason for imposing the restriction on critic shape?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Would you like me to move the TQCLoss implementation into the objectives directory?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will draft the move for you and maybe leave it to you to complete the missing bits?

I'll have a look at how the critic works in this case for the shape restriction.
The reason we have that is quite simply that not ensuring it can lead to issues :)
The most common is broadcasting: say you pass a reward of shape [T, 1] and a done state of shape [T], value estimation without check will broadcast the shapes and return a value of shape [T, T].
Because you can always expand a tensor with zero memory cost (as we do in MARL) we prefer to enforce that all shapes match ahead of time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

That makes sense, thanks for explaining! Perhaps one could add an option to explicitly specify the shapes of reward, value and done states when constructing the estimator? Then do some consistency checks on those shapes.

examples/tqc/utils.py Outdated Show resolved Hide resolved
examples/tqc/utils.py Outdated Show resolved Hide resolved
examples/tqc/utils.py Outdated Show resolved Hide resolved
examples/tqc/utils.py Outdated Show resolved Hide resolved
examples/tqc/utils.py Outdated Show resolved Hide resolved
@vmoens vmoens added the new algo New algorithm request or PR label Oct 18, 2023
@vmoens vmoens changed the title Added TQC [Algorithm] Added TQC Oct 18, 2023
@maxweissenbacher
Copy link
Contributor Author

Thank you Vincent! Will get to work on those changes now. I think adding the TQCLoss to the objectives would be a great idea. If you could assist in writing tests, that would be great.

@vmoens
Copy link
Contributor

vmoens commented Oct 25, 2023

Just to keep you in the loop:
I'm working on a separate branch that integrates the loss in objectives, I'll keep you posted and probably ask for your help :)

@maxweissenbacher
Copy link
Contributor Author

Amazing! Definitely do let me know if I can help. Would love to hear your opinion on the issue with the value estimators (see the comment and the latest commit).

@vmoens
Copy link
Contributor

vmoens commented Nov 9, 2023

@maxweissenbacher I made a PR against your main branch with proposed changes
maxweissenbacher#1
(nit: in the future, it's best to branch out before making a PR, working from your main branch can cause some hustle for other developers working with you :) )

Copy link

pytorch-bot bot commented Nov 10, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1631

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (18 Unrelated Failures)

As of commit 2e56b5b with merge base c7d4764 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@vmoens
Copy link
Contributor

vmoens commented Nov 27, 2023

Hi @maxweissenbacher
Would you like us to take it from there?
I can find some time to make this mergeable, just wanted to check with you first if you still wanted to work on it?

@maxweissenbacher
Copy link
Contributor Author

Hi Vincent, my apologies for leaving this dormant for so long, I have been busy with moving house. I will get to work on it this week!

from torchrl.objectives.utils import ValueEstimators


class TQCLoss(LossModule):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the docstring should be added
It should include:

  • a link to the arxiv of the paper or similar
  • a short description of what this algo is about
  • an example of how to create and run the module
  • some notes regarding value estimation etc

Comment on lines +67 to +71
# no need to pass device, should be handled by actor/qvalue nets
# device: torch.device,
# gamma should be passed to the value estimator construction
# for consistency with other losses
# gamma: float=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these notes can be removed

self,
actor_network: TensorDictModule,
qvalue_network: TensorDictModule,
top_quantiles_to_drop: float = 10,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
top_quantiles_to_drop: float = 10,
*,
top_quantiles_to_drop: float = 10,

create_target_params=True, # Create a target critic network
)

# self.device = device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# self.device = device

Comment on lines +151 to +152
self.actor(td_next, params=self.actor_params)
self.critic(td_next, params=self.target_critic_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.actor(td_next, params=self.actor_params)
self.critic(td_next, params=self.target_critic_params)
with self.actor_params.to_module(self.actor):
self.actor(td_next)
with self.target_critic_params.to_module(self.actor):
self.critic(td_next)


# compute and cut quantiles at the next state
next_z = td_next.get(self.tensor_keys.state_action_value)
sorted_z, _ = torch.sort(next_z.reshape(*tensordict_copy.batch_size, -1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sorted_z, _ = torch.sort(next_z.reshape(*tensordict_copy.batch_size, -1))
sorted_z = torch.sort(next_z.reshape(*tensordict_copy.batch_size, -1)).values

sorted_z_part - alpha * next_log_pi
)

self.critic(tensordict_copy, params=self.critic_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.critic(tensordict_copy, params=self.critic_params)
with self.critic_params.to_module(self.critic):
self.critic(tensordict_copy)

Comment on lines +179 to +180
self.actor(tensordict_copy, params=self.actor_params)
self.critic(tensordict_copy, params=self.critic_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.actor(tensordict_copy, params=self.actor_params)
self.critic(tensordict_copy, params=self.critic_params)
with self.actor_params.to_module(self.actor):
self.actor(tensordict_copy)
with self.critic_params.to_module(self.critic):
self.critic(tensordict_copy)

Comment on lines +198 to +203
with set_exploration_type(ExplorationType.RANDOM):
dist = self.actor.get_dist(
tensordict,
params=self.actor_params,
)
a_reparm = dist.rsample()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with set_exploration_type(ExplorationType.RANDOM):
dist = self.actor.get_dist(
tensordict,
params=self.actor_params,
)
a_reparm = dist.rsample()
with set_exploration_type(ExplorationType.RANDOM), self.actor_params.to_module(self.actor):
dist = self.actor.get_dist(
tensordict
)
a_reparm = dist.rsample()

@vmoens
Copy link
Contributor

vmoens commented Nov 27, 2023

Hi Vincent, my apologies for leaving this dormant for so long, I have been busy with moving house. I will get to work on it this week!

No worry at all! I was just checking in to see if there was anything blocking you really :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. new algo New algorithm request or PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants