Skip to content

Commit

Permalink
Add TorchScript support for Node2Vec (#6726)
Browse files Browse the repository at this point in the history
The original jit test of `Node2Vec` model is not really using
TorchScript, just fix this.

---------

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
ftxj and rusty1s authored Feb 17, 2023
1 parent c45f8db commit a0ffd6f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added TorchScript support to the `Node2Vec` model ([#6726](https://github.com/pyg-team/pytorch_geometric/pull/6726))
- Added `utils.to_edge_index` to convert sparse tensors to edge indices and edge attributes ([#6728](https://github.com/pyg-team/pytorch_geometric/issues/6728))
- Fixed expected data format in `PolBlogs` dataset ([#6714](https://github.com/pyg-team/pytorch_geometric/issues/6714))
- Added `SimpleConv` to perform non-trainable propagation ([#6718](https://github.com/pyg-team/pytorch_geometric/pull/6718))
Expand Down
2 changes: 1 addition & 1 deletion test/nn/models/test_node2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_node2vec():
assert 0 <= acc and acc <= 1

if is_full_test():
jit = torch.jit.export(model)
jit = torch.jit.script(model)

assert jit(torch.arange(3)).size() == (3, 16)

Expand Down
12 changes: 7 additions & 5 deletions torch_geometric/nn/models/node2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
except ImportError:
random_walk = None

EPS = 1e-15


class Node2Vec(torch.nn.Module):
r"""The Node2Vec model from the
Expand Down Expand Up @@ -71,7 +69,7 @@ def __init__(
row, col = edge_index
self.adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N))
self.adj = self.adj.to('cpu')

self.EPS = 1e-15
assert walk_length >= context_size

self.embedding_dim = embedding_dim
Expand Down Expand Up @@ -99,6 +97,7 @@ def loader(self, **kwargs) -> DataLoader:
return DataLoader(range(self.adj.sparse_size(0)),
collate_fn=self.sample, **kwargs)

@torch.jit.export
def pos_sample(self, batch: Tensor) -> Tensor:
batch = batch.repeat(self.walks_per_node)
rowptr, col, _ = self.adj.csr()
Expand All @@ -112,6 +111,7 @@ def pos_sample(self, batch: Tensor) -> Tensor:
walks.append(rw[:, j:j + self.context_size])
return torch.cat(walks, dim=0)

@torch.jit.export
def neg_sample(self, batch: Tensor) -> Tensor:
batch = batch.repeat(self.walks_per_node * self.num_negative_samples)

Expand All @@ -125,11 +125,13 @@ def neg_sample(self, batch: Tensor) -> Tensor:
walks.append(rw[:, j:j + self.context_size])
return torch.cat(walks, dim=0)

@torch.jit.export
def sample(self, batch: Tensor) -> Tuple[Tensor, Tensor]:
if not isinstance(batch, Tensor):
batch = torch.tensor(batch)
return self.pos_sample(batch), self.neg_sample(batch)

@torch.jit.export
def loss(self, pos_rw: Tensor, neg_rw: Tensor) -> Tensor:
r"""Computes the loss given positive and negative random walks."""

Expand All @@ -142,7 +144,7 @@ def loss(self, pos_rw: Tensor, neg_rw: Tensor) -> Tensor:
self.embedding_dim)

out = (h_start * h_rest).sum(dim=-1).view(-1)
pos_loss = -torch.log(torch.sigmoid(out) + EPS).mean()
pos_loss = -torch.log(torch.sigmoid(out) + self.EPS).mean()

# Negative loss.
start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous()
Expand All @@ -153,7 +155,7 @@ def loss(self, pos_rw: Tensor, neg_rw: Tensor) -> Tensor:
self.embedding_dim)

out = (h_start * h_rest).sum(dim=-1).view(-1)
neg_loss = -torch.log(1 - torch.sigmoid(out) + EPS).mean()
neg_loss = -torch.log(1 - torch.sigmoid(out) + self.EPS).mean()

return pos_loss + neg_loss

Expand Down

0 comments on commit a0ffd6f

Please sign in to comment.