-
Notifications
You must be signed in to change notification settings - Fork 79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] functiondict/opdict: TensorDict of Callables for apply() #937
Comments
Interesting, never thought of that On the top of my head, here are a couple of things I would suggest:
dict_func = defaultdict(lambda: lambda x, y: x+y)
dict_func["a"] = lambda x,y: x-y
td0 = TensorDict(a=1, b=1, c=1)
td1 = TensorDict(a=2, b=2, c=2)
td = td0.named_apply(lambda name, x, y: dict_func[name](x, y), td1)
assert td["a"] == -1
assert td["b"] == 3
compose_fn = TensorDict(
{"x": lambda x, y: x + y, "y": torch.mul, "z": {"a": torch.add, "b": torch.mul}},
# _run_checks=False,
)
out = x.apply(lambda x_, fn, y_: fn.data(x_, y_), compose_fn, y) Hope that helps! |
Thank you for the suggestions. would you have suggestion how I could push the I would imagine something like inheriting from I looked through the v0.5 code and found this line Line 1314 in 37feb13
_get_str() function. But I couldnt make it work, unfortunately. :( Would you have a suggestion?
|
To me the simplest thing could be to add this to @tensorclass
class NonTensorData:
def __call__(self, *args, **kwargs):
return self.data(*args, **kwargs) We could make that part of the lib, I don't think it'd be a problem |
I patched it in on my local machine and it works very nicely. That'd be an ideal solution. :) Please make it part of the lib! |
Motivation
I would like to define a TensorDict of callables, which I can then apply elementwise as functions to a TensorDict of Tensors.
Solution
Until recently in tensordict==0.4, I could create this behavior with the hacky
_run_checks=False
which would prevent Tensordict to wrap the callables inNonTensorData
.I create two tensordicts
x
andy
of data and one tensordict of functions.Then I use
apply()
to combinex
andy
according to the tensordict of functions:Now in tensordict==0.5 this not possible anymore.
Is there perhaps a more elegant and abstracted way of achieving such tensordicts of functions?
Alternatives
I know that one can access
NonTensorData
by usingtolist()
which will return the desired function if there is only a single element per node, but it's not the most elegant solution.The text was updated successfully, but these errors were encountered: