Skip to content

Commit

Permalink
Condition sampled_op execution path (#6805)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Feb 27, 2023
1 parent eac6cf1 commit 903780f
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Correctly use edge weights in `GDC` example ([#6159](https://github.com/pyg-team/pytorch_geometric/pull/6159))
- Breaking Change: Moved PyTorch Lightning data modules to `torch_geometric.data.lightning` ([#6140](https://github.com/pyg-team/pytorch_geometric/pull/6140))
- Make `torch_sparse` an optional dependency ([#6132](https://github.com/pyg-team/pytorch_geometric/pull/6132), [#6134](https://github.com/pyg-team/pytorch_geometric/pull/6134), [#6138](https://github.com/pyg-team/pytorch_geometric/pull/6138), [#6139](https://github.com/pyg-team/pytorch_geometric/pull/6139))
- Optimized `utils.softmax` implementation ([#6113](https://github.com/pyg-team/pytorch_geometric/pull/6113), [#6155](https://github.com/pyg-team/pytorch_geometric/pull/6155))
- Optimized `utils.softmax` implementation ([#6113](https://github.com/pyg-team/pytorch_geometric/pull/6113), [#6155](https://github.com/pyg-team/pytorch_geometric/pull/6155), [#6805](https://github.com/pyg-team/pytorch_geometric/pull/6805))
- Optimized `topk` implementation for large enough graphs ([#6123](https://github.com/pyg-team/pytorch_geometric/pull/6123))

### Removed
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
try:
import pyg_lib # noqa
WITH_PYG_LIB = True
WITH_SAMPLED_OP = WITH_PYG_LIB and hasattr(pyg_lib.ops, 'sampled_add')
WITH_INDEX_SORT = WITH_PYG_LIB and hasattr(pyg_lib.ops, 'index_sort')
except (ImportError, OSError) as e:
if isinstance(e, OSError):
warnings.warn(f"An issue occurred while importing 'pyg-lib'. "
f"Disabling its usage. Stacktrace: {e}")
pyg_lib = object
WITH_PYG_LIB = False
WITH_SAMPLED_OP = False
WITH_INDEX_SORT = False

try:
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/utils/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def softmax(
N = maybe_num_nodes(index, num_nodes)
with torch.no_grad():
src_max = scatter(src, index, dim, dim_size=N, reduce='max')
if (torch_geometric.typing.WITH_PYG_LIB and src.dim() == 2
if (torch_geometric.typing.WITH_SAMPLED_OP and src.dim() == 2
and (dim == 0 or dim == -2)):
out = pyg_lib.ops.sampled_sub(src, src_max, right_index=index)
else:
Expand Down

0 comments on commit 903780f

Please sign in to comment.