Skip to content

Commit

Permalink
feat: add DistMult KGE model (#6958)
Browse files Browse the repository at this point in the history
Achieves .40 Hits@10 on FB15K in example (.41 on validation), which is
comparable to the
[PapersWithCode](https://paperswithcode.com/paper/embedding-entities-and-relations-for-learning)
metric value of .419

---------

Co-authored-by: David Kuo <davidekuo@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
4 people authored Mar 19, 2023
1 parent 98abaae commit d3043d8
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `DistMult` KGE model ([#6958](https://github.com/pyg-team/pytorch_geometric/pull/6958))
- Added `HeteroData.set_value_dict` functionality ([#6961](https://github.com/pyg-team/pytorch_geometric/pull/6961))
- Added PyTorch >= 2.0 support ([#6934](https://github.com/pyg-team/pytorch_geometric/pull/6934))
- Added PyTorch Lightning >= 2.0 support ([#6929](https://github.com/pyg-team/pytorch_geometric/pull/6929))
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ Unlike simple stacking of GNN layers, these models could involve pre-processing,
New Benchmarks and Strong Simple Methods](https://arxiv.org/abs/2110.14446) (NeurIPS 2021) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/linkx.py)]
* **[RevGNN](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GroupAddRev.html)** from Li *et al.*: [Training Graph Neural with 1000 Layers](https://arxiv.org/abs/2106.07476) (ICML 2021) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/rev_gnn.py)]
* **[TransE](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.kge.TransE.html)** from Bordes *et al.*: [Translating Embeddings for Modeling Multi-Relational Data](https://proceedings.neurips.cc/paper/2013/file/1cecc7a77928ca8133fa24680a88d2f9-Paper.pdf) (NIPS 2013) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)]
* **[ComplEx](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.kge.ComplEx.html)** from Trouillon *et al.*: [Complex Embeddings for Simple Link Prediction](https://arxiv.org/pdf/1606.06357.pdf) (ICML 2016) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)]
* **[ComplEx](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.kge.ComplEx.html)** from Trouillon *et al.*: [Complex Embeddings for Simple Link Prediction](https://arxiv.org/abs/1606.06357) (ICML 2016) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)]
* **[DistMult](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.kge.DistMult.html)** from Yang *et al.*: [Embedding Entities and Relations for Learning and Inference in Knowledge Bases](https://arxiv.org/abs/1412.6575) (ICLR 2015) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)]
</details>

**GNN operators and utilities:**
Expand Down
13 changes: 9 additions & 4 deletions examples/kge_fb15k_237.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
import torch.optim as optim

from torch_geometric.datasets import FB15k_237
from torch_geometric.nn import ComplEx, TransE
from torch_geometric.nn import ComplEx, DistMult, TransE

model_map = {'transe': TransE, 'complex': ComplEx}
model_map = {
'transe': TransE,
'complex': ComplEx,
'distmult': DistMult,
}

parser = argparse.ArgumentParser()
parser.add_argument('--model', choices=model_map.keys(), type=str.lower,
Expand Down Expand Up @@ -36,8 +40,9 @@
)

optimizer_map = {
'transe': torch.optim.Adam(model.parameters(), lr=0.01),
'complex': optim.Adagrad(model.parameters(), lr=0.001, weight_decay=1e-6)
'transe': optim.Adam(model.parameters(), lr=0.01),
'complex': optim.Adagrad(model.parameters(), lr=0.001, weight_decay=1e-6),
'distmult': optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-6),
}
optimizer = optimizer_map[args.model]

Expand Down
24 changes: 24 additions & 0 deletions test/nn/kge/test_distmult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch

from torch_geometric.nn import DistMult


def test_distmult():
model = DistMult(num_nodes=10, num_relations=5, hidden_channels=32)
assert str(model) == 'DistMult(10, num_relations=5, hidden_channels=32)'

head_index = torch.tensor([0, 2, 4, 6, 8])
rel_type = torch.tensor([0, 1, 2, 3, 4])
tail_index = torch.tensor([1, 3, 5, 7, 9])

loader = model.loader(head_index, rel_type, tail_index, batch_size=5)
for h, r, t in loader:
out = model(h, r, t)
assert out.size() == (5, )

loss = model.loss(h, r, t)
assert loss >= 0.

mean_rank, hits_at_10 = model.test(h, r, t, batch_size=5, log=False)
assert mean_rank <= 10
assert hits_at_10 == 1.0
2 changes: 2 additions & 0 deletions torch_geometric/nn/kge/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .base import KGEModel
from .transe import TransE
from .complex import ComplEx
from .distmult import DistMult

__all__ = classes = [
'KGEModel',
'TransE',
'ComplEx',
'DistMult',
]
2 changes: 1 addition & 1 deletion torch_geometric/nn/kge/complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class ComplEx(KGEModel):
r"""The ComplEx model from the `"Complex Embeddings for Simple Link
Prediction" <https://arxiv.org/pdf/1606.06357.pdf>`_ paper.
Prediction" <https://arxiv.org/abs/1606.06357>`_ paper.
:class:`ComplEx` models relations as complex-valued bilinear mappings
between head and tail entities using the Hermetian dot product.
Expand Down
82 changes: 82 additions & 0 deletions torch_geometric/nn/kge/distmult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import torch.nn.functional as F
from torch import Tensor

from torch_geometric.nn.kge import KGEModel


class DistMult(KGEModel):
r"""The DistMult model from the `"Embedding Entities and Relations for
Learning and Inference in Knowledge Bases"
<https://arxiv.org/abs/1412.6575>`_ paper.
:class:`DistMult` models relations as diagonal matrices, which simplifies
the bi-linear interaction between the head and tail entities to the score
function:
.. math::
d(h, r, t) = < \mathbf{e}_h, \mathbf{e}_r, \mathbf{e}_t >
.. note::
For an example of using the :class:`DistMult` model, see
`examples/kge_fb15k_237.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
kge_fb15k_237.py>`_.
Args:
num_nodes (int): The number of nodes/entities in the graph.
num_relations (int): The number of relations in the graph.
hidden_channels (int): The hidden embedding size.
margin (float, optional): The margin of the ranking loss.
(default: :obj:`1.0`)
sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to
the embedding matrices will be sparse. (default: :obj:`False`)
"""
def __init__(
self,
num_nodes: int,
num_relations: int,
hidden_channels: int,
margin: float = 1.0,
sparse: bool = False,
):
super().__init__(num_nodes, num_relations, hidden_channels, sparse)

self.margin = margin

self.reset_parameters()

def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.node_emb.weight)
torch.nn.init.xavier_uniform_(self.rel_emb.weight)

def forward(
self,
head_index: Tensor,
rel_type: Tensor,
tail_index: Tensor,
) -> Tensor:

head = self.node_emb(head_index)
rel = self.rel_emb(rel_type)
tail = self.node_emb(tail_index)

return (head * rel * tail).sum(dim=-1)

def loss(
self,
head_index: Tensor,
rel_type: Tensor,
tail_index: Tensor,
) -> Tensor:

pos_score = self(head_index, rel_type, tail_index)
neg_score = self(*self.random_sample(head_index, rel_type, tail_index))

return F.margin_ranking_loss(
pos_score,
neg_score,
target=torch.ones_like(pos_score),
margin=self.margin,
)

0 comments on commit d3043d8

Please sign in to comment.