-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consider adding a pure PyTorch function for segment_mm
and gather_mm
#56
Comments
For the record, the initial workaround def my_gather_mm(a, b, idx_b):
# mimic https://docs.dgl.ai/generated/dgl.ops.gather_mm.html
R,D1,D2 = b.shape
N = idx_b.shape[0]
# Sanity check sizes
assert(a.shape[0]==N and a.shape[1]==D1)
torchdevice = a.device
src_idx = torch.arange(N,device=torchdevice)
# Ideally the conversions below to nested tensor would be handled without for looops and without copy
nested_a = torch.nested.as_nested_tensor(
[torch.index_select(a,dim=0,index=torch.nonzero(idx_b==i).squeeze()) for i in range(R)] )
src_idx_reshuffled = torch.cat(
[torch.index_select(src_idx,dim=0,index=torch.nonzero(idx_b==i).squeeze()) for i in range(R)] )
nested_b = torch.nested.as_nested_tensor(
[b[i,:,:].squeeze() for i in range(R)] )
# The actual gather matmul computation
nested_ab = torch.matmul(nested_a,nested_b)
# Convert back to tensors, again, ideally this would be handled natively with no copy
ab_segmented = torch.cat(nested_ab.unbind(),dim=0)
ab = torch.empty((N,D2),device=torchdevice)
ab[src_idx_reshuffled] = ab_segmented
return ab can be simplified a bit def my_gather_mm(a, b, idx_b):
# mimic https://docs.dgl.ai/generated/dgl.ops.gather_mm.html
R,D1,D2 = b.shape
N = idx_b.shape[0]
# Sanity check sizes
assert(a.shape[0]==N and a.shape[1]==D1)
torchdevice = a.device
src_idx = torch.arange(N,device=torchdevice)
# Ideally the conversions below to nested tensor would be handled without for looops and without copy
nested_a = torch.nested.as_nested_tensor([a[idx_b==i,:] for i in range(R)] )
src_idx_reshuffled = torch.cat( [src_idx[idx_b==i] for i in range(R)] )
nested_b = torch.nested.as_nested_tensor(
[b[i,:,:].squeeze() for i in range(R)] )
# The actual gather matmul computation
nested_ab = torch.matmul(nested_a,nested_b)
# Convert back to tensors, again, ideally this would be handled natively with no copy
ab_segmented = torch.cat(nested_ab.unbind(),dim=0)
ab = torch.empty((N,D2),device=torchdevice)
ab[src_idx_reshuffled] = ab_segmented
return ab |
tvercaut
added a commit
that referenced
this issue
Sep 26, 2024
tvercaut
added a commit
that referenced
this issue
Sep 27, 2024
* Added simple workarounds for gather_mm and segment_mm. See #56 * bumping python and pytorch version in CI * enabling black on notebooks in CI * updating github actions to avoid deprecation warning
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
segment_mm
andgather_mm
are borderline in scope for this repo but they would still be useful additions. DGL provides some nice functions but installing it is non trivial.https://docs.dgl.ai/generated/dgl.ops.gather_mm.html
https://docs.dgl.ai/generated/dgl.ops.segment_mm.html
https://pyg-lib.readthedocs.io/en/latest/modules/ops.html#pyg_lib.ops.segment_matmul
A workaround for
gather_mm
is found here and would be a good starting point for inclusion in torchsparsegradutils (just needs a unit test):pytorch/pytorch#136747
The text was updated successfully, but these errors were encountered: