Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Added linear interpolation to FrameSkipTransform #1692

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

Zekrom-7780
Copy link

Description

  • Added an action_interp parameter to enable or disable the action interpolation feature.
  • Added an action_interp_buffer attribute to store the previous action for interpolation.

Motivation and Context

Solves Issue #1659

Types of changes

  • New feature (non-breaking change which adds core functionality)
  • Documentation (update in the documentation)

Checklist

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the documentation accordingly.

Reviewers

cc. @vmoens

P.S. I don't know if I have to add unit-tests for this, OR if this is a breaking change, so I haven't added them here.

Copy link

pytorch-bot bot commented Nov 10, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1692

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 8 Pending, 3 Unrelated Failures

As of commit ce0a846 with merge base 11562c7 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 10, 2023
@vmoens vmoens added the enhancement New feature or request label Nov 10, 2023
@vmoens vmoens changed the title [Feature Request] Added linear interpolation to FrameSkipTransform [Feature] Added linear interpolation to FrameSkipTransform Nov 10, 2023
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for this!
We need to add a couple of tests for this feature.
What should happen at reset time, if there is more than one env being run?

def __init__(self, frame_skip: int = 1):
super().__init__()
def __init__(self, frame_skip: int = 1, action_interp: bool = False):
super().__init()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change?

Copy link
Author

@Zekrom-7780 Zekrom-7780 Nov 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh, that is a mistake from my side, sorry!
(regarding the init)

@@ -3893,28 +3893,59 @@ class FrameSkipTransform(Transform):
Args:
frame_skip (int, optional): a positive integer representing the number
of frames during which the same action must be applied.
action_interp (bool, optional): whether to perform action interpolation over frame_skip steps.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
action_interp (bool, optional): whether to perform action interpolation over frame_skip steps.
action_interp (bool, optional): whether to perform action interpolation over frame_skip steps.
Defaults to ``False``.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add an example that in the docstring

Suggested change
action_interp (bool, optional): whether to perform action interpolation over frame_skip steps.
Suggested change
action_interp (bool, optional): whether to perform action interpolation over frame_skip steps.
Defaults to ``False``.
Examples:
>>> from torchrl.envs import GymEnv
>>> env = TransformedEnv(GymEnv("CartPole-v1"), FrameSkipTransform(3, action_interp=False))
>>> print(env.rollout)
# add print here
>>> env = TransformedEnv(GymEnv("CartPole-v1"), FrameSkipTransform(3, action_interp=True))
>>> print(env.rollout)
# add print here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suree, will add it in the next commit


def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
parent = self.parent
if parent is None:
raise RuntimeError("parent not found for FrameSkipTransform")

if self.action_interp:
action_key = "_action"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be a parameter in the class constructor

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it

)
self.action_interp_buffer = next_action
else:
interpolated_actions = [current_action] * (self.frame_skip - 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to store this contiguously (with a torch.stack around it)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what you mean by this, could you like elaborate?

def _linear_interpolation(
self, start_action: Tensor, end_action: Tensor, num_steps: int
) -> List[Tensor]:
if num_steps <= 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can num_steps be smaller than 0?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They cannot be, but still, I felt the need to add them, just in case the people using them don't run into any issues (like trying to enter a negative number just for fun and all)

Comment on lines +3943 to +3946
for step in range(1, num_steps):
alpha = step / num_steps
interpolated_action = start_action + alpha * (end_action - start_action)
interpolation_steps.append(interpolated_action)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I see what this block is doing, can you give an example?
It'd be cool to see (eg in the tests) what you would expect to get from it

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so my reasoning here was for let's say num_steps = 5, we have:

At step = 1, alpha = 1/5 and interpolated_action is calculated as follows:
interpolated_action=start_action+(1/5)*(end_action - start_action)

Similarly for each step (like for the 4th step, alpha=4/5 ) and interpolated_action is calculated as follows:
interpolated_action=start_action+(4/5)*(end_action - start_action)

And finally, at step = 5, alpha = 5/5 (which is 1), and interpolated_action becomes :
interpolated_action=end_action

So like theinterpolation_steps list contains the sequence of actions that go from from start_action to end_action over the specified number of steps, which in this example is 5.
I felt that this made sense for moving from start_action to end_action , for smooth action transitions over multiple steps

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this makes sense, but I felt that this is how i have to do it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok got it, do you mind if I try to do a vectorized version of this once you have written tests? I think we can make it faster

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suree! please do it!
I'll send another round of commits by the next 3ish hours
Also please could you answer what

would it make sense to store this contiguously (with a torch.stack around it)?

This means?

@Zekrom-7780
Copy link
Author

@vmoens, added the changes that you told me to.

Ready for another round of comments

would it make sense to store this contiguously (with a torch.stack around it)?

I wasn't sure what to do here, so I just wrapped interpolated_actions around a torch.stack and returned it.

Added action_key = "_action" as a parameter in the constructor, and now, will have to change the docstrings as well
Is the action correct, or you wanted a different word, or a different way of doing it?

Thanks so much for this!
We need to add a couple of tests for this feature.
What should happen at reset time, if there is more than one env being run?

Here, Should I create a action_interp_buffer (let's call it that) for each environment, and like reset when that environment is called?
But then how do I get the envirnoment like ID so as to know which one to reset?

@vmoens
Copy link
Contributor

vmoens commented Nov 15, 2023

@vmoens, added the changes that you told me to.

Ready for another round of comments

would it make sense to store this contiguously (with a torch.stack around it)?

I wasn't sure what to do here, so I just wrapped interpolated_actions around a torch.stack and returned it.

Added action_key = "_action" as a parameter in the constructor, and now, will have to change the docstrings as well Is the action correct, or you wanted a different word, or a different way of doing it?

Thanks so much for this!
We need to add a couple of tests for this feature.
What should happen at reset time, if there is more than one env being run?

Here, Should I create a action_interp_buffer (let's call it that) for each environment, and like reset when that environment is called? But then how do I get the envirnoment like ID so as to know which one to reset?

I'd make a step back and define precisely what we expect this to do.
What I understand we want from that feature is to have a stack of interpolated actions, for instance:

  • if action at t=0 is act = [1] and action at t=1 with frame_skip=4 is act=[5] we want the interpolated action to be [[1], [2], [3], [4], [5]], is that right?
  • Maybe people will want to have the interpolated action and the action, so we should be able to pass the interpolated action as a separate key. Something like action = [5] and interpolated_action = [[1], [2], [3], [4], [5]]? I think overwriting the "action" entry by default is dangerous because the action you will see out of the transformed environment will be incompatible with the base environment, so the default should be a separate entry for the interpolated action. If people want, they can use the "action" key name in the constructor but that will be subject to bugs (from which the user is responsible since this is not the default).
  • The way we do the interpolation (in a buffer or in the tensordict) seems accessory for now, I would first code the feature and all the test, then we can change the inner mechanism (I'll be happy to help do not worry).
  • Do we allow any action type? Eg, categorical or one-hot? If not we should check that the action_spec of the parent env (transform.parent.full_action_spec[transform.action_key]) has a compatible type (ie, if of dtype float or smth like that).
  • To me the top-1 priority here is to write a series of tests. To do that, go in test_transforms.py and add one or more tests in TestFrameSkipTransform. Check that the resulting action has the right values by inputting two actions at two consecutive steps and check that the values stricly match what you're expecting. That way we will know that any change we make works within the bounds of what we want to do. Also test what happens at reset time. Let's work with a single environment for now (no batched envs) and we'll adapt that to batched envs later.

@Zekrom-7780
Copy link
Author

@vmoens , I read your comments, and you mention that my first priority should be adding tests

So I looked at the test_transforms.py , and it has like 9000ish lines of code!!
Where should I like write my tests?
This file made me feel extremely overwhelmed!!

@vmoens
Copy link
Contributor

vmoens commented Dec 1, 2023

Ahah I can imagine!
There's a TestFrameSkipTransform with all the frame-skip related stuff. You just need a couple of functions that test that if you create a frame stkip transform the results are what is expected

@Zekrom-7780
Copy link
Author

Thanks a lot, will push some tests later, most probably in 24ish hours

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants