Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix reduce_any kernel data race on sharedMem #47233

Merged
merged 6 commits into from
Oct 27, 2022

Conversation

zhangbopd
Copy link
Contributor

@zhangbopd zhangbopd commented Oct 20, 2022

PR types

Bug fixes

PR changes

OPs

Describe

  • 分析并解决了 issue: Potential race-condition in phi::funcs::ReduceAnyKernel #46974 在sharedMem中出现data race的问题,

    1. 问题原因: 代码中存在val = shared[bid * block_dim_x + lane]; 读取了sharedMem中的数据,后续的CudaShuffleDownSync虽然进行了 warp内的同步,但是不同warp间的线程未同步,而紧接着的shared[threadIdx.y] = val; 在sharedMem上写数据时,与前述的sharedMem中读取数据的行为产生了data race.
    2. 解决方法:在写数据操作前加入线程同步操作.
    3. 补充说明:若使用如下判断替代线程同步操作也可解决data race,但是性能上不如线程同步
      if (wid % block_dim_x == 0) { val = shared[bid * block_dim_x + lane] };
  • kp 涉及的文件较多,ci任务超时,benchmark本地测得的性能数据如下图,个别显卡中间被别的任务占用,部分数据不准确。
    image

  • 使用位运算替代除法运算、取模运算,可带来小幅度性能提升 。

@paddle-bot
Copy link

paddle-bot bot commented Oct 20, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

int tid = threadIdx.y * blockDim.x + threadIdx.x;
int wid = tid / kWarpSize;
int lane, tid, wid, n;
if (kWarpSize == 32 || kWarpSize == 64) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kWarpSize一定满足这个条件吧,还需要这个判断吗

Copy link
Contributor Author

@zhangbopd zhangbopd Oct 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原考虑NV后面架构、warpSize的大小或许变动,保持了原来计算的分支,后面将删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@zhangbopd zhangbopd requested a review from ZzSean October 24, 2022 08:40
ZzSean
ZzSean previously approved these changes Oct 25, 2022
Copy link
Contributor

@ZzSean ZzSean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for CI-OP-Benchmark

int lane, tid, wid, bid, n;
// Bit operation can be used when kWarpSize is 32 or 64 now
n = kWarpSize == 32 ? 5 : 6;
block_dim_x = blockDim.x >> n;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的n = kWarpSize == 32 ? 5 : 6; 存在两处可以继续修改的地方:

  1. n作为右移位的规模,可以改为int rshift_val
  2. 这个数字可以在编译过程中判断,所以可以改为:
constexpr int rshift_val = (kWarpSize != 32) ? ((kWarpSize == 64) ? 6 : 5) : 5;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int tid = threadIdx.y * blockDim.x + threadIdx.x;
int wid = tid / kWarpSize;
int bid = threadIdx.y;
int lane, tid, wid, bid, n;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码习惯上的小建议:变量声明和赋值分离,且声明时未赋初值,会影响代码的维护和问题排查;建议改回原来的直接声明时赋值的写法

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前因if分支考虑到变量生命周期,导致变量声明和赋值分离,现已取消分支,按照建议修改。done

Copy link
Contributor

@ZzSean ZzSean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for CI-OP-Benchmark

@JamesLim-sy
Copy link
Contributor

@niuliling123 请大佬review下对于Kps的修改;
@zhangbopd 请在PR描述中解释下为什么对CUDA Kernel进行修改就能够解决issue46974中出现的问题;
@mingxu1067 请大佬reveiw下是否有效地解决了issue问题.

Copy link
Contributor

@JamesLim-sy JamesLim-sy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
等后续的reviewer全部通过后可以合入

Copy link
Contributor

@AnnaTrainingG AnnaTrainingG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mingxu1067
Copy link
Collaborator

Veried, no more errors reported from copmute-sanitizer. LGTM

@JamesLim-sy JamesLim-sy merged commit 77dbb31 into PaddlePaddle:develop Oct 27, 2022
@zhangbopd zhangbopd deleted the fix_reduce_any_data_race branch October 27, 2022 05:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants