Skip to content

Commit

Permalink
Add EdgeIndex tests for torch.compile and torch.jit.script; Add…
Browse files Browse the repository at this point in the history
… `is_torch_instance` (#8461)
  • Loading branch information
rusty1s authored Nov 28, 2023
1 parent 2ab23f6 commit 6336cdf
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `is_torch_instance` to check against the original class of compiled models ([#8461](https://github.com/pyg-team/pytorch_geometric/pull/8461))
- Added dense computation for `AddRandomWalkPE` ([#8431](https://github.com/pyg-team/pytorch_geometric/pull/8431))
- Added a tutorial for point cloud processing ([#8015](https://github.com/pyg-team/pytorch_geometric/pull/8015))
- Added `fsspec` as file system backend ([#8379](https://github.com/pyg-team/pytorch_geometric/pull/8379), [#8426](https://github.com/pyg-team/pytorch_geometric/pull/8426), [#8434](https://github.com/pyg-team/pytorch_geometric/pull/8434))
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ filterwarnings = [
"ignore:Sparse CSR tensor support is in beta state:UserWarning",
"ignore:Sparse CSC tensor support is in beta state:UserWarning",
"ignore:torch.distributed._sharded_tensor will be deprecated:DeprecationWarning",
# Filter `torch.compile` warnings:
"ignore:pkg_resources is deprecated as an API",
"ignore:Deprecated call to `pkg_resources.declare_namespace",
# Filter `captum` warnings:
"ignore:Setting backward hooks on ReLU activations:UserWarning",
"ignore:.*did not already require gradients, required_grads has been set automatically:UserWarning",
Expand Down
80 changes: 79 additions & 1 deletion test/data/test_edge_index.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import os.path as osp
from typing import Optional

import pytest
import torch
from torch import Tensor

import torch_geometric
from torch_geometric.data.edge_index import EdgeIndex
from torch_geometric.testing import onlyCUDA, withCUDA
from torch_geometric.testing import (
disableExtensions,
onlyCUDA,
onlyLinux,
withCUDA,
withPackage,
)
from torch_geometric.utils import scatter


def test_basic():
Expand Down Expand Up @@ -239,3 +249,71 @@ def test_data_loader(num_workers):
assert isinstance(adj, EdgeIndex)
assert adj.is_shared() == (num_workers > 0)
assert adj._rowptr.is_shared() == (num_workers > 0)


def test_torch_script():
class Model(torch.nn.Module):
def forward(self, x: Tensor, edge_index: EdgeIndex) -> Tensor:
row, col = edge_index[0], edge_index[1]
x_j = x[row]
out = scatter(x_j, col, dim_size=edge_index.num_cols)
return out

x = torch.randn(3, 8)
# Test that `num_cols` gets picked up by making last node isolated.
edge_index = EdgeIndex([[0, 1, 1, 2], [1, 0, 0, 1]], sparse_size=(3, 3))

model = Model()
expected = model(x, edge_index)
assert expected.size() == (3, 8)

# `torch.jit.script` does not support inheritance at the `Tensor` level :(
with pytest.raises(RuntimeError, match="attribute or method 'num_cols'"):
torch.jit.script(model)

# A valid workaround is to treat `EdgeIndex` as a regular PyTorch tensor
# whenever we are in script mode:
class ScriptableModel(torch.nn.Module):
def forward(self, x: Tensor, edge_index: EdgeIndex) -> Tensor:
row, col = edge_index[0], edge_index[1]
x_j = x[row]
dim_size: Optional[int] = None
if (not torch.jit.is_scripting()
and isinstance(edge_index, EdgeIndex)):
dim_size = edge_index.num_cols
out = scatter(x_j, col, dim_size=dim_size)
return out

script_model = torch.jit.script(ScriptableModel())
out = script_model(x, edge_index)
assert out.size() == (2, 8)
assert torch.allclose(out, expected[:2])


@onlyLinux
@disableExtensions
@withPackage('torch>=2.1.0')
def test_compile():
import torch._dynamo as dynamo

class Model(torch.nn.Module):
def forward(self, x: Tensor, edge_index: EdgeIndex) -> Tensor:
row, col = edge_index[0], edge_index[1]
x_j = x[row]
out = scatter(x_j, col, dim_size=edge_index.num_cols)
return out

x = torch.randn(3, 8)
# Test that `num_cols` gets picked up by making last node isolated.
edge_index = EdgeIndex([[0, 1, 1, 2], [1, 0, 0, 1]], sparse_size=(3, 3))

model = Model()
expected = model(x, edge_index)
assert expected.size() == (3, 8)

explanation = dynamo.explain(model)(x, edge_index)
assert explanation.graph_break_count <= 0

compiled_model = torch_geometric.compile(model)
out = compiled_model(x, edge_index)
assert torch.allclose(out, expected)
16 changes: 16 additions & 0 deletions test/test_isinstance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch

from torch_geometric import is_torch_instance
from torch_geometric.testing import onlyLinux, withPackage


def test_basic():
assert is_torch_instance(torch.nn.Linear(1, 1), torch.nn.Linear)


@onlyLinux
@withPackage('torch>=2.0.0')
def test_compile():
model = torch.compile(torch.nn.Linear(1, 1))
assert not isinstance(model, torch.nn.Linear)
assert is_torch_instance(model, torch.nn.Linear)
2 changes: 2 additions & 0 deletions torch_geometric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .seed import seed_everything
from .home import get_home_dir, set_home_dir
from .compile import compile
from .isinstance import is_torch_instance
from .debug import is_debug_enabled, debug, set_debug
from .experimental import (is_experimental_mode_enabled, experimental_mode,
set_experimental_mode)
Expand All @@ -26,6 +27,7 @@
'get_home_dir',
'set_home_dir',
'compile',
'is_torch_instance',
'is_debug_enabled',
'debug',
'set_debug',
Expand Down
25 changes: 25 additions & 0 deletions torch_geometric/isinstance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Any, Tuple, Type, Union

import torch

import torch_geometric.typing

if torch_geometric.typing.WITH_PT20:
import torch._dynamo


def is_torch_instance(obj: Any, cls: Union[Type, Tuple[Type]]) -> bool:
r"""Checks if the :obj:`obj` is an instance of a :obj:`cls`.
This function extends :meth:`isinstance` to be applicable during
:meth:`torch.compile` usage by checking against the original class of
compiled models.
"""
# `torch.compile` removes the model inheritance and converts the model to
# a `torch._dynamo.OptimizedModule` instance, leading to `isinstance` being
# unable to check the model's inheritance. This function unwraps the
# compiled model before evaluating via `isinstance`.
if (torch_geometric.typing.WITH_PT20
and isinstance(obj, torch._dynamo.OptimizedModule)):
return isinstance(obj._orig_mod, cls)
return isinstance(obj, cls)

0 comments on commit 6336cdf

Please sign in to comment.