Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exposing c++ PHL Implementation in Python API #1453

Merged
merged 15 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ Layers
.. autoclass:: BilateralFilter
:members:

`PHLFilter`
~~~~~~~~~~~~~~~~~
.. autoclass:: PHLFilter

`SavitzkyGolayFilter`
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: SavitzkyGolayFilter
Expand Down
1 change: 1 addition & 0 deletions monai/csrc/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// filtering
m.def("bilateral_filter", &BilateralFilter, "Bilateral Filter");
m.def("phl_filter", &PermutohedralFilter, "Permutohedral Filter");

// lltm
m.def("lltm_forward", &lltm_forward, "LLTM forward");
Expand Down
2 changes: 1 addition & 1 deletion monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ torch::Tensor BilateralFilterCpu(torch::Tensor inputTensor, float spatialSigma,
// Preparing output tensor.
torch::Tensor outputTensor = torch::zeros_like(inputTensor);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputTensor.type(), "BilateralFilterCpu", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputTensor.scalar_type(), "BilateralFilterCpu", ([&] {
BilateralFilterCpu<scalar_t>(
inputTensor, outputTensor, spatialSigma, colorSigma);
}));
Expand Down
7 changes: 3 additions & 4 deletions monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,12 @@ void BilateralFilterPHLCpu(
}

// Filtering data with respect to the features.
scalar_t* output =
PermutohedralCPU<scalar_t>(data, features, desc.channelCount, featureChannels, desc.channelStride);
PermutohedralCPU<scalar_t>(data, features, desc.channelCount, featureChannels, desc.channelStride);

// Writing output tensor.
for (int i = 0; i < desc.channelStride; i++) {
for (int c = 0; c < desc.channelCount; c++) {
outputTensorData[batchOffset + i + c * desc.channelStride] = output[i * desc.channelCount + c];
outputTensorData[batchOffset + i + c * desc.channelStride] = data[i * desc.channelCount + c];
}
}
}
Expand All @@ -81,7 +80,7 @@ void BilateralFilterPHLCpu(
torch::Tensor BilateralFilterPHLCpu(torch::Tensor inputTensor, float spatialSigma, float colorSigma) {
torch::Tensor outputTensor = torch::zeros_like(inputTensor);

AT_DISPATCH_FLOATING_TYPES(inputTensor.type(), "BilateralFilterPhlCpu", ([&] {
AT_DISPATCH_FLOATING_TYPES(inputTensor.scalar_type(), "BilateralFilterPhlCpu", ([&] {
BilateralFilterPHLCpu<scalar_t>(inputTensor, outputTensor, spatialSigma, colorSigma);
}));

Expand Down
28 changes: 21 additions & 7 deletions monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ __global__ void BilateralFilterCudaKernel1D(scalar_t* input, scalar_t* output) {
int homeOffset = blockIdx.x * blockDim.x + threadIdx.x;
int batchOffset = blockIdx.y * cBatchStride;

if (homeOffset >= cColorStride)
return;

scalar_t weightSum = 0;

for (int kernelOffset = 0; kernelOffset < cKernelSize; kernelOffset++) {
Expand Down Expand Up @@ -79,6 +82,9 @@ __global__ void BilateralFilterCudaKernel2D(scalar_t* input, scalar_t* output) {
int homeOffset = blockIdx.x * blockDim.x + threadIdx.x;
int batchOffset = blockIdx.y * cBatchStride;

if (homeOffset >= cColorStride)
return;

int homeX = homeOffset / cStrides[0];
int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1];

Expand Down Expand Up @@ -132,6 +138,9 @@ __global__ void BilateralFilterCudaKernel3D(scalar_t* input, scalar_t* output) {
int homeOffset = blockIdx.x * blockDim.x + threadIdx.x;
int batchOffset = blockIdx.y * cBatchStride;

if (homeOffset >= cColorStride)
return;

int homeX = homeOffset / cStrides[0];
int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1];
int homeZ = (homeOffset - homeX * cStrides[0] - homeY * cStrides[1]) / cStrides[2];
Expand Down Expand Up @@ -211,22 +220,27 @@ void BilateralFilterCuda(torch::Tensor inputTensor, torch::Tensor outputTensor,
cudaMemcpyToSymbol(cKernel, kernel, sizeof(float) * kernelSize);
cudaMemcpyToSymbol(cColorExponentFactor, &colorExponentFactor, sizeof(float));

#define BLOCK_SIZE 32

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
inputTensor.type(), "BilateralFilterCudaKernel", ([&] {
inputTensor.scalar_type(), "BilateralFilterCudaKernel", ([&] {
// Dispatch kernel. (Partial template function specialisation not supported at present so using this switch
// instead)
switch (D) {
case (1):
BilateralFilterCudaKernel1D<scalar_t, C><<<dim3(desc.channelStride, desc.batchCount), dim3(1, 1)>>>(
inputTensor.data_ptr<scalar_t>(), outputTensor.data_ptr<scalar_t>());
BilateralFilterCudaKernel1D<scalar_t, C>
<<<dim3(int(desc.channelStride / BLOCK_SIZE) + 1, desc.batchCount), dim3(BLOCK_SIZE, 1)>>>(
inputTensor.data_ptr<scalar_t>(), outputTensor.data_ptr<scalar_t>());
break;
case (2):
BilateralFilterCudaKernel2D<scalar_t, C><<<dim3(desc.channelStride, desc.batchCount), dim3(1, 1)>>>(
inputTensor.data_ptr<scalar_t>(), outputTensor.data_ptr<scalar_t>());
BilateralFilterCudaKernel2D<scalar_t, C>
<<<dim3(int(desc.channelStride / BLOCK_SIZE) + 1, desc.batchCount), dim3(BLOCK_SIZE, 1)>>>(
inputTensor.data_ptr<scalar_t>(), outputTensor.data_ptr<scalar_t>());
break;
case (3):
BilateralFilterCudaKernel3D<scalar_t, C><<<dim3(desc.channelStride, desc.batchCount), dim3(1, 1)>>>(
inputTensor.data_ptr<scalar_t>(), outputTensor.data_ptr<scalar_t>());
BilateralFilterCudaKernel3D<scalar_t, C>
<<<dim3(int(desc.channelStride / BLOCK_SIZE) + 1, desc.batchCount), dim3(BLOCK_SIZE, 1)>>>(
inputTensor.data_ptr<scalar_t>(), outputTensor.data_ptr<scalar_t>());
break;
}
}));
Expand Down
17 changes: 14 additions & 3 deletions monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ __global__ void FeatureCreation(const scalar_t* inputTensor, scalar_t* outputDat
int elementIndex = blockIdx.x * blockDim.x + threadIdx.x;
int batchIndex = blockIdx.y;

if (elementIndex >= cChannelStride)
return;

int dataBatchOffset = batchIndex * cBatchStride;
int featureBatchOffset = batchIndex * (D + C) * cChannelStride;

Expand All @@ -56,6 +59,10 @@ template <typename scalar_t, int C>
__global__ void WriteOutput(const scalar_t* data, scalar_t* outputTensor) {
int elementIndex = blockIdx.x * blockDim.x + threadIdx.x;
int batchIndex = blockIdx.y;

if (elementIndex >= cChannelStride)
return;

int batchOffset = batchIndex * cBatchStride;

#pragma unroll
Expand Down Expand Up @@ -95,9 +102,12 @@ void BilateralFilterPHLCuda(
cudaMemcpyToSymbol(cInvSpatialSigma, &invSpatialSigma, sizeof(float));
cudaMemcpyToSymbol(cInvColorSigma, &invColorSigma, sizeof(float));

#define BLOCK_SIZE 32

// Creating features
FeatureCreation<scalar_t, C, D>
<<<dim3(desc.channelStride, desc.batchCount), dim3(1, 1)>>>(inputTensorData, data, features);
<<<dim3(int(desc.channelStride / BLOCK_SIZE) + 1, desc.batchCount), dim3(BLOCK_SIZE, 1)>>>(
inputTensorData, data, features);

// Filtering data with respect to the features for each sample in batch
for (int batchIndex = 0; batchIndex < desc.batchCount; batchIndex++) {
Expand All @@ -108,7 +118,8 @@ void BilateralFilterPHLCuda(
}

// Writing output
WriteOutput<scalar_t, C><<<dim3(desc.channelStride, desc.batchCount), dim3(1, 1)>>>(data, outputTensorData);
WriteOutput<scalar_t, C><<<dim3(int(desc.channelStride / BLOCK_SIZE) + 1, desc.batchCount), dim3(BLOCK_SIZE, 1)>>>(
data, outputTensorData);

cudaFree(data);
cudaFree(features);
Expand All @@ -119,7 +130,7 @@ torch::Tensor BilateralFilterPHLCuda(torch::Tensor inputTensor, float spatialSig
torch::Tensor outputTensor = torch::zeros_like(inputTensor);

#define CASE(c, d) \
AT_DISPATCH_FLOATING_TYPES(inputTensor.type(), "BilateralFilterCudaPHL", ([&] { \
AT_DISPATCH_FLOATING_TYPES(inputTensor.scalar_type(), "BilateralFilterCudaPHL", ([&] { \
BilateralFilterPHLCuda<scalar_t, c, d>( \
inputTensor, outputTensor, spatialSigma, colorSigma); \
}));
Expand Down
3 changes: 2 additions & 1 deletion monai/csrc/filtering/filtering.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ limitations under the License.

#pragma once

#include "bilateral/bilateral.h"
#include "bilateral/bilateral.h"
#include "permutohedral/permutohedral.h"
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,14 @@ static scalar_t* createHashTable(int capacity) {
template <typename scalar_t>
static void destroyHashTable() {
#ifndef LINEAR_D_MEMORY
cudaFree(table_keys);
signed short* keys;
cudaMemcpyFromSymbol(&keys, table_keys, sizeof(unsigned int*));
cudaFree(keys);
#endif
cudaFree(table_entries);

int* entries;
cudaMemcpyFromSymbol(&entries, table_entries, sizeof(int*));
cudaFree(entries);
}

template <int kd>
Expand Down
71 changes: 71 additions & 0 deletions monai/csrc/filtering/permutohedral/permutohedral.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#include "utils/common_utils.h"
#include "utils/meta_macros.h"

#include "permutohedral.h"

torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features) {
input = input.contiguous();

int batchCount = input.size(0);
int batchStride = input.stride(0);
int elementCount = input.stride(1);
int channelCount = input.size(1);
int featureCount = features.size(1);

// movedim not support in torch < 1.7.1
#if MONAI_TORCH_VERSION >= 10701
torch::Tensor data = input.clone().movedim(1, -1).contiguous();
features = features.movedim(1, -1).contiguous();
#else
torch::Tensor data = input.clone();
features = features;

for (int i = 1; i < input.dim() - 1; i++) {
data = data.transpose(i, i + 1);
features = features.transpose(i, i + 1);
}

data = data.contiguous();
features = features.contiguous();
#endif

#ifdef WITH_CUDA
if (torch::cuda::is_available() && data.is_cuda()) {
CHECK_CONTIGUOUS_CUDA(data);

#define CASE(dc, fc) \
AT_DISPATCH_FLOATING_TYPES(data.scalar_type(), "PermutohedralCuda", ([&] { \
for (int batchIndex = 0; batchIndex < batchCount; batchIndex++) { \
scalar_t* offsetData = data.data_ptr<scalar_t>() + batchIndex * batchStride; \
scalar_t* offsetFeatures = \
features.data_ptr<scalar_t>() + batchIndex * fc * elementCount; \
PermutohedralCuda<scalar_t, dc, fc>(offsetData, offsetFeatures, elementCount, true); \
} \
}));
SWITCH_AB(CASE, 16, 19, channelCount, featureCount);

} else {
#endif
AT_DISPATCH_FLOATING_TYPES(
data.scalar_type(), "PermutohedralCPU", ([&] {
for (int batchIndex = 0; batchIndex < batchCount; batchIndex++) {
scalar_t* offsetData = data.data_ptr<scalar_t>() + batchIndex * batchStride;
scalar_t* offsetFeatures = features.data_ptr<scalar_t>() + batchIndex * featureCount * elementCount;
PermutohedralCPU<scalar_t>(offsetData, offsetFeatures, channelCount, featureCount, elementCount);
}
}));
#ifdef WITH_CUDA
}
#endif

// movedim not support in torch < 1.7.1
#if MONAI_TORCH_VERSION >= 10701
data = data.movedim(-1, 1);
#else
for (int i = input.dim() - 1; i > 1; i--) {
data = data.transpose(i - 1, i);
}
#endif

return data;
}
6 changes: 5 additions & 1 deletion monai/csrc/filtering/permutohedral/permutohedral.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

#include <torch/extension.h>

#pragma once
template <typename scalar_t>
scalar_t* PermutohedralCPU(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount);
void PermutohedralCPU(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount);
#ifdef WITH_CUDA
template <typename scalar_t, int dc, int fc>
void PermutohedralCuda(scalar_t* data, scalar_t* features, int elementCount, bool accurate);
#endif

torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features);
26 changes: 6 additions & 20 deletions monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ class PermutohedralLattice {
* im : image to be bilateral-filtered.
* ref : reference image whose edges are to be respected.
*/
static scalar_t* filter(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount) {
static void filter(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount) {
// Create lattice
PermutohedralLattice lattice(featureChannels, dataChannels + 1, elementCount);

Expand All @@ -236,20 +236,16 @@ class PermutohedralLattice {
lattice.blur();

// Slice from the lattice
scalar_t* outputData = new scalar_t[elementCount * dataChannels];

lattice.beginSlice();

for (int i = 0, e = 0; e < elementCount; e++) {
lattice.slice(col);

scalar_t scale = 1.0f / col[dataChannels];
for (int c = 0; c < dataChannels; c++, i++) {
outputData[i] = col[c] * scale;
data[i] = col[c] * scale;
}
}

return outputData;
}

/* Constructor
Expand Down Expand Up @@ -498,19 +494,9 @@ class PermutohedralLattice {
};

template <typename scalar_t>
scalar_t* PermutohedralCPU(
scalar_t* data,
scalar_t* features,
int dataChannels,
int featureChannels,
int elementCount) {
return PermutohedralLattice<scalar_t>::filter(data, features, dataChannels, featureChannels, elementCount);
void PermutohedralCPU(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount) {
PermutohedralLattice<scalar_t>::filter(data, features, dataChannels, featureChannels, elementCount);
}

template float* PermutohedralCPU(float* data, float* features, int dataChannels, int featureChannels, int elementCount);
template double* PermutohedralCPU(
double* data,
double* features,
int dataChannels,
int featureChannels,
int elementCount);
template void PermutohedralCPU(float* data, float* features, int dataChannels, int featureChannels, int elementCount);
template void PermutohedralCPU(double* data, double* features, int dataChannels, int featureChannels, int elementCount);
2 changes: 1 addition & 1 deletion monai/csrc/filtering/permutohedral/permutohedral_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ SOFTWARE.
#include <torch/extension.h>
#include <THC/THCAtomics.cuh>

#include "hash_table.cu"
#include "hash_table.cuh"
#include "utils/meta_macros.h"

template <typename scalar_t>
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding
from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args
from .filtering import BilateralFilter
from .filtering import BilateralFilter, PHLFilter
from .simplelayers import (
LLTM,
ChannelPad,
Expand Down
Loading