Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider adding a pure PyTorch function for segment_mm and gather_mm #56

Closed
tvercaut opened this issue Sep 26, 2024 · 1 comment
Closed

Comments

@tvercaut
Copy link
Member

segment_mm and gather_mm are borderline in scope for this repo but they would still be useful additions. DGL provides some nice functions but installing it is non trivial.
https://docs.dgl.ai/generated/dgl.ops.gather_mm.html
https://docs.dgl.ai/generated/dgl.ops.segment_mm.html
https://pyg-lib.readthedocs.io/en/latest/modules/ops.html#pyg_lib.ops.segment_matmul

A workaround for gather_mm is found here and would be a good starting point for inclusion in torchsparsegradutils (just needs a unit test):
pytorch/pytorch#136747

@tvercaut
Copy link
Member Author

tvercaut commented Sep 26, 2024

For the record, the initial workaround

def my_gather_mm(a, b, idx_b):
  # mimic https://docs.dgl.ai/generated/dgl.ops.gather_mm.html
  R,D1,D2 = b.shape
  N = idx_b.shape[0]

  # Sanity check sizes
  assert(a.shape[0]==N and a.shape[1]==D1)

  torchdevice = a.device
  src_idx = torch.arange(N,device=torchdevice)

  # Ideally the conversions below to nested tensor would be handled without for looops and without copy
  nested_a = torch.nested.as_nested_tensor( 
      [torch.index_select(a,dim=0,index=torch.nonzero(idx_b==i).squeeze()) for i in range(R)] )
  src_idx_reshuffled = torch.cat( 
      [torch.index_select(src_idx,dim=0,index=torch.nonzero(idx_b==i).squeeze()) for i in range(R)] )
  nested_b = torch.nested.as_nested_tensor( 
      [b[i,:,:].squeeze() for i in range(R)] )

  # The actual gather matmul computation
  nested_ab = torch.matmul(nested_a,nested_b)

  # Convert back to tensors, again, ideally this would be handled natively with no copy
  ab_segmented = torch.cat(nested_ab.unbind(),dim=0)
  ab = torch.empty((N,D2),device=torchdevice)
  ab[src_idx_reshuffled] = ab_segmented
  return ab

can be simplified a bit

def my_gather_mm(a, b, idx_b):
  # mimic https://docs.dgl.ai/generated/dgl.ops.gather_mm.html
  R,D1,D2 = b.shape
  N = idx_b.shape[0]

  # Sanity check sizes
  assert(a.shape[0]==N and a.shape[1]==D1)

  torchdevice = a.device
  src_idx = torch.arange(N,device=torchdevice)

  # Ideally the conversions below to nested tensor would be handled without for looops and without copy
  nested_a = torch.nested.as_nested_tensor([a[idx_b==i,:] for i in range(R)] )
  src_idx_reshuffled = torch.cat( [src_idx[idx_b==i] for i in range(R)] )
  nested_b = torch.nested.as_nested_tensor(
      [b[i,:,:].squeeze() for i in range(R)] )

  # The actual gather matmul computation
  nested_ab = torch.matmul(nested_a,nested_b)

  # Convert back to tensors, again, ideally this would be handled natively with no copy
  ab_segmented = torch.cat(nested_ab.unbind(),dim=0)
  ab = torch.empty((N,D2),device=torchdevice)
  ab[src_idx_reshuffled] = ab_segmented
  return ab

tvercaut added a commit that referenced this issue Sep 27, 2024
* Added simple workarounds for gather_mm and segment_mm. See #56

* bumping python and pytorch version in CI

* enabling black on notebooks in CI

* updating github actions to avoid deprecation warning
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant