Skip to content

Commit

Permalink
make ProvenanceTensor behave more like a Tensor (closes #3218) (#3220)
Browse files Browse the repository at this point in the history
* make ProvenanceTensor behave more like a Tensor (closes #3218)

the data is now stored in the Tensor itself instead of an attribute.
This fixes torch.to_tensor returning empty tensors when called with a
ProvenanceTensor and and a device as arguments

* fix compatibility with PyTorch 1.11

* make detach_provenance always return the exact same object

this is important when using Tensors as keys in a dict, e.g.
the Pyro param store

* preserve .unconstrained attribute in detach_provenance

* simplify

* add unit test

* simplify further

* use Tensor.as_subclass instead of modifying __class__

also remove unnecessary check in __init__
  • Loading branch information
ilia-kats authored May 28, 2023
1 parent 89a56ea commit 831c463
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 5 deletions.
11 changes: 6 additions & 5 deletions pyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@ def __new__(cls, data: torch.Tensor, provenance=frozenset(), **kwargs):
assert not isinstance(data, ProvenanceTensor)
if not provenance:
return data
return super().__new__(cls)
ret = data.as_subclass(cls)
ret._t = data # this makes sure that detach_provenance always
# returns the same object. This is important when
# using the tensor as key in a dict, e.g. the global
# param store
return ret

def __init__(self, data, provenance=frozenset()):
assert isinstance(provenance, frozenset)
if isinstance(data, ProvenanceTensor):
provenance |= data._provenance
data = data._t
self._t = data
self._provenance = provenance

def __repr__(self):
Expand Down
45 changes: 45 additions & 0 deletions tests/ops/test_provenance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from pyro.ops.provenance import ProvenanceTensor
from tests.common import assert_equal, requires_cuda


@requires_cuda
@pytest.mark.parametrize(
"dtype1",
[
torch.float16,
torch.float32,
torch.float64,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
],
)
@pytest.mark.parametrize(
"dtype2",
[
torch.float16,
torch.float32,
torch.float64,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
],
)
def test_provenance_tensor(dtype1, dtype2):
device = torch.device("cuda")
x = torch.tensor([1, 2, 3], dtype=dtype1)
y = ProvenanceTensor(x, frozenset(["x"]))
z = torch.as_tensor(y, device=device, dtype=dtype2)

assert x.shape == y.shape == z.shape
assert_equal(x, z.cpu())

0 comments on commit 831c463

Please sign in to comment.