Skip to content

Commit

Permalink
Merge pull request #573 from DesmonDay/main
Browse files Browse the repository at this point in the history
Update code
  • Loading branch information
Yelrose authored Sep 21, 2023
2 parents bf120a2 + 324d857 commit 1c49fd3
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 5 deletions.
3 changes: 2 additions & 1 deletion pgl/bigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,7 +1387,8 @@ def _join_graph_index(graph_list, mode="src_node"):
% mode)

if is_tensor:
counts = paddle.concat(counts)
counts = [c.item() for c in counts]
counts = paddle.to_tensor(counts, dtype="int64")
return op.get_index_from_counts(counts)

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion pgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,8 @@ def _join_graph_index(graph_list, mode="node"):
mode)

if is_tensor:
counts = paddle.concat(counts)
counts = [c.item() for c in counts]
counts = paddle.to_tensor(counts, dtype="int64")
return op.get_index_from_counts(counts)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
numpy >= 1.16.4
cython >= 0.25.2
numpy==1.26.0
cython==3.0.2
2 changes: 1 addition & 1 deletion tests/test_static_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_static_graph(self):
model2.set_state_dict(state_dict)

feed_dict = {
"num_nodes": num_nodes,
"num_nodes": np.array([num_nodes]).astype("int32"),
"edges": np.array(
edges, dtype="int32"),
"feature": nfeat.astype("float32"),
Expand Down

0 comments on commit 1c49fd3

Please sign in to comment.