Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebGPU] Support warp-level shuffle primitives with subgroup #17699

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/target/source/codegen_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ std::string CodeGenWebGPU::Finish() {
if (enable_fp16_) {
header_stream << "enable f16;\n\n";
}
// TODO(Charlie): Add enable_subgroups_ to control
header_stream << "enable subgroups;\n\n";
return header_stream.str() + decl_stream.str() + this->fwd_decl_stream.str() + stream.str();
}

Expand Down
56 changes: 56 additions & 0 deletions src/target/source/intrin_rule_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,29 @@ namespace intrin {

using tir::FLowerIntrinsic;

// warp-level primitives. Follows implementation in intrin_rule_metal.cc
struct WebGPUWarpIntrinsic {
const Op operator()(DataType t, const Op& orig_op) const {
if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
return Op::Get("tir.webgpu.subgroup_shuffle");
} else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
return Op::Get("tir.webgpu.subgroup_shuffle_up");
} else {
ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
return Op::Get("tir.webgpu.subgroup_shuffle_down");
}
}
};

template <typename T>
static PrimExpr DispatchWebGPUShuffle(const PrimExpr& e) {
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
Array<PrimExpr> webgpu_args{{call->args[1], call->args[2]}};
return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), webgpu_args);
}

// See full list of builtin: https://www.w3.org/TR/WGSL/#builtin-functions

struct ReturnAbs {
Expand Down Expand Up @@ -113,6 +136,39 @@ TVM_REGISTER_OP("tir.trunc")
// extra dispatch
TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchFastErf);

// warp-level primitives. Follows implementation in intrin_rule_metal.cc
TVM_REGISTER_OP("tir.tvm_warp_shuffle")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchWebGPUShuffle<WebGPUWarpIntrinsic>);

TVM_REGISTER_OP("tir.tvm_warp_shuffle_up")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchWebGPUShuffle<WebGPUWarpIntrinsic>);

TVM_REGISTER_OP("tir.tvm_warp_shuffle_down")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchWebGPUShuffle<WebGPUWarpIntrinsic>);

// Register low-level builtin ops.
TVM_REGISTER_OP("tir.webgpu.subgroup_shuffle")
.set_num_inputs(2)
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("lane", "Expr", "The source thread id.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "subgroupShuffle")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TVM_REGISTER_OP("tir.webgpu.subgroup_shuffle_up")
.set_num_inputs(2)
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("delta", "Expr", "The source lane id offset to be added.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "subgroupShuffleUp")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TVM_REGISTER_OP("tir.webgpu.subgroup_shuffle_down")
.set_num_inputs(2)
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("delta", "Expr", "The source lane id offset to be subtracted.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "subgroupShuffleDown")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));


} // namespace intrin
} // namespace codegen
} // namespace tvm
2 changes: 2 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,8 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)

TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU)
.add_attr_option<runtime::Int>("max_num_threads", runtime::Int(256))
// TODO(Charlie): Not all WebGPU supports this, need a control logic
.add_attr_option<runtime::Int>("thread_warp_size", runtime::Int(32))
.set_default_keys({"webgpu", "gpu"});

TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon)
Expand Down
13 changes: 9 additions & 4 deletions src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
//
// The former may cause dead lock as there is a divergent
// branch with a warp sync call inside.
PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_buffer, val, offset);
bool cast_offset_to_uint = target_->kind->name == "webgpu";
PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_buffer, val, offset,
cast_offset_to_uint);
Buffer local_buf = local_bufs[i];
Stmt s = BufferStore(local_buf, other, zero_indices);
seq->push_back(s);
Expand Down Expand Up @@ -694,7 +696,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {

// Emit warp shuffle calls.
PrimExpr WarpShuffle(const Op& op, Optional<Buffer> mask_buffer, PrimExpr val,
PrimExpr delta_or_lane) {
PrimExpr delta_or_lane, bool cast_delta_to_uint = false) {
if (cast_delta_to_uint) {
delta_or_lane = cast(DataType::UInt(32, delta_or_lane.dtype().lanes()), delta_or_lane);
}
Array<PrimExpr> indices = {0};
PrimExpr mask;
if (mask_buffer.defined()) {
Expand All @@ -714,11 +719,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
bool IsWarpReduction(const std::vector<DataType>& types, int group_extent, int reduce_extent,
int contiguous_reduce_extent) {
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") &&
(target_->kind->name != "metal")) {
(target_->kind->name != "metal") && (target_->kind->name != "webgpu")) {
return false;
}

need_warp_shuffle_mask_ = target_->kind->name != "metal";
need_warp_shuffle_mask_ = target_->kind->name != "metal" && target_->kind->name != "webgpu";

// rocm only supports 32 bit operands for shuffling at the moment
if ((target_->kind->name == "rocm") &&
Expand Down
8 changes: 7 additions & 1 deletion web/src/webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,17 @@ export async function detectGPUDevice(powerPreference: "low-power" | "high-perfo
);
}

const requiredFeatures: GPUFeatureName[] = [];
// TODO(Charlie): cannot type annotate because @webgpu/types
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@webgpu/types 0.1.55 should work now. See gpuweb/types#167

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks!

// does not have "subgroups" as GPUFeatureName yet
// const requiredFeatures: GPUFeatureName[] = [];
const requiredFeatures = [];
// Always require f16 if available
if (adapter.features.has("shader-f16")) {
requiredFeatures.push("shader-f16");
}
if (adapter.features.has("subgroups")) {
requiredFeatures.push("subgroups");
}

// requestAdapterInfo() is deprecated, causing requestAdapterInfo to raise
// issue when building. However, it is still needed for older browsers, hence `as any`.
Expand Down