Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] functiondict/opdict: TensorDict of Callables for apply() #937

Closed
ludwigwinkler opened this issue Jul 31, 2024 · 4 comments · Fixed by #939
Closed

[Feature Request] functiondict/opdict: TensorDict of Callables for apply() #937

ludwigwinkler opened this issue Jul 31, 2024 · 4 comments · Fixed by #939
Assignees
Labels
enhancement New feature or request

Comments

@ludwigwinkler
Copy link

Motivation

I would like to define a TensorDict of callables, which I can then apply elementwise as functions to a TensorDict of Tensors.

Solution

Until recently in tensordict==0.4, I could create this behavior with the hacky _run_checks=False which would prevent Tensordict to wrap the callables in NonTensorData.

I create two tensordicts x and y of data and one tensordict of functions.
Then I use apply() to combine x and y according to the tensordict of functions:

import torch
import tensordict
import inspect

from tensordict import TensorDict

x = TensorDict(
    {
        "x": torch.tensor([1.0, 0.1, 0.0]),
        "y": torch.tensor([-1.0, 0.1, 0.0]),
        "z": {"a": torch.tensor([1.0, 0.1, 0.0]), "b": torch.tensor([1.0, 0.1, 0.0])},
    },
    batch_size=None,
)

y = TensorDict(
    {
        "x": torch.tensor([-1.0, 0.1, 0.0]),
        "y": torch.tensor([2.0, 0.1, -0.1]),
        "z": {"a": torch.tensor([1.0, 0.1, 0.0]), "b": torch.tensor([1.0, 0.1, 0.0])},
    },
)

compose_fn = TensorDict(
    {"x": lambda x, y: x + y, "y": torch.mul, "z": {"a": torch.add, "b": torch.mul}},
    _run_checks=False,
)

out = x.apply(lambda x_, fn, y_: fn(x_, y_), compose_fn, y)

Now in tensordict==0.5 this not possible anymore.

Is there perhaps a more elegant and abstracted way of achieving such tensordicts of functions?

Alternatives

I know that one can access NonTensorData by using tolist() which will return the desired function if there is only a single element per node, but it's not the most elegant solution.

@ludwigwinkler ludwigwinkler added the enhancement New feature or request label Jul 31, 2024
@vmoens
Copy link
Contributor

vmoens commented Jul 31, 2024

Interesting, never thought of that

On the top of my head, here are a couple of things I would suggest:

  1. Use named_apply
dict_func = defaultdict(lambda: lambda x, y: x+y)
dict_func["a"] = lambda x,y: x-y
td0 = TensorDict(a=1, b=1, c=1)
td1 = TensorDict(a=2, b=2, c=2)
td = td0.named_apply(lambda name, x, y: dict_func[name](x, y), td1)
assert td["a"] == -1
assert td["b"] == 3
  1. use _new_unsafe which is the new way of creating a TD without checks
    I would obv recommend against 2. lol

  2. I can run your code like this

compose_fn = TensorDict(
    {"x": lambda x, y: x + y, "y": torch.mul, "z": {"a": torch.add, "b": torch.mul}},
    # _run_checks=False,
)
out = x.apply(lambda x_, fn, y_: fn.data(x_, y_), compose_fn, y)

Hope that helps!

@ludwigwinkler
Copy link
Author

ludwigwinkler commented Aug 1, 2024

Thank you for the suggestions.

would you have suggestion how I could push the fn.data call into the iteration over the nodes of the FunctionDict?

I would imagine something like inheriting from TensorDict (class FunctionDict(TensorDict)) and overwriting the function that retrieves a particular node (which is a NonTensorData) in the FunctionDict and then does one additional .data unpacking step.

I looked through the v0.5 code and found this line

_others = [_other._get_str(key, default=None) for _other in others]
from which I gathered that I would maybe have to overwrite the _get_str() function. But I couldnt make it work, unfortunately. :( Would you have a suggestion?

class OpDict(TensorDict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _get_str(self, key, default):
        first_key = key
        out = self._tensordict.get(first_key, None)
        if out is None:
            return self._default_get(first_key, default).data
        return out.data

@vmoens
Copy link
Contributor

vmoens commented Aug 1, 2024

To me the simplest thing could be to add this to NonTensorData:

@tensorclass
class NonTensorData:
    def __call__(self, *args, **kwargs):
        return self.data(*args, **kwargs)

We could make that part of the lib, I don't think it'd be a problem

@ludwigwinkler
Copy link
Author

I patched it in on my local machine and it works very nicely. That'd be an ideal solution. :) Please make it part of the lib!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants