diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index e96c1577ba..44d2e6cbbd 100755 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -66,6 +66,7 @@ from .sync_bn import SyncBatchNorm from .three_interpolate import three_interpolate from .three_nn import three_nn +from .three_nn_vector_pool import three_nn_vector_pool_by_two_step from .tin_shift import TINShift, tin_shift from .upfirdn2d import upfirdn2d from .voxelize import Voxelization, voxelization @@ -102,5 +103,5 @@ 'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons', 'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou', 'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance', - 'PrRoIPool', 'prroi_pool' + 'PrRoIPool', 'prroi_pool', 'three_nn_vector_pool_by_two_step' ] diff --git a/mmcv/ops/csrc/common/cuda/three_nn_vector_pool_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/three_nn_vector_pool_cuda_kernel.cuh new file mode 100644 index 0000000000..e60f43ca06 --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/three_nn_vector_pool_cuda_kernel.cuh @@ -0,0 +1,180 @@ +#ifndef THREE_NN_VECTOR_POOL_CUDA_KERNEL_CUH +#define THREE_NN_VECTOR_POOL_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +__global__ void query_stacked_local_neighbor_idxs_cuda_kernel( + const float *support_xyz, const int *xyz_batch_cnt, const float *new_xyz, + const int *new_xyz_batch_cnt, int *stack_neighbor_idxs, int *start_len, + int *cumsum, int avg_length_of_neighbor_idxs, float max_neighbour_distance, + int batch_size, int M, int nsample, int neighbor_type) { + // support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features + // xyz_batch_cnt: (batch_size), [N1, N2, ...] + // new_xyz: (M1 + M2 ..., 3) centers of the ball query + // new_xyz_batch_cnt: (batch_size), [M1, M2, ...] + // stack_neighbor_idxs: (max_length_of_neighbor_idxs) + // start_len: (M1 + M2, 2) [start_offset, neighbor_length] + // cumsum: (1), max offset of current data in stack_neighbor_idxs + // max_neighbour_distance: float + // nsample: find all (-1), find limited number(>0) + // neighbor_type: 1: ball, others: cube + CUDA_1D_KERNEL_LOOP(pt_idx, M) { + const float *cur_support_xyz = support_xyz; + const float *cur_new_xyz = new_xyz; + int *cur_start_len = start_len; + int *cur_stack_neighbor_idxs = stack_neighbor_idxs; + + int bs_idx = 0, pt_cnt = new_xyz_batch_cnt[0]; + for (int k = 1; k < batch_size; k++) { + if (pt_idx < pt_cnt) break; + pt_cnt += new_xyz_batch_cnt[k]; + bs_idx = k; + } + + int xyz_batch_start_idx = 0; + for (int k = 0; k < bs_idx; k++) xyz_batch_start_idx += xyz_batch_cnt[k]; + + cur_support_xyz += xyz_batch_start_idx * 3; + cur_new_xyz += pt_idx * 3; + cur_start_len += pt_idx * 2; + + float new_x = cur_new_xyz[0]; + float new_y = cur_new_xyz[1]; + float new_z = cur_new_xyz[2]; + int n = xyz_batch_cnt[bs_idx]; + + float local_x, local_y, local_z; + float radius2 = max_neighbour_distance * max_neighbour_distance; + + int temp_idxs[1000]; + + int sample_cnt = 0; + for (int k = 0; k < n; ++k) { + local_x = cur_support_xyz[k * 3 + 0] - new_x; + local_y = cur_support_xyz[k * 3 + 1] - new_y; + local_z = cur_support_xyz[k * 3 + 2] - new_z; + + if (neighbor_type == 1) { + // ball + if (local_x * local_x + local_y * local_y + local_z * local_z > + radius2) { + continue; + } + } else { + // voxel + if ((fabs(local_x) > max_neighbour_distance) | + (fabs(local_y) > max_neighbour_distance) | + (fabs(local_z) > max_neighbour_distance)) { + continue; + } + } + if (sample_cnt < 1000) { + temp_idxs[sample_cnt] = k; + } else { + break; + } + sample_cnt++; + if (nsample > 0 && sample_cnt >= nsample) break; + } + cur_start_len[0] = atomicAdd(cumsum, sample_cnt); + cur_start_len[1] = sample_cnt; + + int max_thresh = avg_length_of_neighbor_idxs * M; + if (cur_start_len[0] >= max_thresh) continue; + + cur_stack_neighbor_idxs += cur_start_len[0]; + if (cur_start_len[0] + sample_cnt >= max_thresh) + sample_cnt = max_thresh - cur_start_len[0]; + + for (int k = 0; k < sample_cnt; k++) { + cur_stack_neighbor_idxs[k] = temp_idxs[k] + xyz_batch_start_idx; + } + } +} + +__global__ void query_three_nn_by_stacked_local_idxs_cuda_kernel( + const float *support_xyz, const float *new_xyz, + const float *new_xyz_grid_centers, int *new_xyz_grid_idxs, + float *new_xyz_grid_dist2, const int *stack_neighbor_idxs, + const int *start_len, int M, int num_total_grids) { + // support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features + // new_xyz: (M1 + M2 ..., 3) centers of the ball query + // new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of + // each grid new_xyz_grid_idxs: (M1 + M2 ..., num_total_grids, 3) three-nn + // new_xyz_grid_dist2: (M1 + M2 ..., num_total_grids, 3) square of dist of + // three-nn stack_neighbor_idxs: (max_length_of_neighbor_idxs) start_len: (M1 + // + M2, 2) [start_offset, neighbor_length] + int grid_idx = blockIdx.y; + if (grid_idx >= num_total_grids) return; + CUDA_1D_KERNEL_LOOP(pt_idx, M) { + const float *cur_new_xyz = new_xyz; + const float *cur_new_xyz_grid_centers = new_xyz_grid_centers; + int *cur_new_xyz_grid_idxs = new_xyz_grid_idxs; + float *cur_new_xyz_grid_dist2 = new_xyz_grid_dist2; + const int *cur_start_len = start_len; + const int *cur_stack_neighbor_idxs = stack_neighbor_idxs; + + cur_new_xyz += pt_idx * 3; + cur_new_xyz_grid_centers += pt_idx * num_total_grids * 3 + grid_idx * 3; + cur_new_xyz_grid_idxs += pt_idx * num_total_grids * 3 + grid_idx * 3; + cur_new_xyz_grid_dist2 += pt_idx * num_total_grids * 3 + grid_idx * 3; + + cur_start_len += pt_idx * 2; + cur_stack_neighbor_idxs += cur_start_len[0]; + int neighbor_length = cur_start_len[1]; + + float center_x = cur_new_xyz_grid_centers[0]; + float center_y = cur_new_xyz_grid_centers[1]; + float center_z = cur_new_xyz_grid_centers[2]; + + double best1 = 1e40, best2 = 1e40, best3 = 1e40; + int besti1 = -1, besti2 = -1, besti3 = -1; + for (int k = 0; k < neighbor_length; k++) { + int cur_neighbor_idx = cur_stack_neighbor_idxs[k]; + + float x = support_xyz[cur_neighbor_idx * 3 + 0]; + float y = support_xyz[cur_neighbor_idx * 3 + 1]; + float z = support_xyz[cur_neighbor_idx * 3 + 2]; + + float d = (center_x - x) * (center_x - x) + + (center_y - y) * (center_y - y) + + (center_z - z) * (center_z - z); + + if (d < best1) { + best3 = best2; + besti3 = besti2; + best2 = best1; + besti2 = besti1; + best1 = d; + besti1 = cur_neighbor_idx; + } else if (d < best2) { + best3 = best2; + besti3 = besti2; + best2 = d; + besti2 = cur_neighbor_idx; + } else if (d < best3) { + best3 = d; + besti3 = cur_neighbor_idx; + } + } + if (besti2 == -1) { + besti2 = besti1; + best2 = best1; + } + if (besti3 == -1) { + besti3 = besti1; + best3 = best1; + } + cur_new_xyz_grid_dist2[0] = best1; + cur_new_xyz_grid_dist2[1] = best2; + cur_new_xyz_grid_dist2[2] = best3; + cur_new_xyz_grid_idxs[0] = besti1; + cur_new_xyz_grid_idxs[1] = besti2; + cur_new_xyz_grid_idxs[2] = besti3; + } +} +#endif diff --git a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp index e558634068..60866500e2 100644 --- a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp +++ b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp @@ -1867,3 +1867,60 @@ REGISTER_DEVICE_IMPL(prroi_pool_forward_impl, CUDA, prroi_pool_forward_cuda); REGISTER_DEVICE_IMPL(prroi_pool_backward_impl, CUDA, prroi_pool_backward_cuda); REGISTER_DEVICE_IMPL(prroi_pool_coor_backward_impl, CUDA, prroi_pool_coor_backward_cuda); + +void StackQueryLocalNeighborIdxsCUDAKernelLauncher( + const Tensor support_xyz_tensor, const Tensor xyz_batch_cnt_tensor, + const Tensor new_xyz_tensor, const Tensor new_xyz_batch_cnt_tensor, + Tensor stack_neighbor_idxs_tensor, Tensor start_len_tensor, + Tensor cumsum_tensor, const int avg_length_of_neighbor_idxs, + const float max_neighbour_distance, const int nsample, + const int neighbor_type); + +void stack_query_local_neighbor_idxs_cuda( + const Tensor support_xyz_tensor, const Tensor xyz_batch_cnt_tensor, + const Tensor new_xyz_tensor, const Tensor new_xyz_batch_cnt_tensor, + Tensor stack_neighbor_idxs_tensor, Tensor start_len_tensor, + Tensor cumsum_tensor, const int avg_length_of_neighbor_idxs, + const float max_neighbour_distance, const int nsample, + const int neighbor_type) { + StackQueryLocalNeighborIdxsCUDAKernelLauncher( + support_xyz_tensor, xyz_batch_cnt_tensor, new_xyz_tensor, + new_xyz_batch_cnt_tensor, stack_neighbor_idxs_tensor, start_len_tensor, + cumsum_tensor, avg_length_of_neighbor_idxs, max_neighbour_distance, + nsample, neighbor_type); +} + +void stack_query_local_neighbor_idxs_impl( + const Tensor support_xyz_tensor, const Tensor xyz_batch_cnt_tensor, + const Tensor new_xyz_tensor, const Tensor new_xyz_batch_cnt_tensor, + Tensor stack_neighbor_idxs_tensor, Tensor start_len_tensor, + Tensor cumsum_tensor, const int avg_length_of_neighbor_idxs, + const float max_neighbour_distance, const int nsample, + const int neighbor_type); + +void StackQueryThreeNNLocalIdxsCUDAKernelLauncher( + const Tensor support_xyz_tensor, const Tensor new_xyz_tensor, + const Tensor new_xyz_grid_centers_tensor, Tensor new_xyz_grid_idxs_tensor, + Tensor new_xyz_grid_dist2_tensor, Tensor stack_neighbor_idxs_tensor, + Tensor start_len_tensor, const int M, const int num_total_grids); + +void stack_query_three_nn_local_idxs_cuda( + const Tensor support_xyz_tensor, const Tensor new_xyz_tensor, + const Tensor new_xyz_grid_centers_tensor, Tensor new_xyz_grid_idxs_tensor, + Tensor new_xyz_grid_dist2_tensor, Tensor stack_neighbor_idxs_tensor, + Tensor start_len_tensor, const int M, const int num_total_grids) { + StackQueryThreeNNLocalIdxsCUDAKernelLauncher( + support_xyz_tensor, new_xyz_tensor, new_xyz_grid_centers_tensor, + new_xyz_grid_idxs_tensor, new_xyz_grid_dist2_tensor, + stack_neighbor_idxs_tensor, start_len_tensor, M, num_total_grids); +} + +void stack_query_three_nn_local_idxs_impl( + const Tensor support_xyz_tensor, const Tensor new_xyz_tensor, + const Tensor new_xyz_grid_centers_tensor, Tensor new_xyz_grid_idxs_tensor, + Tensor new_xyz_grid_dist2_tensor, Tensor stack_neighbor_idxs_tensor, + Tensor start_len_tensor, const int M, const int num_total_grids); +REGISTER_DEVICE_IMPL(stack_query_three_nn_local_idxs_impl, CUDA, + stack_query_three_nn_local_idxs_cuda); +REGISTER_DEVICE_IMPL(stack_query_local_neighbor_idxs_impl, CUDA, + stack_query_local_neighbor_idxs_cuda); diff --git a/mmcv/ops/csrc/pytorch/cuda/three_nn_vector_pool.cu b/mmcv/ops/csrc/pytorch/cuda/three_nn_vector_pool.cu new file mode 100644 index 0000000000..4e904fcc24 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/three_nn_vector_pool.cu @@ -0,0 +1,58 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#include +#include +#include + +#include "pytorch_cuda_helper.hpp" +#include "vector_pool.cuh" +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + + +void StackQueryLocalNeighborIdxsCUDAKernelLauncher( + const Tensor support_xyz_tensor, const Tensor xyz_batch_cnt_tensor, + const Tensor new_xyz_tensor, const Tensor new_xyz_batch_cnt_tensor, + Tensor stack_neighbor_idxs_tensor, Tensor start_len_tensor, + Tensor cumsum_tensor, const int avg_length_of_neighbor_idxs, + const float max_neighbour_distance, const int nsample, + const int neighbor_type) { + int batch_size = xyz_batch_cnt_tensor.size(0); + int M = new_xyz_tensor.size(0); + at::cuda::CUDAGuard device_guard(support_xyz_tensor.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // blockIdx.x(col), blockIdx.y(row) + dim3 blocks(DIVUP(M, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + + query_stacked_local_neighbor_idxs_cuda_kernel<<>>( + support_xyz_tensor.data_ptr(), + xyz_batch_cnt_tensor.data_ptr(), new_xyz_tensor.data_ptr(), + new_xyz_batch_cnt_tensor.data_ptr(), + stack_neighbor_idxs_tensor.data_ptr(), + start_len_tensor.data_ptr(), cumsum_tensor.data_ptr(), + avg_length_of_neighbor_idxs, max_neighbour_distance, batch_size, M, + nsample, neighbor_type); + AT_CUDA_CHECK(cudaGetLastError()); +} + +void StackQueryThreeNNLocalIdxsCUDAKernelLauncher( + const Tensor support_xyz_tensor, const Tensor new_xyz_tensor, + const Tensor new_xyz_grid_centers_tensor, Tensor new_xyz_grid_idxs_tensor, + Tensor new_xyz_grid_dist2_tensor, Tensor stack_neighbor_idxs_tensor, + Tensor start_len_tensor, const int M, const int num_total_grids) { + at::cuda::CUDAGuard device_guard(support_xyz_tensor.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // blockIdx.x(col), blockIdx.y(row) + dim3 blocks(DIVUP(M, THREADS_PER_BLOCK), num_total_grids); + dim3 threads(THREADS_PER_BLOCK); + + query_three_nn_by_stacked_local_idxs_cuda_kernel<<>>( + support_xyz_tensor.data_ptr(), new_xyz_tensor.data_ptr(), + new_xyz_grid_centers_tensor.data_ptr(), + new_xyz_grid_idxs_tensor.data_ptr(), + new_xyz_grid_dist2_tensor.data_ptr(), + stack_neighbor_idxs_tensor.data_ptr(), + start_len_tensor.data_ptr(), M, num_total_grids); + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 4947b72152..198ada757b 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -446,6 +446,20 @@ Tensor nms_quadri(const Tensor dets, const Tensor scores, const Tensor order, const Tensor dets_sorted, const float iou_threshold, const int multi_label); +void stack_query_local_neighbor_idxs( + const Tensor support_xyz_tensor, const Tensor xyz_batch_cnt_tensor, + const Tensor new_xyz_tensor, const Tensor new_xyz_batch_cnt_tensor, + Tensor stack_neighbor_idxs_tensor, Tensor start_len_tensor, + Tensor cumsum_tensor, const int avg_length_of_neighbor_idxs, + const float max_neighbour_distance, const int nsample, + const int neighbor_type); + +void stack_query_three_nn_local_idxs( + const Tensor support_xyz_tensor, const Tensor new_xyz_tensor, + const Tensor new_xyz_grid_centers_tensor, Tensor new_xyz_grid_idxs_tensor, + Tensor new_xyz_grid_dist2_tensor, Tensor stack_neighbor_idxs_tensor, + Tensor start_len_tensor, const int M, const int num_total_grids); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"), py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"), @@ -899,4 +913,19 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("dets"), py::arg("scores"), py::arg("order"), py::arg("dets_sorted"), py::arg("iou_threshold"), py::arg("multi_label")); + m.def("stack_query_local_neighbor_idxs", &stack_query_local_neighbor_idxs, + "stack quert local neighbor indexes", py::arg("support_xyz_tensor"), + py::arg("xyz_batch_cnt_tensor"), py::arg("new_xyz_tensor"), + py::arg("new_xyz_batch_cnt_tensor"), + py::arg("stack_neighbor_idxs_tensor"), py::arg("start_len_tensor"), + py::arg("cumsum_tensor"), py::arg("avg_length_of_neighbor_idxs"), + py::arg("max_neighbour_distance"), py::arg("nsample"), + py::arg("neighbor_type")); + m.def("stack_query_three_nn_local_idxs", &stack_query_three_nn_local_idxs, + "stack quert three nn local indexes", py::arg("support_xyz_tensor"), + py::arg("new_xyz_tensor"), py::arg("new_xyz_grid_centers_tensor"), + py::arg("new_xyz_grid_idxs_tensor"), + py::arg("new_xyz_grid_dist2_tensor"), + py::arg("stack_neighbor_idxs_tensor"), py::arg("start_len_tensor"), + py::arg("M"), py::arg("num_total_grids")); } diff --git a/mmcv/ops/csrc/pytorch/three_nn_vector_pool.cpp b/mmcv/ops/csrc/pytorch/three_nn_vector_pool.cpp new file mode 100644 index 0000000000..2e3483bd57 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/three_nn_vector_pool.cpp @@ -0,0 +1,55 @@ +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + + +void stack_query_local_neighbor_idxs_impl( + const Tensor support_xyz_tensor, const Tensor xyz_batch_cnt_tensor, + const Tensor new_xyz_tensor, const Tensor new_xyz_batch_cnt_tensor, + Tensor stack_neighbor_idxs_tensor, Tensor start_len_tensor, + Tensor cumsum_tensor, const int avg_length_of_neighbor_idxs, + const float max_neighbour_distance, const int nsample, + const int neighbor_type) { + DISPATCH_DEVICE_IMPL(stack_query_local_neighbor_idxs_impl, support_xyz_tensor, + xyz_batch_cnt_tensor, new_xyz_tensor, + new_xyz_batch_cnt_tensor, stack_neighbor_idxs_tensor, + start_len_tensor, cumsum_tensor, + avg_length_of_neighbor_idxs, max_neighbour_distance, + nsample, neighbor_type); +} + +void stack_query_local_neighbor_idxs( + const Tensor support_xyz_tensor, const Tensor xyz_batch_cnt_tensor, + const Tensor new_xyz_tensor, const Tensor new_xyz_batch_cnt_tensor, + Tensor stack_neighbor_idxs_tensor, Tensor start_len_tensor, + Tensor cumsum_tensor, const int avg_length_of_neighbor_idxs, + const float max_neighbour_distance, const int nsample, + const int neighbor_type) { + stack_query_local_neighbor_idxs_impl( + support_xyz_tensor, xyz_batch_cnt_tensor, new_xyz_tensor, + new_xyz_batch_cnt_tensor, stack_neighbor_idxs_tensor, start_len_tensor, + cumsum_tensor, avg_length_of_neighbor_idxs, max_neighbour_distance, + nsample, neighbor_type); +} + +void stack_query_three_nn_local_idxs_impl( + const Tensor support_xyz_tensor, const Tensor new_xyz_tensor, + const Tensor new_xyz_grid_centers_tensor, Tensor new_xyz_grid_idxs_tensor, + Tensor new_xyz_grid_dist2_tensor, Tensor stack_neighbor_idxs_tensor, + Tensor start_len_tensor, const int M, const int num_total_grids) { + DISPATCH_DEVICE_IMPL(stack_query_three_nn_local_idxs_impl, support_xyz_tensor, + new_xyz_tensor, new_xyz_grid_centers_tensor, + new_xyz_grid_idxs_tensor, new_xyz_grid_dist2_tensor, + stack_neighbor_idxs_tensor, start_len_tensor, M, + num_total_grids); +} + +void stack_query_three_nn_local_idxs( + const Tensor support_xyz_tensor, const Tensor new_xyz_tensor, + const Tensor new_xyz_grid_centers_tensor, Tensor new_xyz_grid_idxs_tensor, + Tensor new_xyz_grid_dist2_tensor, Tensor stack_neighbor_idxs_tensor, + Tensor start_len_tensor, const int M, const int num_total_grids) { + stack_query_three_nn_local_idxs_impl( + support_xyz_tensor, new_xyz_tensor, new_xyz_grid_centers_tensor, + new_xyz_grid_idxs_tensor, new_xyz_grid_dist2_tensor, + stack_neighbor_idxs_tensor, start_len_tensor, M, num_total_grids); +} diff --git a/mmcv/ops/three_nn_vector_pool.py b/mmcv/ops/three_nn_vector_pool.py new file mode 100644 index 0000000000..39a5d218fd --- /dev/null +++ b/mmcv/ops/three_nn_vector_pool.py @@ -0,0 +1,91 @@ +from typing import Tuple + +import torch +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', + ['stack_query_local_neighbor_idxs', 'stack_query_three_nn_local_idxs']) + + +class ThreeNNVectorPoolByTwoStep(Function): + """The local space around a center point is divided into dense voxels, + where the inside point-wise features are generated by interpolating from + three nearest neighbors.""" + + @staticmethod + def forward( + ctx, xyz: torch.Tensor, xyz_batch_cnt: torch.Tensor, + new_xyz: torch.Tensor, new_xyz_grid_centers: torch.Tensor, + new_xyz_batch_cnt: torch.Tensor, max_neighbour_distance: float, + nsample: int, neighbor_type: int, avg_length_of_neighbor_idxs: int, + num_total_grids: int, neighbor_distance_multiplier: float + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + xyz (torch.Tensor): XYZ coordinates of the features shape + with (N1 + N2 ..., 3). + xyz_batch_cnt: (torch.Tensor): Stacked input xyz coordinates + nums in each batch, just like (N1, N2, ...). + new_xyz (torch.Tensor): Centers of the ball + query shape with (M1 + M2 ..., 3). + new_xyz_grid_centers (torch.Tensor): Grids centers of each grid + shape with (M1 + M2 ..., num_total_grids, 3). + new_xyz_batch_cnt: (torch.Tensor): Stacked centers coordinates + nums in each batch, just line (M1, M2, ...). + max_neighbour_distance (float): Max neighbour distance for center. + nsample (int): Find all (-1), find limited number(>0). + neighbor_type (int): Neighbor type, 1: ball, others: cube. + avg_length_of_neighbor_idxs (int): Num avg length of neighbor idxs. + num_total_grids (int): Total grids num. + neighbor_distance_multiplier (float): Used to compute + query_distance. query_distance = neighbor_distance_multiplier + * max_neighbour_distance + + Returns: + - new_xyz_grid_dist (torch.Tensor): Three nn xyz for query + shape with (M1 + M2 ..., num_total_grids, 3) + - new_xyz_grid_idxs (torch.Tensor): Indexes for new xyz grids + with shape (M1 + M2 ..., num_total_grids, 3). + - avg_length_of_neighbor_idxs (torch.Tensor): Average length of + neighbor indexes. + """ + num_new_xyz = new_xyz.shape[0] + new_xyz_grid_dist2 = new_xyz_grid_centers.new_zeros( + new_xyz_grid_centers.shape) + new_xyz_grid_idxs = new_xyz_grid_centers.new_zeros( + new_xyz_grid_centers.shape).int().fill_(-1) + + while True: + num_max_sum_points = avg_length_of_neighbor_idxs * num_new_xyz + stack_neighbor_idxs = new_xyz_grid_idxs.new_zeros( + num_max_sum_points) + start_len = new_xyz_grid_idxs.new_zeros(num_new_xyz, 2).int() + cumsum = new_xyz_grid_idxs.new_zeros(1) + + ext_module.stack_query_local_neighbor_idxs( + xyz.contiguous(), xyz_batch_cnt.contiguous(), + new_xyz.contiguous(), new_xyz_batch_cnt.contiguous(), + stack_neighbor_idxs.contiguous(), start_len.contiguous(), + cumsum, avg_length_of_neighbor_idxs, + max_neighbour_distance * neighbor_distance_multiplier, nsample, + neighbor_type) + avg_length_of_neighbor_idxs = cumsum[0].item( + ) // num_new_xyz + int(cumsum[0].item() % num_new_xyz > 0) + + if cumsum[0] <= num_max_sum_points: + break + + stack_neighbor_idxs = stack_neighbor_idxs[:cumsum[0]] + ext_module.stack_query_three_nn_local_idxs( + xyz, new_xyz, new_xyz_grid_centers, new_xyz_grid_idxs, + new_xyz_grid_dist2, stack_neighbor_idxs, start_len, num_new_xyz, + num_total_grids) + + return torch.sqrt(new_xyz_grid_dist2), new_xyz_grid_idxs, torch.tensor( + avg_length_of_neighbor_idxs) + + +three_nn_vector_pool_by_two_step = ThreeNNVectorPoolByTwoStep.apply diff --git a/tests/test_ops/test_three_nn_vector_pool.py b/tests/test_ops/test_three_nn_vector_pool.py new file mode 100644 index 0000000000..ef8c235275 --- /dev/null +++ b/tests/test_ops/test_three_nn_vector_pool.py @@ -0,0 +1,140 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmcv.ops import three_nn_vector_pool_by_two_step +from mmcv.utils import IS_CUDA_AVAILABLE + + +def get_dense_voxels_by_center(point_centers, max_neighbour_distance, + num_voxels): + """ + Args: + point_centers: (N, 3) + max_neighbour_distance: float + num_voxels: [num_x, num_y, num_z] + + Returns: + voxel_centers: (N, total_voxels, 3) + """ + R = max_neighbour_distance + device = point_centers.device + x_grids = torch.arange( + -R + R / num_voxels[0], + R - R / num_voxels[0] + 1e-5, + 2 * R / num_voxels[0], + device=device) + y_grids = torch.arange( + -R + R / num_voxels[1], + R - R / num_voxels[1] + 1e-5, + 2 * R / num_voxels[1], + device=device) + z_grids = torch.arange( + -R + R / num_voxels[2], + R - R / num_voxels[2] + 1e-5, + 2 * R / num_voxels[2], + device=device) + x_offset, y_offset, z_offset = torch.meshgrid( + x_grids, y_grids, z_grids) # shape: [num_x, num_y, num_z] + xyz_offset = torch.cat( + (x_offset.contiguous().view(-1, 1), y_offset.contiguous().view( + -1, 1), z_offset.contiguous().view(-1, 1)), + dim=-1) + voxel_centers = point_centers[:, None, :] + xyz_offset[None, :, :] + return voxel_centers + + +@pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), +]) +def test_three_nn_vector_pool(device): + xyz = torch.tensor( + [[0.7911, 4.1821, 18.1309], [9.8552, 19.9272, 7.4532], + [17.0715, 9.8851, 5.8078], [4.3750, 1.1232, 18.0702], + [14.0227, 9.5781, 15.7914], [3.0038, 8.7471, 12.6253], + [17.1353, 13.0427, 13.4723], [1.4284, 12.0409, 16.0280], + [10.5802, 11.9821, 10.6400], [11.2924, 16.3918, 16.3261], + [8.6749, 4.3318, 19.6607], [6.7047, 10.6616, 16.7599], + [15.1153, 1.8694, 16.1620], [4.5372, 2.2882, 12.4915], + [12.0136, 0.5850, 4.2164], [15.2224, 13.8230, 19.8346], + [16.7076, 12.8573, 5.8789], [17.8641, 18.0247, 0.7161], + [12.7604, 10.6771, 19.1813], [10.3219, 10.4839, 14.7624]], + device=device) + new_xyz = torch.tensor( + [[0.1411, 15.6141, 9.3022], [15.6595, 0.9505, 19.3470], + [8.0824, 10.3586, 17.3501], [7.3926, 9.9670, 6.6586], + [13.8781, 8.9048, 5.8824], [11.1121, 0.0274, 9.4883], + [0.4287, 1.5586, 6.9646], [2.7858, 1.8852, 15.0609], + [6.0411, 2.8716, 18.9102], [9.1480, 10.8151, 17.0509], + [5.1243, 8.9133, 18.5356], [19.7255, 14.6383, 9.3120]], + device=device) + expected_output = torch.tensor( + [[[8.9668, 9.9123, 8.9668], [5.8039, 6.9644, 5.8039], + [11.5133, 12.0966, 11.5133], [8.2455, 10.6970, 8.2455], + [7.2748, 9.2679, 7.2748], [4.5848, 4.6181, 4.5848], + [10.9016, 10.9635, 10.9016], [7.4585, 9.3244, 7.4585]], + [[15.0939, 17.9669, 15.0939], [17.0984, 18.8328, 17.0984], + [12.3685, 14.7086, 12.3685], [14.7482, 15.7545, 14.7482], + [18.6902, 21.4343, 18.6902], [20.3433, 22.1651, 20.3433], + [16.5673, 18.7873, 16.5673], [18.4120, 19.6169, 18.4120]], + [[3.6334, 5.9936, 3.6334], [6.9725, 7.6524, 6.9725], + [4.4467, 5.3546, 4.4467], [5.6979, 8.6041, 5.6979], + [7.8712, 9.9901, 7.8712], [10.3592, 10.6063, 10.3592], + [8.7992, 9.1461, 8.7992], [9.8155, 11.0808, 9.8155]], + [[8.6804, 13.0858, 8.6804], [4.2508, 9.0162, 4.2508], + [9.3306, 12.3016, 9.3306], [11.6039, 11.6039, 11.6039], + [12.4476, 12.4476, 12.4476], [8.2018, 8.2018, 8.2018], + [12.5967, 12.5967, 12.5967], [8.4263, 8.4263, 8.4263]], + [[12.9395, 12.9395, 12.9395], [8.5032, 8.5032, 8.5032], + [12.6873, 12.6873, 12.6873], [8.1143, 8.1143, 8.1143], + [12.8858, 12.8858, 12.8858], [8.4212, 8.4212, 8.4212], + [12.6325, 12.6325, 12.6325], [8.0283, 8.0283, 8.0283]], + [[15.7088, 15.7088, 15.7088], [13.6476, 13.6476, 13.6476], + [12.4531, 12.4531, 12.4531], [9.7247, 9.7247, 9.7247], + [14.7927, 14.7927, 14.7927], [12.5823, 12.5823, 12.5823], + [11.2755, 11.2755, 11.2755], [8.1626, 8.1626, 8.1626]], + [[22.1453, 22.1453, 22.1453], [20.1414, 20.1414, 20.1414], + [20.3329, 20.3329, 20.3329], [18.1298, 18.1298, 18.1298], + [18.9714, 18.9714, 18.9714], [16.5884, 16.5884, 16.5884], + [9.3774, 13.7203, 13.8864], [5.7959, 9.0028, 9.2854]], + [[6.9179, 7.2212, 9.6249], [4.3552, 4.7617, 10.7712], + [5.1734, 5.4860, 7.4278], [0.7898, 5.1267, 7.0813], + [5.7097, 8.4436, 9.5155], [1.9266, 6.4671, 10.6737], + [4.9670, 6.3179, 7.0175], [3.3207, 4.4467, 6.9320]], + [[1.8430, 4.9515, 9.1642], [3.3854, 5.6567, 12.0132], + [3.4549, 4.4924, 5.2514], [4.4065, 5.3146, 8.8669], + [4.4036, 8.6555, 10.6368], [5.2398, 9.0773, 12.0225], + [6.0146, 7.0864, 7.5323], [6.6513, 8.3557, 8.9531]], + [[4.2699, 6.5834, 7.4548], [7.2911, 7.4260, 7.7916], + [5.6190, 6.1713, 8.2128], [6.4337, 8.9186, 8.9761], + [2.9627, 7.3499, 8.7873], [4.5682, 9.4009, 10.3213], + [4.5445, 5.7129, 9.8524], [5.7222, 8.1848, 10.7472]], + [[3.6267, 4.1702, 5.6785], [7.5045, 8.6098, 12.7871], + [1.4901, 4.3573, 11.4361], [5.1277, 8.7020, 12.5351], + [6.1439, 7.1931, 8.2296], [8.8366, 9.5812, 9.7204], + [6.1401, 6.2724, 6.7349], [7.8596, 8.4678, 9.8021]], + [[9.8402, 15.8095, 18.3264], [5.8844, 14.7694, 16.4738], + [12.0585, 17.5071, 18.9946], [9.1217, 16.5738, 17.2139], + [12.3116, 20.2601, 22.6166], [9.4538, 19.4592, 21.1432], + [14.1476, 21.6108, 23.1613], [11.7453, 20.8619, 21.7249]]], + device='cuda:0') + xyz_batch_cnt = torch.tensor([8, 12], device=device).int() + new_xyz_batch_cnt = torch.tensor([4, 8], device=device).int() + max_neighbour_distance = 4.8 + new_xyz_grid_centers = get_dense_voxels_by_center(new_xyz, + max_neighbour_distance, + (2, 2, 2)) + nsample = -1 + neighbor_type = 0 + avg_length_of_neighbor_idxs = 1000 + num_total_grids = 27 + neighbor_distance_multiplier = 2.0 + dist, idx, avg_length = three_nn_vector_pool_by_two_step( + xyz, xyz_batch_cnt, new_xyz, new_xyz_grid_centers, new_xyz_batch_cnt, + max_neighbour_distance, nsample, neighbor_type, + avg_length_of_neighbor_idxs, num_total_grids, + neighbor_distance_multiplier) + assert torch.allclose(dist, expected_output, 1e-4)