diff --git a/CHANGELOG.md b/CHANGELOG.md index d457b7de9974..ad8f3484d5df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `is_torch_instance` to check against the original class of compiled models ([#8461](https://github.com/pyg-team/pytorch_geometric/pull/8461)) - Added dense computation for `AddRandomWalkPE` ([#8431](https://github.com/pyg-team/pytorch_geometric/pull/8431)) - Added a tutorial for point cloud processing ([#8015](https://github.com/pyg-team/pytorch_geometric/pull/8015)) - Added `fsspec` as file system backend ([#8379](https://github.com/pyg-team/pytorch_geometric/pull/8379), [#8426](https://github.com/pyg-team/pytorch_geometric/pull/8426), [#8434](https://github.com/pyg-team/pytorch_geometric/pull/8434)) diff --git a/pyproject.toml b/pyproject.toml index b252b5168925..16455501bc3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -151,6 +151,9 @@ filterwarnings = [ "ignore:Sparse CSR tensor support is in beta state:UserWarning", "ignore:Sparse CSC tensor support is in beta state:UserWarning", "ignore:torch.distributed._sharded_tensor will be deprecated:DeprecationWarning", + # Filter `torch.compile` warnings: + "ignore:pkg_resources is deprecated as an API", + "ignore:Deprecated call to `pkg_resources.declare_namespace", # Filter `captum` warnings: "ignore:Setting backward hooks on ReLU activations:UserWarning", "ignore:.*did not already require gradients, required_grads has been set automatically:UserWarning", diff --git a/test/data/test_edge_index.py b/test/data/test_edge_index.py index 518269851445..6e122e50c800 100644 --- a/test/data/test_edge_index.py +++ b/test/data/test_edge_index.py @@ -1,10 +1,20 @@ import os.path as osp +from typing import Optional import pytest import torch +from torch import Tensor +import torch_geometric from torch_geometric.data.edge_index import EdgeIndex -from torch_geometric.testing import onlyCUDA, withCUDA +from torch_geometric.testing import ( + disableExtensions, + onlyCUDA, + onlyLinux, + withCUDA, + withPackage, +) +from torch_geometric.utils import scatter def test_basic(): @@ -239,3 +249,71 @@ def test_data_loader(num_workers): assert isinstance(adj, EdgeIndex) assert adj.is_shared() == (num_workers > 0) assert adj._rowptr.is_shared() == (num_workers > 0) + + +def test_torch_script(): + class Model(torch.nn.Module): + def forward(self, x: Tensor, edge_index: EdgeIndex) -> Tensor: + row, col = edge_index[0], edge_index[1] + x_j = x[row] + out = scatter(x_j, col, dim_size=edge_index.num_cols) + return out + + x = torch.randn(3, 8) + # Test that `num_cols` gets picked up by making last node isolated. + edge_index = EdgeIndex([[0, 1, 1, 2], [1, 0, 0, 1]], sparse_size=(3, 3)) + + model = Model() + expected = model(x, edge_index) + assert expected.size() == (3, 8) + + # `torch.jit.script` does not support inheritance at the `Tensor` level :( + with pytest.raises(RuntimeError, match="attribute or method 'num_cols'"): + torch.jit.script(model) + + # A valid workaround is to treat `EdgeIndex` as a regular PyTorch tensor + # whenever we are in script mode: + class ScriptableModel(torch.nn.Module): + def forward(self, x: Tensor, edge_index: EdgeIndex) -> Tensor: + row, col = edge_index[0], edge_index[1] + x_j = x[row] + dim_size: Optional[int] = None + if (not torch.jit.is_scripting() + and isinstance(edge_index, EdgeIndex)): + dim_size = edge_index.num_cols + out = scatter(x_j, col, dim_size=dim_size) + return out + + script_model = torch.jit.script(ScriptableModel()) + out = script_model(x, edge_index) + assert out.size() == (2, 8) + assert torch.allclose(out, expected[:2]) + + +@onlyLinux +@disableExtensions +@withPackage('torch>=2.1.0') +def test_compile(): + import torch._dynamo as dynamo + + class Model(torch.nn.Module): + def forward(self, x: Tensor, edge_index: EdgeIndex) -> Tensor: + row, col = edge_index[0], edge_index[1] + x_j = x[row] + out = scatter(x_j, col, dim_size=edge_index.num_cols) + return out + + x = torch.randn(3, 8) + # Test that `num_cols` gets picked up by making last node isolated. + edge_index = EdgeIndex([[0, 1, 1, 2], [1, 0, 0, 1]], sparse_size=(3, 3)) + + model = Model() + expected = model(x, edge_index) + assert expected.size() == (3, 8) + + explanation = dynamo.explain(model)(x, edge_index) + assert explanation.graph_break_count <= 0 + + compiled_model = torch_geometric.compile(model) + out = compiled_model(x, edge_index) + assert torch.allclose(out, expected) diff --git a/test/test_isinstance.py b/test/test_isinstance.py new file mode 100644 index 000000000000..d2360cca780b --- /dev/null +++ b/test/test_isinstance.py @@ -0,0 +1,16 @@ +import torch + +from torch_geometric import is_torch_instance +from torch_geometric.testing import onlyLinux, withPackage + + +def test_basic(): + assert is_torch_instance(torch.nn.Linear(1, 1), torch.nn.Linear) + + +@onlyLinux +@withPackage('torch>=2.0.0') +def test_compile(): + model = torch.compile(torch.nn.Linear(1, 1)) + assert not isinstance(model, torch.nn.Linear) + assert is_torch_instance(model, torch.nn.Linear) diff --git a/torch_geometric/__init__.py b/torch_geometric/__init__.py index 7a9ed82ae1ea..cc3243c16acc 100644 --- a/torch_geometric/__init__.py +++ b/torch_geometric/__init__.py @@ -11,6 +11,7 @@ from .seed import seed_everything from .home import get_home_dir, set_home_dir from .compile import compile +from .isinstance import is_torch_instance from .debug import is_debug_enabled, debug, set_debug from .experimental import (is_experimental_mode_enabled, experimental_mode, set_experimental_mode) @@ -26,6 +27,7 @@ 'get_home_dir', 'set_home_dir', 'compile', + 'is_torch_instance', 'is_debug_enabled', 'debug', 'set_debug', diff --git a/torch_geometric/isinstance.py b/torch_geometric/isinstance.py new file mode 100644 index 000000000000..a64d74cd4ec8 --- /dev/null +++ b/torch_geometric/isinstance.py @@ -0,0 +1,25 @@ +from typing import Any, Tuple, Type, Union + +import torch + +import torch_geometric.typing + +if torch_geometric.typing.WITH_PT20: + import torch._dynamo + + +def is_torch_instance(obj: Any, cls: Union[Type, Tuple[Type]]) -> bool: + r"""Checks if the :obj:`obj` is an instance of a :obj:`cls`. + + This function extends :meth:`isinstance` to be applicable during + :meth:`torch.compile` usage by checking against the original class of + compiled models. + """ + # `torch.compile` removes the model inheritance and converts the model to + # a `torch._dynamo.OptimizedModule` instance, leading to `isinstance` being + # unable to check the model's inheritance. This function unwraps the + # compiled model before evaluating via `isinstance`. + if (torch_geometric.typing.WITH_PT20 + and isinstance(obj, torch._dynamo.OptimizedModule)): + return isinstance(obj._orig_mod, cls) + return isinstance(obj, cls)