Skip to content

Commit

Permalink
ARROW-2917: [Python] Use detach() to avoid PyTorch gradient errors
Browse files Browse the repository at this point in the history
`detach()` doesn't copy data unless it has to and will give a RuntimeError if the detached data needs to have its gradient calculated.

Author: Wes McKinney <wesm+git@apache.org>
Author: Alok Singh <8325708+alok@users.noreply.github.com>

Closes #2311 from alok/patch-1 and squashes the following commits:

e451de8 <Wes McKinney> Add unit test serializing pytorch tensor requiring gradiant that fails on master
f8e298f <Alok Singh> Use detach() to avoid torch gradient errors
  • Loading branch information
wesm committed Jul 27, 2018
1 parent 47e462f commit fdc8e6a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/pyarrow/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def register_torch_serialization_handlers(serialization_context):
import torch

def _serialize_torch_tensor(obj):
return obj.numpy()
return obj.detach().numpy()

def _deserialize_torch_tensor(data):
return torch.from_numpy(data)
Expand Down
4 changes: 4 additions & 0 deletions python/pyarrow/tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ def test_torch_serialization(large_buffer):
serialization_roundtrip(obj, large_buffer,
context=serialization_context)

tensor_requiring_grad = torch.randn(10, 10, requires_grad=True)
serialization_roundtrip(tensor_requiring_grad, large_buffer,
context=serialization_context)


def test_numpy_immutable(large_buffer):
obj = np.zeros([10])
Expand Down

0 comments on commit fdc8e6a

Please sign in to comment.