Skip to content

Commit

Permalink
Merge pull request #3378 from cudawarped:replace_texture_ref_with_tex…
Browse files Browse the repository at this point in the history
…ture_obj

Fix CUDA texture bugs and replace all instances of CUDA texture references with texture objects
  • Loading branch information
asmorkalov authored Dec 20, 2022
2 parents b5f4e24 + 8a6ea82 commit 8db3e62
Show file tree
Hide file tree
Showing 33 changed files with 1,135 additions and 2,329 deletions.
84 changes: 13 additions & 71 deletions modules/cudaarithm/src/cuda/lut.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,14 @@
#include "opencv2/cudaarithm.hpp"
#include "opencv2/cudev.hpp"
#include "opencv2/core/private.cuda.hpp"
#include <opencv2/cudev/ptr2d/texture.hpp>

using namespace cv;
using namespace cv::cuda;
using namespace cv::cudev;

namespace cv { namespace cuda {

texture<uchar, cudaTextureType1D, cudaReadModeElementType> texLutTable;

LookUpTableImpl::LookUpTableImpl(InputArray _lut)
{
if (_lut.kind() == _InputArray::CUDA_GPU_MAT)
Expand All @@ -73,83 +72,28 @@ namespace cv { namespace cuda {
Mat h_lut = _lut.getMat();
d_lut.upload(Mat(1, 256, h_lut.type(), h_lut.data));
}

CV_Assert( d_lut.depth() == CV_8U );
CV_Assert( d_lut.rows == 1 && d_lut.cols == 256 );

cc30 = deviceSupports(FEATURE_SET_COMPUTE_30);

if (cc30)
{
// Use the texture object
cudaResourceDesc texRes;
std::memset(&texRes, 0, sizeof(texRes));
texRes.resType = cudaResourceTypeLinear;
texRes.res.linear.devPtr = d_lut.data;
texRes.res.linear.desc = cudaCreateChannelDesc<uchar>();
texRes.res.linear.sizeInBytes = 256 * d_lut.channels() * sizeof(uchar);

cudaTextureDesc texDescr;
std::memset(&texDescr, 0, sizeof(texDescr));

CV_CUDEV_SAFE_CALL( cudaCreateTextureObject(&texLutTableObj, &texRes, &texDescr, 0) );
}
else
{
// Use the texture reference
cudaChannelFormatDesc desc = cudaCreateChannelDesc<uchar>();
CV_CUDEV_SAFE_CALL( cudaBindTexture(0, &texLutTable, d_lut.data, &desc) );
}
}

LookUpTableImpl::~LookUpTableImpl()
{
if (cc30)
{
// Use the texture object
cudaDestroyTextureObject(texLutTableObj);
}
else
{
// Use the texture reference
cudaUnbindTexture(texLutTable);
}
szInBytes = 256 * d_lut.channels() * sizeof(uchar);
}

struct LutTablePtrC1
{
typedef uchar value_type;
typedef uchar index_type;

cudaTextureObject_t texLutTableObj;

__device__ __forceinline__ uchar operator ()(uchar, uchar x) const
{
#if CV_CUDEV_ARCH < 300
// Use the texture reference
return tex1Dfetch(texLutTable, x);
#else
// Use the texture object
return tex1Dfetch<uchar>(texLutTableObj, x);
#endif
cv::cudev::TexturePtr<uchar> tex;
__device__ __forceinline__ uchar operator ()(uchar, uchar x) const {
return tex(x);
}
};

struct LutTablePtrC3
{
typedef uchar3 value_type;
typedef uchar3 index_type;

cudaTextureObject_t texLutTableObj;

__device__ __forceinline__ uchar3 operator ()(const uchar3&, const uchar3& x) const
{
#if CV_CUDEV_ARCH < 300
// Use the texture reference
return make_uchar3(tex1Dfetch(texLutTable, x.x * 3), tex1Dfetch(texLutTable, x.y * 3 + 1), tex1Dfetch(texLutTable, x.z * 3 + 2));
#else
// Use the texture object
return make_uchar3(tex1Dfetch<uchar>(texLutTableObj, x.x * 3), tex1Dfetch<uchar>(texLutTableObj, x.y * 3 + 1), tex1Dfetch<uchar>(texLutTableObj, x.z * 3 + 2));
#endif
cv::cudev::TexturePtr<uchar> tex;
__device__ __forceinline__ uchar3 operator ()(const uchar3&, const uchar3& x) const {
return make_uchar3(tex(x.x * 3), tex(x.y * 3 + 1), tex(x.z * 3 + 2));
}
};

Expand All @@ -169,20 +113,18 @@ namespace cv { namespace cuda {
{
GpuMat_<uchar> src1(src.reshape(1));
GpuMat_<uchar> dst1(dst.reshape(1));

cv::cudev::Texture<uchar> tex(szInBytes, reinterpret_cast<uchar*>(d_lut.data));
LutTablePtrC1 tbl;
tbl.texLutTableObj = texLutTableObj;

tbl.tex = TexturePtr<uchar>(tex);
dst1.assign(lut_(src1, tbl), stream);
}
else if (lut_cn == 3)
{
GpuMat_<uchar3>& src3 = (GpuMat_<uchar3>&) src;
GpuMat_<uchar3>& dst3 = (GpuMat_<uchar3>&) dst;

cv::cudev::Texture<uchar> tex(szInBytes, reinterpret_cast<uchar*>(d_lut.data));
LutTablePtrC3 tbl;
tbl.texLutTableObj = texLutTableObj;

tbl.tex = TexturePtr<uchar>(tex);
dst3.assign(lut_(src3, tbl), stream);
}

Expand Down
6 changes: 1 addition & 5 deletions modules/cudaarithm/src/lut.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,10 @@ class LookUpTableImpl : public LookUpTable
{
public:
LookUpTableImpl(InputArray lut);
~LookUpTableImpl();

void transform(InputArray src, OutputArray dst, Stream& stream = Stream::Null()) CV_OVERRIDE;

private:
GpuMat d_lut;
cudaTextureObject_t texLutTableObj;
bool cc30;
size_t szInBytes = 0;
};

} }
Expand Down
Loading

0 comments on commit 8db3e62

Please sign in to comment.