Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Compile]Error compiling with TENSORRT on CUDA 12.2 #55016

Closed
engineer1109 opened this issue Jun 30, 2023 · 6 comments
Closed

[Compile]Error compiling with TENSORRT on CUDA 12.2 #55016

engineer1109 opened this issue Jun 30, 2023 · 6 comments
Assignees
Labels
NVIDIA PFCC Paddle Framework Contributor Club,https://github.com/PaddlePaddle/community/tree/master/pfcc status/close 已关闭 type/bug-report 报bug

Comments

@engineer1109
Copy link
Contributor

bug描述 Describe the Bug

目前版本分支 develop 12a296c
cmake .. -DWITH_CUSTOM_DEVICE=ON -DWITH_GPU=ON -DWITH_TENSORRT=ON

CUDA 12.2 apt 安装
TENSOR 8.6.1.6 apt 安装 libnvinfer-dev libnvinfer-plugin-dev

出现大量的相同编译错误

/media/wjl/D2/github/fork/Paddle/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_kernel.cu(83): error: no instance of function template "cuda::std::__4::plus<void>::operator()" matches the argument list
            argument types are: (cub::CUB_200101_750_NS::KeyValuePair<float, float>, cub::CUB_200101_750_NS::KeyValuePair<float, float>)
            object type is: cub::CUB_200101_750_NS::Sum
      threadData = pairSum(threadData, kvp<T>(rldval, rldval * val));
                   ^
          detected during:
            instantiation of "void paddle::inference::tensorrt::plugin::embLayerNormKernel_2<T,TPB>(int32_t, const int32_t *, const int32_t *, const float *, const float *, const T *, const T *, int32_t, int32_t, T *) [with T=float, TPB=256U]" at line 279
            instantiation of "int32_t paddle::inference::tensorrt::plugin::embSkipLayerNorm_2(cudaStream_t, int32_t, int32_t, int32_t, const int32_t *, const int32_t *, int32_t, const float *, const float *, const T *, const T *, int32_t, int32_t, T *) [with T=float]" at line 365

cub::Sum 有问题

其他补充信息 Additional Supplementary Information

No response

@paddle-bot paddle-bot bot added the PFCC Paddle Framework Contributor Club,https://github.com/PaddlePaddle/community/tree/master/pfcc label Jun 30, 2023
@ForFishes
Copy link
Member

您好,请问用的是官方镜像吗?使用官方发布的镜像编译试一试呢?https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html

@engineer1109
Copy link
Contributor Author

@ForFishes 没用镜像,问题来源是CUDA12.1 没问题, CUDA 12.2有问题

@paddle-bot paddle-bot bot added status/following-up 跟进中 and removed status/new-issue 新建 labels Jul 4, 2023
@jeng1220
Copy link
Collaborator

nvbugs 4202615

@jeng1220
Copy link
Collaborator

jeng1220 commented Jul 20, 2023

cub::Sum是__host__ __device__ __forceinline__ T cub::Sum::operator()
最快的workaround是將其替換成
// threadData = pairSum(threadData, kvp(rldval, rldval * val));
threadData.key += rldva;
threadData.value += rldval * val;

目前看起來,問題是新的cub::Sum複用::cuda::std::plus<>所引起的:
https://github.com/NVIDIA/cub/blob/main/cub/thread/thread_operators.cuh#L79

舊cub是自行實現Sum:
https://github.com/NVIDIA/cub/blob/2.0.X/cub/thread/thread_operators.cuh#L97C1-L106C3

因cub是底層算子,warp_reduce、block_reduce 和 device_reduce 也會受影響
故不是光改一個地方就能解決問題

@jeng1220
Copy link
Collaborator

jeng1220 commented Jul 21, 2023

@engineer1109 ,
問題應已修復,若沒問題的話,麻煩關閉Issue。

cqulilujia pushed a commit to cqulilujia/Paddle that referenced this issue Jul 24, 2023
@jeng1220
Copy link
Collaborator

@engineer1109 ,
由於問題已修復,故關閉這Issue,若你仍遇到問題,請再此開啟Issue

@paddle-bot paddle-bot bot added status/close 已关闭 and removed status/following-up 跟进中 labels Jul 27, 2023
wz1qqx pushed a commit to wz1qqx/Paddle that referenced this issue Jul 31, 2023
jinjidejinmuyan pushed a commit to jinjidejinmuyan/Paddle that referenced this issue Aug 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
NVIDIA PFCC Paddle Framework Contributor Club,https://github.com/PaddlePaddle/community/tree/master/pfcc status/close 已关闭 type/bug-report 报bug
Projects
None yet
Development

No branches or pull requests

3 participants