From fa47205513b3806d40681e711ee4c0fd1c3149eb Mon Sep 17 00:00:00 2001 From: "padarn.wilson" Date: Sat, 18 Jun 2022 12:48:02 +0800 Subject: [PATCH] rebase to fix tests --- test/nn/aggr/test_equilibrium.py | 4 ++-- torch_geometric/nn/aggr/equilibrium.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/test/nn/aggr/test_equilibrium.py b/test/nn/aggr/test_equilibrium.py index 92f926020ef3..ceea61d6c303 100644 --- a/test/nn/aggr/test_equilibrium.py +++ b/test/nn/aggr/test_equilibrium.py @@ -19,7 +19,7 @@ def test_equilibrium(iter, alpha): out = model(x) assert out.size() == (1, 2) - with pytest.raises(NotImplementedError): + with pytest.raises(ValueError): model(x, dim_size=0) out = model(x, dim_size=3) @@ -45,7 +45,7 @@ def test_equilibrium_batch(iter, alpha): out = model(x, batch) assert out.size() == (2, 2) - with pytest.raises(NotImplementedError): + with pytest.raises(ValueError): model(x, dim_size=0) out = model(x, dim_size=3) diff --git a/torch_geometric/nn/aggr/equilibrium.py b/torch_geometric/nn/aggr/equilibrium.py index f1b823ab5397..23358c4437ef 100644 --- a/torch_geometric/nn/aggr/equilibrium.py +++ b/torch_geometric/nn/aggr/equilibrium.py @@ -167,15 +167,14 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *, dim: int = -2) -> Tensor: if ptr is not None: - raise NotImplementedError( - f"{self.__class__} doesn't support `ptr`") + raise ValueError(f"{self.__class__} doesn't support `ptr`") index_size = 1 if index is None else index.max() + 1 dim_size = index_size if dim_size is None else dim_size if dim_size < index_size: - raise NotImplementedError("`dim_size` is less than `index` " - "implied size") + raise ValueError("`dim_size` is less than `index` " + "implied size") with torch.enable_grad(): y = self.optimizer(x, self.init_output(index), index, self.energy,