Skip to content

Commit

Permalink
Merge pull request #5 from pinellolab/add_edge_weight
Browse files Browse the repository at this point in the history
Add support of edge weights
  • Loading branch information
huidongchen authored Oct 12, 2022
2 parents 0e6f837 + adfcd97 commit e0d3bbf
Show file tree
Hide file tree
Showing 10 changed files with 319 additions and 78 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# Installation

To install the customized PyTorch-BigGraph (PBG) package [`simba_pbg`](https://anaconda.org/bioconda/simba_pbg) for [`simba`](https://github.com/pinellolab/simba),

```bash
conda install -c bioconda simba_pbg
```


# ![PyTorch-BigGraph](docs/source/_static/logo_color.svg)

[![CircleCI Status](https://circleci.com/gh/facebookresearch/PyTorch-BigGraph.svg?style=svg)](https://circleci.com/gh/facebookresearch/PyTorch-BigGraph) [![Documentation Status](https://readthedocs.org/projects/torchbiggraph/badge/?version=latest)](https://torchbiggraph.readthedocs.io/en/latest/?badge=latest)
Expand Down
96 changes: 81 additions & 15 deletions test/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,41 +47,63 @@ def test_forward(self):
requires_grad=True,
)
loss_fn = LogisticLossFunction()
loss = loss_fn(pos_scores, neg_scores)
loss = loss_fn(pos_scores, neg_scores, None)
self.assertTensorEqual(loss, torch.tensor(4.2589))
loss.backward()
self.assertTrue((pos_scores.grad != 0).any())
self.assertTrue((neg_scores.grad != 0).any())

def test_forward_weight(self):
pos_scores = torch.tensor([0.8181, 0.5700, 0.3506], requires_grad=True)
neg_scores = torch.tensor(
[
[0.4437, 0.6573, 0.9986, 0.2548, 0.0998],
[0.6175, 0.4061, 0.4582, 0.5382, 0.3126],
[0.9869, 0.2028, 0.1667, 0.0044, 0.9934],
],
requires_grad=True,
)
weight = torch.full((3,), 1.23)
loss_fn = LogisticLossFunction()
loss = loss_fn(pos_scores, neg_scores, weight)
self.assertTensorEqual(loss, torch.tensor(4.2589 * 1.23))
loss.backward()
self.assertTrue((pos_scores.grad != 0).any())
self.assertTrue((neg_scores.grad != 0).any())

weight = torch.tensor([0.2, 0.4, 0.0])
loss = loss_fn(pos_scores, neg_scores, weight)
self.assertTensorEqual(loss, torch.tensor(0.8302))

def test_forward_good(self):
pos_scores = torch.full((3,), +1e9, requires_grad=True)
neg_scores = torch.full((3, 5), -1e9, requires_grad=True)
loss_fn = LogisticLossFunction()
loss = loss_fn(pos_scores, neg_scores)
loss = loss_fn(pos_scores, neg_scores, None)
self.assertTensorEqual(loss, torch.zeros(()))
loss.backward()

def test_forward_bad(self):
pos_scores = torch.full((3,), -1e9, requires_grad=True)
neg_scores = torch.full((3, 5), +1e9, requires_grad=True)
loss_fn = LogisticLossFunction()
loss = loss_fn(pos_scores, neg_scores)
loss = loss_fn(pos_scores, neg_scores, None)
self.assertTensorEqual(loss, torch.tensor(6e9))
loss.backward()

def test_no_neg(self):
pos_scores = torch.zeros((3,), requires_grad=True)
neg_scores = torch.empty((3, 0), requires_grad=True)
loss_fn = LogisticLossFunction()
loss = loss_fn(pos_scores, neg_scores)
loss = loss_fn(pos_scores, neg_scores, None)
self.assertTensorEqual(loss, torch.tensor(2.0794))
loss.backward()

def test_no_pos(self):
pos_scores = torch.empty((0,), requires_grad=True)
neg_scores = torch.empty((0, 0), requires_grad=True)
loss_fn = LogisticLossFunction()
loss = loss_fn(pos_scores, neg_scores)
loss = loss_fn(pos_scores, neg_scores, None)
self.assertTensorEqual(loss, torch.zeros(()))
loss.backward()

Expand All @@ -98,41 +120,63 @@ def test_forward(self):
requires_grad=True,
)
loss_fn = RankingLossFunction(margin=1.0)
loss = loss_fn(pos_scores, neg_scores)
loss = loss_fn(pos_scores, neg_scores, None)
self.assertTensorEqual(loss, torch.tensor(13.4475))
loss.backward()
self.assertTrue((pos_scores.grad != 0).any())
self.assertTrue((neg_scores.grad != 0).any())

def test_forward_weight(self):
pos_scores = torch.tensor([0.8181, 0.5700, 0.3506], requires_grad=True)
neg_scores = torch.tensor(
[
[0.4437, 0.6573, 0.9986, 0.2548, 0.0998],
[0.6175, 0.4061, 0.4582, 0.5382, 0.3126],
[0.9869, 0.2028, 0.1667, 0.0044, 0.9934],
],
requires_grad=True,
)
weight = torch.full((3,), 1.23)
loss_fn = RankingLossFunction(margin=1.0)
loss = loss_fn(pos_scores, neg_scores, weight)
self.assertTensorEqual(loss, torch.tensor(13.4475 * 1.23))
loss.backward()
self.assertTrue((pos_scores.grad != 0).any())
self.assertTrue((neg_scores.grad != 0).any())

weight = torch.tensor([0.2, 0.4, 0.0])
loss = loss_fn(pos_scores, neg_scores, weight)
self.assertTensorEqual(loss, torch.tensor(2.4658))

def test_forward_good(self):
pos_scores = torch.full((3,), 2.0, requires_grad=True)
neg_scores = torch.full((3, 5), 1.0, requires_grad=True)
loss_fn = RankingLossFunction(margin=1.0)
loss = loss_fn(pos_scores, neg_scores)
loss = loss_fn(pos_scores, neg_scores, None)
self.assertTensorEqual(loss, torch.zeros(()))
loss.backward()

def test_forward_bad(self):
pos_scores = torch.full((3,), -1.0, requires_grad=True)
neg_scores = torch.zeros((3, 5), requires_grad=True)
loss_fn = RankingLossFunction(margin=1.0)
loss = loss_fn(pos_scores, neg_scores)
loss = loss_fn(pos_scores, neg_scores, None)
self.assertTensorEqual(loss, torch.tensor(30.0))
loss.backward()

def test_no_neg(self):
pos_scores = torch.zeros((3,), requires_grad=True)
neg_scores = torch.empty((3, 0), requires_grad=True)
loss_fn = RankingLossFunction(margin=1.0)
loss = loss_fn(pos_scores, neg_scores)
loss = loss_fn(pos_scores, neg_scores, None)
self.assertTensorEqual(loss, torch.zeros(()))
loss.backward()

def test_no_pos(self):
pos_scores = torch.empty((0,), requires_grad=True)
neg_scores = torch.empty((0, 3), requires_grad=True)
loss_fn = RankingLossFunction(margin=1.0)
loss = loss_fn(pos_scores, neg_scores)
loss = loss_fn(pos_scores, neg_scores, None)
self.assertTensorEqual(loss, torch.zeros(()))
loss.backward()

Expand All @@ -149,41 +193,63 @@ def test_forward(self):
requires_grad=True,
)
loss_fn = SoftmaxLossFunction()
loss = loss_fn(pos_scores, neg_scores)
loss = loss_fn(pos_scores, neg_scores, None)
self.assertTensorEqual(loss, torch.tensor(5.2513))
loss.backward()
self.assertTrue((pos_scores.grad != 0).any())
self.assertTrue((neg_scores.grad != 0).any())

def test_forward_weight(self):
pos_scores = torch.tensor([0.8181, 0.5700, 0.3506], requires_grad=True)
neg_scores = torch.tensor(
[
[0.4437, 0.6573, 0.9986, 0.2548, 0.0998],
[0.6175, 0.4061, 0.4582, 0.5382, 0.3126],
[0.9869, 0.2028, 0.1667, 0.0044, 0.9934],
],
requires_grad=True,
)
weight = torch.full((3,), 1.23)
loss_fn = SoftmaxLossFunction()
loss = loss_fn(pos_scores, neg_scores, weight)
self.assertTensorEqual(loss, torch.tensor(5.2513 * 1.23))
loss.backward()
self.assertTrue((pos_scores.grad != 0).any())
self.assertTrue((neg_scores.grad != 0).any())

weight = torch.tensor([0.2, 0.4, 0.0])
loss = loss_fn(pos_scores, neg_scores, weight)
self.assertTensorEqual(loss, torch.tensor(0.9978))

def test_forward_good(self):
pos_scores = torch.full((3,), +1e9, requires_grad=True)
neg_scores = torch.full((3, 5), -1e9, requires_grad=True)
loss_fn = SoftmaxLossFunction()
loss = loss_fn(pos_scores, neg_scores)
loss = loss_fn(pos_scores, neg_scores, None)
self.assertTensorEqual(loss, torch.zeros(()))
loss.backward()

def test_forward_bad(self):
pos_scores = torch.full((3,), -1e9, requires_grad=True)
neg_scores = torch.full((3, 5), +1e9, requires_grad=True)
loss_fn = SoftmaxLossFunction()
loss = loss_fn(pos_scores, neg_scores)
loss = loss_fn(pos_scores, neg_scores, None)
self.assertTensorEqual(loss, torch.tensor(6e9))
loss.backward()

def test_no_neg(self):
pos_scores = torch.zeros((3,), requires_grad=True)
neg_scores = torch.empty((3, 0), requires_grad=True)
loss_fn = SoftmaxLossFunction()
loss = loss_fn(pos_scores, neg_scores)
loss = loss_fn(pos_scores, neg_scores, None)
self.assertTensorEqual(loss, torch.zeros(()))
loss.backward()

def test_no_pos(self):
pos_scores = torch.empty((0,), requires_grad=True)
neg_scores = torch.empty((0, 3), requires_grad=True)
loss_fn = SoftmaxLossFunction()
loss = loss_fn(pos_scores, neg_scores)
loss = loss_fn(pos_scores, neg_scores, None)
self.assertTensorEqual(loss, torch.zeros(()))
loss.backward()

Expand Down
11 changes: 8 additions & 3 deletions torchbiggraph/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ def group_by_relation_type(edges: EdgeList) -> List[EdgeList]:
rel_type = sorted_rel[start]
edges_for_rel_type = edges[order[start:end]]
result.append(
EdgeList(edges_for_rel_type.lhs, edges_for_rel_type.rhs, rel_type)
EdgeList(
edges_for_rel_type.lhs,
edges_for_rel_type.rhs,
rel_type,
edges_for_rel_type.weight,
)
)
return result

Expand Down Expand Up @@ -133,8 +138,8 @@ def __init__(

def calc_loss(self, scores: Scores, batch_edges: EdgeList):

lhs_loss = self.loss_fn(scores.lhs_pos, scores.lhs_neg)
rhs_loss = self.loss_fn(scores.rhs_pos, scores.rhs_neg)
lhs_loss = self.loss_fn(scores.lhs_pos, scores.lhs_neg, batch_edges.weight)
rhs_loss = self.loss_fn(scores.rhs_pos, scores.rhs_neg, batch_edges.weight)
relation = (
batch_edges.get_relation_type_as_scalar()
if batch_edges.has_scalar_relation_type()
Expand Down
5 changes: 4 additions & 1 deletion torchbiggraph/converters/import_from_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def main():
help="Column index for target entity",
)
parser.add_argument("--rel-col", type=str, help="Column index for relation entity")
parser.add_argument(
"--weight-col", type=int, help="(Optional) Column index for edge weight"
)
parser.add_argument(
"--relation-type-min-count",
type=int,
Expand All @@ -66,7 +69,7 @@ def main():
entity_path,
edge_paths,
opt.edge_paths,
ParquetEdgelistReader(opt.lhs_col, opt.rhs_col, opt.rel_col),
ParquetEdgelistReader(opt.lhs_col, opt.rhs_col, opt.rel_col, opt.weight_col),
opt.entity_min_count,
opt.relation_type_min_count,
dynamic_relations,
Expand Down
5 changes: 4 additions & 1 deletion torchbiggraph/converters/import_from_tsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def main():
help="Column index for target entity",
)
parser.add_argument("--rel-col", type=int, help="Column index for relation entity")
parser.add_argument(
"--weight-col", type=int, help="(Optional) Column index for edge weight"
)
parser.add_argument(
"--relation-type-min-count",
type=int,
Expand All @@ -66,7 +69,7 @@ def main():
entity_path,
edge_paths,
opt.edge_paths,
TSVEdgelistReader(opt.lhs_col, opt.rhs_col, opt.rel_col),
TSVEdgelistReader(opt.lhs_col, opt.rhs_col, opt.rel_col, opt.weight_col),
opt.entity_min_count,
opt.relation_type_min_count,
dynamic_relations,
Expand Down
Loading

0 comments on commit e0d3bbf

Please sign in to comment.