-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add
DistMult
KGE model (#6958)
Achieves .40 Hits@10 on FB15K in example (.41 on validation), which is comparable to the [PapersWithCode](https://paperswithcode.com/paper/embedding-entities-and-relations-for-learning) metric value of .419 --------- Co-authored-by: David Kuo <davidekuo@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
- Loading branch information
1 parent
98abaae
commit d3043d8
Showing
7 changed files
with
121 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import torch | ||
|
||
from torch_geometric.nn import DistMult | ||
|
||
|
||
def test_distmult(): | ||
model = DistMult(num_nodes=10, num_relations=5, hidden_channels=32) | ||
assert str(model) == 'DistMult(10, num_relations=5, hidden_channels=32)' | ||
|
||
head_index = torch.tensor([0, 2, 4, 6, 8]) | ||
rel_type = torch.tensor([0, 1, 2, 3, 4]) | ||
tail_index = torch.tensor([1, 3, 5, 7, 9]) | ||
|
||
loader = model.loader(head_index, rel_type, tail_index, batch_size=5) | ||
for h, r, t in loader: | ||
out = model(h, r, t) | ||
assert out.size() == (5, ) | ||
|
||
loss = model.loss(h, r, t) | ||
assert loss >= 0. | ||
|
||
mean_rank, hits_at_10 = model.test(h, r, t, batch_size=5, log=False) | ||
assert mean_rank <= 10 | ||
assert hits_at_10 == 1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
from .base import KGEModel | ||
from .transe import TransE | ||
from .complex import ComplEx | ||
from .distmult import DistMult | ||
|
||
__all__ = classes = [ | ||
'KGEModel', | ||
'TransE', | ||
'ComplEx', | ||
'DistMult', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from torch import Tensor | ||
|
||
from torch_geometric.nn.kge import KGEModel | ||
|
||
|
||
class DistMult(KGEModel): | ||
r"""The DistMult model from the `"Embedding Entities and Relations for | ||
Learning and Inference in Knowledge Bases" | ||
<https://arxiv.org/abs/1412.6575>`_ paper. | ||
:class:`DistMult` models relations as diagonal matrices, which simplifies | ||
the bi-linear interaction between the head and tail entities to the score | ||
function: | ||
.. math:: | ||
d(h, r, t) = < \mathbf{e}_h, \mathbf{e}_r, \mathbf{e}_t > | ||
.. note:: | ||
For an example of using the :class:`DistMult` model, see | ||
`examples/kge_fb15k_237.py | ||
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ | ||
kge_fb15k_237.py>`_. | ||
Args: | ||
num_nodes (int): The number of nodes/entities in the graph. | ||
num_relations (int): The number of relations in the graph. | ||
hidden_channels (int): The hidden embedding size. | ||
margin (float, optional): The margin of the ranking loss. | ||
(default: :obj:`1.0`) | ||
sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to | ||
the embedding matrices will be sparse. (default: :obj:`False`) | ||
""" | ||
def __init__( | ||
self, | ||
num_nodes: int, | ||
num_relations: int, | ||
hidden_channels: int, | ||
margin: float = 1.0, | ||
sparse: bool = False, | ||
): | ||
super().__init__(num_nodes, num_relations, hidden_channels, sparse) | ||
|
||
self.margin = margin | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
torch.nn.init.xavier_uniform_(self.node_emb.weight) | ||
torch.nn.init.xavier_uniform_(self.rel_emb.weight) | ||
|
||
def forward( | ||
self, | ||
head_index: Tensor, | ||
rel_type: Tensor, | ||
tail_index: Tensor, | ||
) -> Tensor: | ||
|
||
head = self.node_emb(head_index) | ||
rel = self.rel_emb(rel_type) | ||
tail = self.node_emb(tail_index) | ||
|
||
return (head * rel * tail).sum(dim=-1) | ||
|
||
def loss( | ||
self, | ||
head_index: Tensor, | ||
rel_type: Tensor, | ||
tail_index: Tensor, | ||
) -> Tensor: | ||
|
||
pos_score = self(head_index, rel_type, tail_index) | ||
neg_score = self(*self.random_sample(head_index, rel_type, tail_index)) | ||
|
||
return F.margin_ranking_loss( | ||
pos_score, | ||
neg_score, | ||
target=torch.ones_like(pos_score), | ||
margin=self.margin, | ||
) |