Skip to content

Commit

Permalink
Define custom op for grid points generator of single level feature map (
Browse files Browse the repository at this point in the history
pytorch#4395)

Summary:
Pull Request resolved: pytorch#4395

In order to lower `grid_priors` function that generate grids to Vulkan, we plan to implement this function in one shader due to several reasons like compilation issue of meshgrid, and reduce data copy.

Define this function into one operator and will implement this op in the following diff.

The spec of this op is:
```
(int height, int width, int stride, float offset) -> Tensor
```

Example:
```
height = 2
width = 3
stride = 1
offset = 0
output.shape = [3x2, 2]
output = tensor([[0, 0],
        [1, 0],
        [2, 0],
        [0, 1],
        [1, 1],
        [2, 1]])
```

Reviewed By: jorgep31415

Differential Revision: D60141165

fbshipit-source-id: f56f04671eb5ca75c6a06c4b70b4067a0dc43e2a
  • Loading branch information
Yujie Hui authored and facebook-github-bot committed Jul 25, 2024
1 parent dbf87b0 commit 77c905d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
26 changes: 24 additions & 2 deletions backends/vulkan/passes/custom_ops_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import torch.library

namespace = "et_vk"
lib = torch.library.Library(namespace, "DEF")


def conv_with_clamp_impl(
input,
Expand Down Expand Up @@ -37,11 +40,30 @@ def conv_with_clamp_impl(
)


namespace = "et_vk"
lib = torch.library.Library(namespace, "DEF")
name = "conv_with_clamp"
lib.define(
f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Scalar? output_min, Scalar? output_max) -> Tensor"
)
lib.impl(name, conv_with_clamp_impl, "CompositeExplicitAutograd")
conv_with_clamp_op = getattr(getattr(torch.ops, namespace), name)


def grid_priors_impl(
height,
width,
stride,
offset,
):
shift_x = (torch.arange(0, width) + offset) * stride
shift_y = (torch.arange(0, height) + offset) * stride
shift_xx, shift_yy = torch.meshgrid(shift_y, shift_x)
shift_xx = shift_xx.reshape(-1)
shift_yy = shift_yy.reshape(-1)
shifts = torch.stack((shift_yy, shift_xx), dim=-1)
return shifts


name = "grid_priors"
lib.define(f"{name}(int height, int width, int stride, float offset) -> Tensor")
lib.impl(name, grid_priors_impl)
grid_priors_op = getattr(getattr(torch.ops, namespace), name)
30 changes: 30 additions & 0 deletions backends/vulkan/passes/test_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,33 @@ def forward(self, x):
"custom op `conv_with_clamp` output shape matches expected",
)
self.assertTrue(torch.allclose(custom_out, expected_out))

def test_grid_priors(self):
class GridPriors(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, height, width, stride, offset):
return torch.ops.et_vk.grid_priors(height, width, stride, offset)

model = GridPriors()
sample_input = (2, 3, 4, 0.5)
custom_out = model(*sample_input)

def calculate_expected_output(height, width, stride, offset):
shift_x = (torch.arange(0, width) + offset) * stride
shift_y = (torch.arange(0, height) + offset) * stride
shift_xx, shift_yy = torch.meshgrid(shift_y, shift_x)
shift_xx = shift_xx.reshape(-1)
shift_yy = shift_yy.reshape(-1)
shifts = torch.stack((shift_yy, shift_xx), dim=-1)
return shifts

expected_out = calculate_expected_output(*sample_input)

self.assertEqual(
custom_out.shape,
expected_out.shape,
"custom op `grid_priors` output shape matches expected",
)
self.assertTrue(torch.allclose(custom_out, expected_out))

0 comments on commit 77c905d

Please sign in to comment.