From b698147e96070f0341094b7043a1d8eec190280f Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Thu, 16 Nov 2023 11:12:08 +0000 Subject: [PATCH 1/9] polish --- paddle/phi/api/yaml/backward.yaml | 1 + paddle/phi/api/yaml/ops.yaml | 1 + .../infermeta/spmd_rules/flash_attention.cc | 591 ++++++++++++++++++ .../infermeta/spmd_rules/flash_attention.h | 49 ++ .../semi_auto_parallel_for_flash_attention.py | 58 ++ .../test_semi_auto_parallel_basic.py | 10 + 6 files changed, 710 insertions(+) create mode 100644 paddle/phi/infermeta/spmd_rules/flash_attention.cc create mode 100644 paddle/phi/infermeta/spmd_rules/flash_attention.h create mode 100644 test/auto_parallel/semi_auto_parallel_for_flash_attention.py diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 89e36ef33c4ff..bb9b5e3ee83c1 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -830,6 +830,7 @@ infer_meta : func : FlashAttnGradInferMeta param : [q, k, v] + spmd_rule : FlashAttGradInferSpmd kernel : func : flash_attn_grad data_type: q diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 945ba32109ac8..968e0b64b413f 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -929,6 +929,7 @@ infer_meta : func : FlashAttnInferMeta param : [q, k, v] + spmd_rule : FlashAttInferSpmd kernel : func : flash_attn data_type : q diff --git a/paddle/phi/infermeta/spmd_rules/flash_attention.cc b/paddle/phi/infermeta/spmd_rules/flash_attention.cc new file mode 100644 index 0000000000000..f4193f1d542d3 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/flash_attention.cc @@ -0,0 +1,591 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/flash_attention.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { +using phi::distributed::auto_parallel::str_join; + +SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q, + const DistMetaTensor& k, + const DistMetaTensor& v, + const DistMetaTensor& fixed_seed_offset, + const DistMetaTensor& attn_mask, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name) { + // q + // [batch_size, seq_len_q, num_heads, head_dim] + auto q_shape = phi::vectorize(q.dims()); + int q_ndim = q_shape.size(); + auto q_dist_attr = q.dist_attr(); + int q_dims_mapping_size = q_dist_attr.dims_mapping().size(); + + PADDLE_ENFORCE_EQ( + q_ndim, + 4, + phi::errors::InvalidArgument("The Tensor q's shape must be [batch_size, " + "seq_len_q, num_heads, head_dim]")); + + auto batch_size = q_shape[0]; + auto seq_len_q = q_shape[1]; + auto num_heads = q_shape[2]; + auto head_dim = q_shape[3]; + + PADDLE_ENFORCE_EQ( + q_ndim, + q_dims_mapping_size, + phi::errors::InvalidArgument("The Tensor q's rank [%d] and Its " + "dims_mapping size [%d] are not matched.", + q_ndim, + q_dims_mapping_size)); + + // k + // [batch_size, seq_len_kv, num_heads, head_dim] + auto k_shape = phi::vectorize(k.dims()); + int k_ndim = k_shape.size(); + auto k_dist_attr = k.dist_attr(); + int k_dims_mapping_size = k_dist_attr.dims_mapping().size(); + PADDLE_ENFORCE_EQ( + k_ndim, + 4, + phi::errors::InvalidArgument("The Tensor k's shape must be [batch_size, " + "seq_len_kv, num_heads, head_dim]")); + + auto k_batch_size = q_shape[0]; + auto k_seq_len = q_shape[1]; + auto k_num_heads = q_shape[2]; + auto k_head_dim = q_shape[3]; + + PADDLE_ENFORCE_EQ( + batch_size, + k_batch_size, + phi::errors::InvalidArgument( + "The Tensor q and k's batch size [%d] vs [%d] are not matched.", + batch_size, + k_batch_size)); + + PADDLE_ENFORCE_EQ( + num_heads, + k_num_heads, + phi::errors::InvalidArgument( + "The Tensor q and k's k_num_heads [%d] vs [%d] are not matched.", + num_heads, + k_num_heads)); + + PADDLE_ENFORCE_EQ( + head_dim, + k_head_dim, + phi::errors::InvalidArgument( + "The Tensor q and k's head_dim [%d] vs [%d] are not matched.", + head_dim, + k_head_dim)); + + PADDLE_ENFORCE_EQ( + k_ndim, + k_dims_mapping_size, + phi::errors::InvalidArgument("The Tensor q's rank [%d] and Its " + "dims_mapping size [%d] are not matched.", + k_ndim, + k_dims_mapping_size)); + + // v + // [batch_size, seq_len_kv, num_heads, head_dim] + auto v_shape = phi::vectorize(v.dims()); + int v_ndim = v_shape.size(); + auto v_dist_attr = v.dist_attr(); + int v_dims_mapping_size = v_dist_attr.dims_mapping().size(); + PADDLE_ENFORCE_EQ( + v_ndim, + 4, + phi::errors::InvalidArgument("The Tensor v's shape must be [batch_size, " + "seq_len_kv, num_heads, head_dim_v]")); + + auto v_batch_size = v_shape[0]; + auto v_seq_len = v_shape[1]; + auto v_num_heads = v_shape[2]; + auto v_head_dim = v_shape[3]; + + PADDLE_ENFORCE_EQ( + batch_size, + v_batch_size, + phi::errors::InvalidArgument( + "The Tensor q and v's batch size [%d] vs [%d] are not matched.", + batch_size, + v_batch_size)); + + PADDLE_ENFORCE_EQ( + num_heads, + v_num_heads, + phi::errors::InvalidArgument( + "The Tensor q and v's k_num_heads [%d] vs [%d] are not matched.", + num_heads, + v_num_heads)); + + PADDLE_ENFORCE_EQ( + k_seq_len, + v_seq_len, + phi::errors::InvalidArgument( + "The Tensor k and v's seq_len [%d] vs [%d] are not matched.", + k_seq_len, + v_seq_len)); + + PADDLE_ENFORCE_EQ( + v_ndim, + v_dims_mapping_size, + phi::errors::InvalidArgument("The Tensor q's rank [%d] and Its " + "dims_mapping size [%d] are not matched.", + v_ndim, + v_dims_mapping_size)); + + // fixed_seed_offset + + // attn_mask + auto mask_shape = phi::vectorize(attn_mask.dims()); + int mask_ndim = mask_shape.size(); + auto mask_dist_attr = attn_mask.dist_attr(); + int mask_dims_mapping_size = mask_dist_attr.dims_mapping().size(); + PADDLE_ENFORCE_EQ( + mask_ndim, + mask_dims_mapping_size, + phi::errors::InvalidArgument("The Tensor q's rank [%d] and Its " + "dims_mapping size [%d] are not matched.", + mask_ndim, + mask_dims_mapping_size)); + + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + int used_axes_index = 0; + char batch_axis = alphabet[used_axes_index++]; + char seq_len_q_axis = alphabet[used_axes_index++]; + char num_heads_axis = alphabet[used_axes_index++]; + char head_dim_axis = alphabet[used_axes_index++]; + char seq_len_kv_axis = alphabet[used_axes_index++]; + char head_dim_v_axis = alphabet[used_axes_index++]; + + // [batch_size, seq_len_q, num_heads, head_dim] + std::string q_axes = { + batch_axis, seq_len_q_axis, num_heads_axis, head_dim_axis}; + // [batch_size, seq_len_kv, num_heads, head_dim] + std::string k_axes = { + batch_axis, seq_len_kv_axis, num_heads_axis, head_dim_axis}; + // [batch_size, seq_len_kv, num_heads, head_dim_v] + std::string v_axes = { + batch_axis, seq_len_kv_axis, num_heads_axis, head_dim_v_axis}; + // [batch_size, seq_len_q, num_heads, head_dim_v] + std::string out_axes = { + batch_axis, seq_len_q_axis, num_heads_axis, head_dim_v_axis}; + // [batch_size, num_heads, seq_len_q, seq_len_kv] + std::string softmax_axes = { + batch_axis, num_heads_axis, seq_len_q_axis, seq_len_kv_axis}; + // [batch_size, num_heads, seq_len_q, seq_len_kv] + std::string softmax_lse_axes = { + batch_axis, num_heads_axis, seq_len_q_axis, seq_len_kv_axis}; + + std::string q_axes_align = q_axes; + q_axes_align[1] = alphabet[used_axes_index++]; + q_axes_align[3] = alphabet[used_axes_index++]; + + std::string k_axes_align = k_axes; + k_axes_align[1] = alphabet[used_axes_index++]; + k_axes_align[3] = alphabet[used_axes_index++]; + + std::string v_axes_align = v_axes; + v_axes_align[1] = alphabet[used_axes_index++]; + v_axes_align[3] = alphabet[used_axes_index++]; + + std::vector>> axes_sharding_info; + + axes_sharding_info.emplace_back(q_axes_align, q_dist_attr.dims_mapping()); + axes_sharding_info.emplace_back(k_axes_align, k_dist_attr.dims_mapping()); + axes_sharding_info.emplace_back(v_axes_align, k_dist_attr.dims_mapping()); + + auto axis_to_dim_map = ShardingMergeForTensors(axes_sharding_info); + + auto q_dist_attr_dst = CopyTensorDistAttrForOutput(q_dist_attr); + auto q_dims_mapping = GetDimsMappingForAxes(q_axes, axis_to_dim_map, true); + q_dist_attr_dst.set_dims_mapping(q_dims_mapping); + auto k_dist_attr_dst = CopyTensorDistAttrForOutput(k_dist_attr); + auto k_dims_mapping = GetDimsMappingForAxes(k_axes, axis_to_dim_map, true); + k_dist_attr_dst.set_dims_mapping(k_dims_mapping); + auto v_dist_attr_dst = CopyTensorDistAttrForOutput(v_dist_attr); + auto v_dims_mapping = GetDimsMappingForAxes(v_axes, axis_to_dim_map, true); + v_dist_attr_dst.set_dims_mapping(v_dims_mapping); + + // TODO(liuzhenhai): process fixed_seed_offset and attn_mask + auto fixed_seed_offset_dist_attr = fixed_seed_offset.dist_attr(); + auto attn_mask_dist_attr = attn_mask.dist_attr(); + + TensorDistAttr out; + auto out_dims_mapping = + GetDimsMappingForAxes(out_axes, axis_to_dim_map, true); + out.set_dims_mapping(out_dims_mapping); + TensorDistAttr softmax; + softmax.set_dims_mapping( + GetDimsMappingForAxes(softmax_axes, axis_to_dim_map, true)); + TensorDistAttr softmax_lse; + softmax_lse.set_dims_mapping( + GetDimsMappingForAxes(softmax_lse_axes, axis_to_dim_map, true)); + TensorDistAttr seed_offset = + CopyTensorDistAttrForOutput(fixed_seed_offset_dist_attr); + // same as input + seed_offset.set_dims_mapping(fixed_seed_offset_dist_attr.dims_mapping()); + + VLOG(4) << "FlashAttInferSpmd:"; + VLOG(4) << "Einsum Notation: " << q_axes << "," << k_axes << "," << v_axes + << "-->" << out_axes << "," << softmax_axes << "," + << softmax_lse_axes; + + VLOG(4) << "q"; + VLOG(4) << "Input shape: [" << str_join(q_shape) << "] " + << "src_dims_mapping: [" << str_join(q_dist_attr.dims_mapping()) + << "] " + << "dst_dims_mapping: [" << str_join(q_dims_mapping) << "]"; + + VLOG(4) << "k"; + VLOG(4) << "Input shape: [" << str_join(k_shape) << "] " + << "src_dims_mapping: [" << str_join(k_dist_attr.dims_mapping()) + << "] " + << "dst_dims_mapping: [" << str_join(v_dims_mapping) << "]"; + + VLOG(4) << "v"; + VLOG(4) << "Input shape: [" << str_join(v_shape) << "] " + << "src_dims_mapping: [" << str_join(v_dist_attr.dims_mapping()) + << "] " + << "dst_dims_mapping: [" << str_join(v_dims_mapping) << "]"; + + VLOG(4) << "Output" + << " dims_mapping: [" << str_join(out_dims_mapping) << "]"; + VLOG(4) << std::endl; + + return {{q_dist_attr_dst, + k_dist_attr_dst, + v_dist_attr_dst, + fixed_seed_offset_dist_attr, + attn_mask_dist_attr}, + {out, softmax, softmax_lse, seed_offset}}; +} + +SpmdInfo FlashAttGradInferSpmd(const DistMetaTensor& q, + const DistMetaTensor& k, + const DistMetaTensor& v, + const DistMetaTensor& out, + const DistMetaTensor& softmax_lse, + const DistMetaTensor& seed_offset, + const DistMetaTensor& attn_mask, + const DistMetaTensor& out_grad, + float dropout, + bool causal) { + // q + // [batch_size, seq_len_q, num_heads, head_dim] + auto q_shape = phi::vectorize(q.dims()); + int q_ndim = q_shape.size(); + auto q_dist_attr = q.dist_attr(); + int q_dims_mapping_size = q_dist_attr.dims_mapping().size(); + + PADDLE_ENFORCE_EQ( + q_ndim, + 4, + phi::errors::InvalidArgument("The Tensor q's shape must be [batch_size, " + "seq_len_q, num_heads, head_dim]")); + + auto batch_size = q_shape[0]; + auto seq_len_q = q_shape[1]; + auto num_heads = q_shape[2]; + auto head_dim = q_shape[3]; + + PADDLE_ENFORCE_EQ( + q_ndim, + q_dims_mapping_size, + phi::errors::InvalidArgument("The Tensor q's rank [%d] and Its " + "dims_mapping size [%d] are not matched.", + q_ndim, + q_dims_mapping_size)); + + // k + // [batch_size, seq_len_kv, num_heads, head_dim] + auto k_shape = phi::vectorize(k.dims()); + int k_ndim = k_shape.size(); + auto k_dist_attr = k.dist_attr(); + int k_dims_mapping_size = k_dist_attr.dims_mapping().size(); + PADDLE_ENFORCE_EQ( + k_ndim, + 4, + phi::errors::InvalidArgument("The Tensor k's shape must be [batch_size, " + "seq_len_kv, num_heads, head_dim]")); + + auto k_batch_size = q_shape[0]; + auto k_seq_len = q_shape[1]; + auto k_num_heads = q_shape[2]; + auto k_head_dim = q_shape[3]; + + PADDLE_ENFORCE_EQ( + batch_size, + k_batch_size, + phi::errors::InvalidArgument( + "The Tensor q and k's batch size [%d] vs [%d] are not matched.", + batch_size, + k_batch_size)); + + PADDLE_ENFORCE_EQ( + num_heads, + k_num_heads, + phi::errors::InvalidArgument( + "The Tensor q and k's k_num_heads [%d] vs [%d] are not matched.", + num_heads, + k_num_heads)); + + PADDLE_ENFORCE_EQ( + head_dim, + k_head_dim, + phi::errors::InvalidArgument( + "The Tensor q and k's head_dim [%d] vs [%d] are not matched.", + head_dim, + k_head_dim)); + + PADDLE_ENFORCE_EQ( + k_ndim, + k_dims_mapping_size, + phi::errors::InvalidArgument("The Tensor q's rank [%d] and Its " + "dims_mapping size [%d] are not matched.", + k_ndim, + k_dims_mapping_size)); + + // v + // [batch_size, seq_len_kv, num_heads, head_dim] + auto v_shape = phi::vectorize(v.dims()); + int v_ndim = v_shape.size(); + auto v_dist_attr = v.dist_attr(); + int v_dims_mapping_size = v_dist_attr.dims_mapping().size(); + PADDLE_ENFORCE_EQ( + v_ndim, + 4, + phi::errors::InvalidArgument("The Tensor v's shape must be [batch_size, " + "seq_len_kv, num_heads, head_dim_v]")); + + auto v_batch_size = v_shape[0]; + auto v_seq_len = v_shape[1]; + auto v_num_heads = v_shape[2]; + auto v_head_dim = v_shape[3]; + + PADDLE_ENFORCE_EQ( + batch_size, + v_batch_size, + phi::errors::InvalidArgument( + "The Tensor q and v's batch size [%d] vs [%d] are not matched.", + batch_size, + v_batch_size)); + + PADDLE_ENFORCE_EQ( + num_heads, + v_num_heads, + phi::errors::InvalidArgument( + "The Tensor q and v's k_num_heads [%d] vs [%d] are not matched.", + num_heads, + v_num_heads)); + + PADDLE_ENFORCE_EQ( + k_seq_len, + v_seq_len, + phi::errors::InvalidArgument( + "The Tensor k and v's seq_len [%d] vs [%d] are not matched.", + k_seq_len, + v_seq_len)); + + PADDLE_ENFORCE_EQ( + v_ndim, + v_dims_mapping_size, + phi::errors::InvalidArgument("The Tensor q's rank [%d] and Its " + "dims_mapping size [%d] are not matched.", + v_ndim, + v_dims_mapping_size)); + + // fixed_seed_offset + + // attn_mask + auto mask_shape = phi::vectorize(attn_mask.dims()); + int mask_ndim = mask_shape.size(); + auto mask_dist_attr = attn_mask.dist_attr(); + int mask_dims_mapping_size = mask_dist_attr.dims_mapping().size(); + PADDLE_ENFORCE_EQ( + mask_ndim, + mask_dims_mapping_size, + phi::errors::InvalidArgument("The Tensor q's rank [%d] and Its " + "dims_mapping size [%d] are not matched.", + mask_ndim, + mask_dims_mapping_size)); + + auto out_shape = phi::vectorize(out.dims()); + auto out_dist_attr = attn_mask.dist_attr(); + + auto out_grad_shape = phi::vectorize(out_grad.dims()); + auto out_grad_dist_attr = out_grad.dist_attr(); + + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + int used_axes_index = 0; + char batch_axis = alphabet[used_axes_index++]; + char seq_len_q_axis = alphabet[used_axes_index++]; + char num_heads_axis = alphabet[used_axes_index++]; + char head_dim_axis = alphabet[used_axes_index++]; + char seq_len_kv_axis = alphabet[used_axes_index++]; + char head_dim_v_axis = alphabet[used_axes_index++]; + + // [batch_size, seq_len_q, num_heads, head_dim] + std::string q_axes = { + batch_axis, seq_len_q_axis, num_heads_axis, head_dim_axis}; + // [batch_size, seq_len_kv, num_heads, head_dim] + std::string k_axes = { + batch_axis, seq_len_kv_axis, num_heads_axis, head_dim_axis}; + // [batch_size, seq_len_kv, num_heads, head_dim_v] + std::string v_axes = { + batch_axis, seq_len_kv_axis, num_heads_axis, head_dim_v_axis}; + // [batch_size, seq_len_q, num_heads, head_dim_v] + std::string out_axes = { + batch_axis, seq_len_q_axis, num_heads_axis, head_dim_v_axis}; + // [batch_size, num_heads, seq_len_q, seq_len_kv] + std::string softmax_axes = { + batch_axis, num_heads_axis, seq_len_q_axis, seq_len_kv_axis}; + // [batch_size, num_heads, seq_len_q, seq_len_kv] + std::string softmax_lse_axes = { + batch_axis, num_heads_axis, seq_len_q_axis, seq_len_kv_axis}; + + std::string q_axes_align = q_axes; + q_axes_align[1] = alphabet[used_axes_index++]; + q_axes_align[3] = alphabet[used_axes_index++]; + + std::string k_axes_align = k_axes; + k_axes_align[1] = alphabet[used_axes_index++]; + k_axes_align[3] = alphabet[used_axes_index++]; + + std::string v_axes_align = v_axes; + v_axes_align[1] = alphabet[used_axes_index++]; + v_axes_align[3] = alphabet[used_axes_index++]; + + std::string out_axes_align = out_axes; + out_axes_align[1] = alphabet[used_axes_index++]; + out_axes_align[3] = alphabet[used_axes_index++]; + + std::string out_grad_axes_align = out_axes; + out_grad_axes_align[1] = alphabet[used_axes_index++]; + out_grad_axes_align[3] = alphabet[used_axes_index++]; + + std::vector>> axes_sharding_info; + + axes_sharding_info.emplace_back(q_axes_align, q_dist_attr.dims_mapping()); + axes_sharding_info.emplace_back(k_axes_align, k_dist_attr.dims_mapping()); + axes_sharding_info.emplace_back(v_axes_align, k_dist_attr.dims_mapping()); + axes_sharding_info.emplace_back(out_axes_align, out_dist_attr.dims_mapping()); + axes_sharding_info.emplace_back(out_grad_axes_align, + out_grad_dist_attr.dims_mapping()); + + auto axis_to_dim_map = ShardingMergeForTensors(axes_sharding_info); + + auto q_dist_attr_dst = CopyTensorDistAttrForOutput(q_dist_attr); + auto q_dims_mapping = GetDimsMappingForAxes(q_axes, axis_to_dim_map, true); + q_dist_attr_dst.set_dims_mapping(q_dims_mapping); + auto k_dist_attr_dst = CopyTensorDistAttrForOutput(k_dist_attr); + auto k_dims_mapping = GetDimsMappingForAxes(k_axes, axis_to_dim_map, true); + k_dist_attr_dst.set_dims_mapping(k_dims_mapping); + auto v_dist_attr_dst = CopyTensorDistAttrForOutput(v_dist_attr); + auto v_dims_mapping = GetDimsMappingForAxes(v_axes, axis_to_dim_map, true); + v_dist_attr_dst.set_dims_mapping(v_dims_mapping); + auto out_dist_attr_dst = CopyTensorDistAttrForOutput(out_dist_attr); + auto out_dims_mapping = + GetDimsMappingForAxes(out_axes, axis_to_dim_map, true); + out_dist_attr_dst.set_dims_mapping(out_dims_mapping); + + // TODO(liuzhenhai): process fixed_seed_offset and attn_mask + auto fixed_seed_offset_dist_attr = seed_offset.dist_attr(); + auto attn_mask_dist_attr = attn_mask.dist_attr(); + + auto out_grad_dist_attr_dst = CopyTensorDistAttrForOutput(out_grad_dist_attr); + auto out_grad_dims_mapping = + GetDimsMappingForAxes(out_axes, axis_to_dim_map, true); + v_dist_attr_dst.set_dims_mapping(out_grad_dims_mapping); + + TensorDistAttr q_grad; + auto q_grad_dims_mapping = + GetDimsMappingForAxes(q_axes, axis_to_dim_map, true); + q_grad.set_dims_mapping(q_grad_dims_mapping); + + TensorDistAttr k_grad; + auto k_grad_dims_mapping = + GetDimsMappingForAxes(k_axes, axis_to_dim_map, true); + k_grad.set_dims_mapping(k_grad_dims_mapping); + + TensorDistAttr v_grad; + auto v_grad_dims_mapping = + GetDimsMappingForAxes(v_axes, axis_to_dim_map, true); + v_grad.set_dims_mapping(v_grad_dims_mapping); + + VLOG(4) << "FlashAttInferSpmd:"; + VLOG(4) << "Einsum Notation: " << q_axes << "," << k_axes << "," << v_axes + << "-->" << out_axes << "," << softmax_axes << "," + << softmax_lse_axes; + + VLOG(4) << "q"; + VLOG(4) << "Input shape: [" << str_join(q_shape) << "] " + << "src_dims_mapping: [" << str_join(q_dist_attr.dims_mapping()) + << "] " + << "dst_dims_mapping: [" << str_join(q_dims_mapping) << "]"; + + VLOG(4) << "k"; + VLOG(4) << "Input shape: [" << str_join(k_shape) << "] " + << "src_dims_mapping: [" << str_join(k_dist_attr.dims_mapping()) + << "] " + << "dst_dims_mapping: [" << str_join(v_dims_mapping) << "]"; + + VLOG(4) << "v"; + VLOG(4) << "Input shape: [" << str_join(v_shape) << "] " + << "src_dims_mapping: [" << str_join(v_dist_attr.dims_mapping()) + << "] " + << "dst_dims_mapping: [" << str_join(v_dims_mapping) << "]"; + + VLOG(4) << "out"; + VLOG(4) << "Input shape: [" << str_join(out_shape) << "] " + << "src_dims_mapping: [" << str_join(out_dist_attr.dims_mapping()) + << "] " + << "dst_dims_mapping: [" << str_join(out_dims_mapping) << "]"; + + VLOG(4) << "out_grad"; + VLOG(4) << "Input shape: [" << str_join(out_grad_shape) << "] " + << "src_dims_mapping: [" + << str_join(out_grad_dist_attr.dims_mapping()) << "] " + << "dst_dims_mapping: [" << str_join(out_grad_dims_mapping) << "]"; + + VLOG(4) << "q_grad" + << " dims_mapping: [" << str_join(q_grad_dims_mapping) << "]"; + VLOG(4) << "k_grad" + << " dims_mapping: [" << str_join(k_grad_dims_mapping) << "]"; + VLOG(4) << "v_grad" + << " dims_mapping: [" << str_join(v_grad_dims_mapping) << "]"; + + VLOG(4) << std::endl; + + return {{q_dist_attr_dst, + k_dist_attr_dst, + v_dist_attr_dst, + out_dist_attr_dst, + fixed_seed_offset_dist_attr, + attn_mask_dist_attr, + out_grad_dist_attr_dst}, + {q_grad, k_grad, v_grad}}; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/flash_attention.h b/paddle/phi/infermeta/spmd_rules/flash_attention.h new file mode 100644 index 0000000000000..5cb881dd9ad50 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/flash_attention.h @@ -0,0 +1,49 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { + +SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q, + const DistMetaTensor& k, + const DistMetaTensor& v, + const DistMetaTensor& fixed_seed_offset, + const DistMetaTensor& attn_mask, + float dropout = 0.0, + bool causal = false, + bool return_softmax = false, + bool is_test = false, + const std::string& rng_name = ""); + +SpmdInfo FlashAttGradInferSpmd(const DistMetaTensor& q, + const DistMetaTensor& k, + const DistMetaTensor& v, + const DistMetaTensor& out, + const DistMetaTensor& softmax_lse, + const DistMetaTensor& seed_offset, + const DistMetaTensor& attn_mask, + const DistMetaTensor& out_grad, + float dropout = 0.0, + bool causal = false); + +} // namespace distributed +} // namespace phi diff --git a/test/auto_parallel/semi_auto_parallel_for_flash_attention.py b/test/auto_parallel/semi_auto_parallel_for_flash_attention.py new file mode 100644 index 0000000000000..6dcef7646dde5 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_flash_attention.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from semi_auto_parallel_util import SemiAutoParallelTestBase + +import paddle +import paddle.distributed as dist + + +class TestFlashAttentionSemiAutoParallel(SemiAutoParallelTestBase): + def __init__(self): + super().__init__() + + def check_dim_mapping(self, output, expected_dim_mapping): + assert ( + output.dist_attr.dims_mapping == expected_dim_mapping + ), f"{output.dist_attr.dims_mapping} vs {expected_dim_mapping}" + + def test_concat_forward_reshard(self): + shapes = [[16, 4, 4], [64, 4, 4]] + specs = [['x', None, None], [None, None, 'x']] + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=paddle.concat, + with_backward=False, + axis=0, + ) + self.check_dim_mapping(outputs, [-1, -1, 0]) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + + self.test_concat_forward() + + # all to all is not supported yet for cpu + if self._backend == "gpu": + self.test_concat_forward_reshard() + + +if __name__ == '__main__': + TestFlashAttentionSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/test_semi_auto_parallel_basic.py b/test/auto_parallel/test_semi_auto_parallel_basic.py index 2589566cb670e..c81704abb35a4 100644 --- a/test/auto_parallel/test_semi_auto_parallel_basic.py +++ b/test/auto_parallel/test_semi_auto_parallel_basic.py @@ -106,6 +106,16 @@ def test_custom_relu_api(self): user_defined_envs=envs, ) + def test_flash_attention_api(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_flash_attention.py", + user_defined_envs=envs, + ) + if __name__ == "__main__": unittest.main() From 9c5c8b24497b5dd3f8c29ba77311f688a6877ac1 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Thu, 16 Nov 2023 11:38:36 +0000 Subject: [PATCH 2/9] polish --- .../semi_auto_parallel_for_flash_attention.py | 30 +++++++++++++------ .../test_semi_auto_parallel_basic.py | 2 ++ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/test/auto_parallel/semi_auto_parallel_for_flash_attention.py b/test/auto_parallel/semi_auto_parallel_for_flash_attention.py index 6dcef7646dde5..60ce251a735b0 100644 --- a/test/auto_parallel/semi_auto_parallel_for_flash_attention.py +++ b/test/auto_parallel/semi_auto_parallel_for_flash_attention.py @@ -16,6 +16,7 @@ import paddle import paddle.distributed as dist +from paddle.nn.functional.flash_attention import flash_attention class TestFlashAttentionSemiAutoParallel(SemiAutoParallelTestBase): @@ -27,17 +28,28 @@ def check_dim_mapping(self, output, expected_dim_mapping): output.dist_attr.dims_mapping == expected_dim_mapping ), f"{output.dist_attr.dims_mapping} vs {expected_dim_mapping}" - def test_concat_forward_reshard(self): - shapes = [[16, 4, 4], [64, 4, 4]] - specs = [['x', None, None], [None, None, 'x']] + def test_flash_att_forward(self): + shapes = [[2, 256, 2, 128], [2, 256, 2, 128], [2, 256, 2, 128]] + specs = [['x', None, None], ["x", None, None], ['x', None, None]] inputs, outputs = self.runfunc_and_check( inputs_shape=shapes, inputs_specs=specs, - op_func=paddle.concat, - with_backward=False, - axis=0, + op_func=flash_attention, + with_backward=True, + causal=True, ) - self.check_dim_mapping(outputs, [-1, -1, 0]) + + def test_flash_att_forward_reshard(self): + shapes = [[2, 256, 2, 128], [2, 256, 2, 128], [2, 256, 2, 128]] + specs = [['x', None, None], [None, None, 'x'], ['x', None, None]] + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=flash_attention, + with_backward=True, + causal=True, + ) + # self.check_dim_mapping(outputs, [-1, -1, 0]) def run_test_case(self): if self._backend == "cpu": @@ -47,11 +59,11 @@ def run_test_case(self): else: raise ValueError("Only support cpu or gpu backend.") - self.test_concat_forward() + self.test_flash_att_forward() # all to all is not supported yet for cpu if self._backend == "gpu": - self.test_concat_forward_reshard() + self.test_flash_att_forward_reshard() if __name__ == '__main__': diff --git a/test/auto_parallel/test_semi_auto_parallel_basic.py b/test/auto_parallel/test_semi_auto_parallel_basic.py index c81704abb35a4..cfd2029f03452 100644 --- a/test/auto_parallel/test_semi_auto_parallel_basic.py +++ b/test/auto_parallel/test_semi_auto_parallel_basic.py @@ -26,6 +26,7 @@ def setUp(self): self._default_envs = {"dtype": "float32", "seed": "2023"} self._changeable_envs = {"backend": ["cpu", "gpu"]} + """ def test_matmul_api(self): envs_list = test_base.gen_product_envs_list( self._default_envs, self._changeable_envs @@ -105,6 +106,7 @@ def test_custom_relu_api(self): "semi_auto_parallel_for_custom_relu.py", user_defined_envs=envs, ) + """ def test_flash_attention_api(self): envs_list = test_base.gen_product_envs_list( From a1d6436b18fb7244f6479137776e3a7babcf2afe Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Fri, 17 Nov 2023 10:43:46 +0800 Subject: [PATCH 3/9] polish --- paddle/phi/api/yaml/generator/dist_api_gen.py | 13 ++++- .../infermeta/spmd_rules/flash_attention.cc | 4 -- paddle/phi/infermeta/spmd_rules/rules.h | 1 + .../semi_auto_parallel_for_flash_attention.py | 52 ++++++++++++++++--- test/auto_parallel/semi_auto_parallel_util.py | 11 ++-- .../test_semi_auto_parallel_basic.py | 2 +- 6 files changed, 66 insertions(+), 17 deletions(-) diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index 431a98d829ae6..83849d1990556 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -705,6 +705,14 @@ def generate_specialized_infer_spmd_code(self) -> str: name=param ) input_args_code += "meta_dist_input_" + param + ", " + elif ( + self.inputs['input_info'][param] + == "const paddle::optional&" + ): + input_decl_code += ( + OPTIONAL_SINGLE_DIST_META_IN_TEMPLATE.format(name=param) + ) + input_args_code += "meta_dist_input_" + param + ", " elif ( self.inputs['input_info'][param] == "const std::vector&" @@ -826,7 +834,10 @@ def generate_output_creation_code(self) -> str: # kernel output generate self.dist_output_args.append('dist_out') self.dense_output_args.append('dense_out') - if self.outputs['types'][0] == 'Tensor': + if ( + self.outputs['types'][0] == 'Tensor' + or self.outputs['types'][0] == 'const paddle::optional' + ): if self.infer_meta['spmd_rule'] is not None: output_creation_code += SINGLE_OUT_CREATION_TEMPLATE else: diff --git a/paddle/phi/infermeta/spmd_rules/flash_attention.cc b/paddle/phi/infermeta/spmd_rules/flash_attention.cc index f4193f1d542d3..6b12f3fb09173 100644 --- a/paddle/phi/infermeta/spmd_rules/flash_attention.cc +++ b/paddle/phi/infermeta/spmd_rules/flash_attention.cc @@ -44,7 +44,6 @@ SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q, "seq_len_q, num_heads, head_dim]")); auto batch_size = q_shape[0]; - auto seq_len_q = q_shape[1]; auto num_heads = q_shape[2]; auto head_dim = q_shape[3]; @@ -120,7 +119,6 @@ SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q, auto v_batch_size = v_shape[0]; auto v_seq_len = v_shape[1]; auto v_num_heads = v_shape[2]; - auto v_head_dim = v_shape[3]; PADDLE_ENFORCE_EQ( batch_size, @@ -305,7 +303,6 @@ SpmdInfo FlashAttGradInferSpmd(const DistMetaTensor& q, "seq_len_q, num_heads, head_dim]")); auto batch_size = q_shape[0]; - auto seq_len_q = q_shape[1]; auto num_heads = q_shape[2]; auto head_dim = q_shape[3]; @@ -381,7 +378,6 @@ SpmdInfo FlashAttGradInferSpmd(const DistMetaTensor& q, auto v_batch_size = v_shape[0]; auto v_seq_len = v_shape[1]; auto v_num_heads = v_shape[2]; - auto v_head_dim = v_shape[3]; PADDLE_ENFORCE_EQ( batch_size, diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index f685b4ece261f..08b8fc37c366d 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/default_data_parallel.h" #include "paddle/phi/infermeta/spmd_rules/elementwise.h" #include "paddle/phi/infermeta/spmd_rules/embedding.h" +#include "paddle/phi/infermeta/spmd_rules/flash_attention.h" #include "paddle/phi/infermeta/spmd_rules/flatten.h" #include "paddle/phi/infermeta/spmd_rules/layer_norm.h" #include "paddle/phi/infermeta/spmd_rules/matmul.h" diff --git a/test/auto_parallel/semi_auto_parallel_for_flash_attention.py b/test/auto_parallel/semi_auto_parallel_for_flash_attention.py index 60ce251a735b0..00d8561f1e1b3 100644 --- a/test/auto_parallel/semi_auto_parallel_for_flash_attention.py +++ b/test/auto_parallel/semi_auto_parallel_for_flash_attention.py @@ -29,19 +29,52 @@ def check_dim_mapping(self, output, expected_dim_mapping): ), f"{output.dist_attr.dims_mapping} vs {expected_dim_mapping}" def test_flash_att_forward(self): - shapes = [[2, 256, 2, 128], [2, 256, 2, 128], [2, 256, 2, 128]] - specs = [['x', None, None], ["x", None, None], ['x', None, None]] + shapes = ([2, 256, 2, 128], [2, 256, 2, 128], [2, 256, 2, 128]) + specs = ( + ['x', None, None, None], + ["x", None, None, None], + ['x', None, None, None], + ) + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=flash_attention, + with_backward=True, + causal=True, + ) + self.check_dim_mapping(outputs[0], [0, -1, -1, -1]) + self.check_dim_mapping(inputs[0].grad, [0, -1, -1, -1]) + self.check_dim_mapping(inputs[1].grad, [0, -1, -1, -1]) + self.check_dim_mapping(inputs[2].grad, [0, -1, -1, -1]) + + def test_flash_att_forward_return_softmax(self): + shapes = ([2, 256, 2, 128], [2, 256, 2, 128], [2, 256, 2, 128]) + specs = ( + ['x', None, None, None], + ["x", None, None, None], + ['x', None, None, None], + ) inputs, outputs = self.runfunc_and_check( inputs_shape=shapes, inputs_specs=specs, op_func=flash_attention, with_backward=True, causal=True, + return_softmax=True, ) + self.check_dim_mapping(outputs[0], [0, -1, -1, -1]) + self.check_dim_mapping(outputs[1], [0, -1, -1, -1]) + self.check_dim_mapping(inputs[0].grad, [0, -1, -1, -1]) + self.check_dim_mapping(inputs[1].grad, [0, -1, -1, -1]) + self.check_dim_mapping(inputs[2].grad, [0, -1, -1, -1]) def test_flash_att_forward_reshard(self): - shapes = [[2, 256, 2, 128], [2, 256, 2, 128], [2, 256, 2, 128]] - specs = [['x', None, None], [None, None, 'x'], ['x', None, None]] + shapes = ([2, 256, 2, 128], [2, 256, 2, 128], [2, 256, 2, 128]) + specs = ( + ['x', None, None, None], + [None, None, None, 'x'], + ['x', None, None, None], + ) inputs, outputs = self.runfunc_and_check( inputs_shape=shapes, inputs_specs=specs, @@ -49,7 +82,10 @@ def test_flash_att_forward_reshard(self): with_backward=True, causal=True, ) - # self.check_dim_mapping(outputs, [-1, -1, 0]) + self.check_dim_mapping(outputs[0], [0, -1, -1, -1]) + self.check_dim_mapping(inputs[0].grad, [0, -1, -1, -1]) + self.check_dim_mapping(inputs[1].grad, [0, -1, -1, -1]) + self.check_dim_mapping(inputs[2].grad, [0, -1, -1, -1]) def run_test_case(self): if self._backend == "cpu": @@ -59,10 +95,10 @@ def run_test_case(self): else: raise ValueError("Only support cpu or gpu backend.") - self.test_flash_att_forward() - - # all to all is not supported yet for cpu + # flash attention is not supported yet for cpu if self._backend == "gpu": + self.test_flash_att_forward() + self.test_flash_att_forward_return_softmax() self.test_flash_att_forward_reshard() diff --git a/test/auto_parallel/semi_auto_parallel_util.py b/test/auto_parallel/semi_auto_parallel_util.py index cfb905e8382a2..84567d19b30ac 100644 --- a/test/auto_parallel/semi_auto_parallel_util.py +++ b/test/auto_parallel/semi_auto_parallel_util.py @@ -28,6 +28,9 @@ def __init__(self): self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) def check_tensor_eq(self, a, b): + if a is None: + assert b is None + return np1 = a.numpy() np2 = b.numpy() np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) @@ -92,8 +95,9 @@ def terminal_cond(x): assert len(flat_inputs_specs) == len(flat_inputs_shape) for shape, spec in zip(flat_inputs_shape, flat_inputs_specs): - input_np = np.random.random(size=shape).astype(self._dtype) + input_np = np.random.random(size=shape).astype("float32") input = paddle.to_tensor(input_np) + input = paddle.cast(input, self._dtype).detach() input.stop_gradient = False input_dist_attr = dist.DistAttr( mesh=self._mesh, sharding_specs=spec @@ -124,8 +128,9 @@ def terminal_cond2(x): assert len(flat_out) == len(flat_dist_out) for output, dist_output in zip(flat_out, flat_dist_out): self.check_tensor_eq(out, dist_out) - output.backward() - dist_output.backward() + if out is not None: + output.backward() + dist_output.backward() for x, dist_x in zip(flat_inputs, flat_dist_inputs): self.check_tensor_eq(x.grad, dist_x.grad) diff --git a/test/auto_parallel/test_semi_auto_parallel_basic.py b/test/auto_parallel/test_semi_auto_parallel_basic.py index cfd2029f03452..c9b637249ab4d 100644 --- a/test/auto_parallel/test_semi_auto_parallel_basic.py +++ b/test/auto_parallel/test_semi_auto_parallel_basic.py @@ -110,7 +110,7 @@ def test_custom_relu_api(self): def test_flash_attention_api(self): envs_list = test_base.gen_product_envs_list( - self._default_envs, self._changeable_envs + {"dtype": "float16", "seed": "2023"}, self._changeable_envs ) for envs in envs_list: self.run_test_case( From fcddf1d8d7510c3774f985de1305df79f23aa3c1 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 21 Nov 2023 17:31:53 +0800 Subject: [PATCH 4/9] polish --- .../infermeta/spmd_rules/flash_attention.cc | 16 ++++++---- .../semi_auto_parallel_for_flash_attention.py | 32 ++----------------- 2 files changed, 11 insertions(+), 37 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/flash_attention.cc b/paddle/phi/infermeta/spmd_rules/flash_attention.cc index 6b12f3fb09173..305b47a888ed4 100644 --- a/paddle/phi/infermeta/spmd_rules/flash_attention.cc +++ b/paddle/phi/infermeta/spmd_rules/flash_attention.cc @@ -159,13 +159,15 @@ SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q, int mask_ndim = mask_shape.size(); auto mask_dist_attr = attn_mask.dist_attr(); int mask_dims_mapping_size = mask_dist_attr.dims_mapping().size(); - PADDLE_ENFORCE_EQ( - mask_ndim, - mask_dims_mapping_size, - phi::errors::InvalidArgument("The Tensor q's rank [%d] and Its " - "dims_mapping size [%d] are not matched.", - mask_ndim, - mask_dims_mapping_size)); + if (!IsEmpty(mask_shape)) { + PADDLE_ENFORCE_EQ( + mask_ndim, + mask_dims_mapping_size, + phi::errors::InvalidArgument("The Tensor mask's rank [%d] and Its " + "dims_mapping size [%d] are not matched.", + mask_ndim, + mask_dims_mapping_size)); + } std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; int used_axes_index = 0; diff --git a/test/auto_parallel/semi_auto_parallel_for_flash_attention.py b/test/auto_parallel/semi_auto_parallel_for_flash_attention.py index 00d8561f1e1b3..9975b976248a4 100644 --- a/test/auto_parallel/semi_auto_parallel_for_flash_attention.py +++ b/test/auto_parallel/semi_auto_parallel_for_flash_attention.py @@ -39,34 +39,10 @@ def test_flash_att_forward(self): inputs_shape=shapes, inputs_specs=specs, op_func=flash_attention, - with_backward=True, + with_backward=False, causal=True, ) self.check_dim_mapping(outputs[0], [0, -1, -1, -1]) - self.check_dim_mapping(inputs[0].grad, [0, -1, -1, -1]) - self.check_dim_mapping(inputs[1].grad, [0, -1, -1, -1]) - self.check_dim_mapping(inputs[2].grad, [0, -1, -1, -1]) - - def test_flash_att_forward_return_softmax(self): - shapes = ([2, 256, 2, 128], [2, 256, 2, 128], [2, 256, 2, 128]) - specs = ( - ['x', None, None, None], - ["x", None, None, None], - ['x', None, None, None], - ) - inputs, outputs = self.runfunc_and_check( - inputs_shape=shapes, - inputs_specs=specs, - op_func=flash_attention, - with_backward=True, - causal=True, - return_softmax=True, - ) - self.check_dim_mapping(outputs[0], [0, -1, -1, -1]) - self.check_dim_mapping(outputs[1], [0, -1, -1, -1]) - self.check_dim_mapping(inputs[0].grad, [0, -1, -1, -1]) - self.check_dim_mapping(inputs[1].grad, [0, -1, -1, -1]) - self.check_dim_mapping(inputs[2].grad, [0, -1, -1, -1]) def test_flash_att_forward_reshard(self): shapes = ([2, 256, 2, 128], [2, 256, 2, 128], [2, 256, 2, 128]) @@ -79,13 +55,10 @@ def test_flash_att_forward_reshard(self): inputs_shape=shapes, inputs_specs=specs, op_func=flash_attention, - with_backward=True, + with_backward=False, causal=True, ) self.check_dim_mapping(outputs[0], [0, -1, -1, -1]) - self.check_dim_mapping(inputs[0].grad, [0, -1, -1, -1]) - self.check_dim_mapping(inputs[1].grad, [0, -1, -1, -1]) - self.check_dim_mapping(inputs[2].grad, [0, -1, -1, -1]) def run_test_case(self): if self._backend == "cpu": @@ -98,7 +71,6 @@ def run_test_case(self): # flash attention is not supported yet for cpu if self._backend == "gpu": self.test_flash_att_forward() - self.test_flash_att_forward_return_softmax() self.test_flash_att_forward_reshard() From b4eb1450556d78c48085b97e351286428249f4f9 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 21 Nov 2023 17:34:40 +0800 Subject: [PATCH 5/9] polish --- test/auto_parallel/test_semi_auto_parallel_basic.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/auto_parallel/test_semi_auto_parallel_basic.py b/test/auto_parallel/test_semi_auto_parallel_basic.py index 25f49c3329ad3..2f1d0297584cd 100644 --- a/test/auto_parallel/test_semi_auto_parallel_basic.py +++ b/test/auto_parallel/test_semi_auto_parallel_basic.py @@ -26,7 +26,6 @@ def setUp(self): self._default_envs = {"dtype": "float32", "seed": "2023"} self._changeable_envs = {"backend": ["cpu", "gpu"]} - """ def test_matmul_api(self): default_envs = self._default_envs default_envs["NVIDIA_TF32_OVERRIDE"] = "0" @@ -118,7 +117,6 @@ def test_custom_relu_api(self): "semi_auto_parallel_for_custom_relu.py", user_defined_envs=envs, ) - """ def test_flash_attention_api(self): envs_list = test_base.gen_product_envs_list( From a285e0a5e4e1e512d6dce8fc560464c9fd3c52ce Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Wed, 22 Nov 2023 19:05:57 +0800 Subject: [PATCH 6/9] polish --- .../infermeta/spmd_rules/flash_attention.cc | 265 ++++++++---------- .../semi_auto_parallel_for_flash_attention.py | 4 +- test/auto_parallel/semi_auto_parallel_util.py | 8 +- 3 files changed, 118 insertions(+), 159 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/flash_attention.cc b/paddle/phi/infermeta/spmd_rules/flash_attention.cc index 305b47a888ed4..02ef6e49df1a9 100644 --- a/paddle/phi/infermeta/spmd_rules/flash_attention.cc +++ b/paddle/phi/infermeta/spmd_rules/flash_attention.cc @@ -18,8 +18,33 @@ limitations under the License. */ namespace phi { namespace distributed { + +#define LOG_SPMD_INPUT(name) \ + do { \ + VLOG(4) << #name; \ + VLOG(4) << "shape: [" << str_join(name##_shape) << "] " \ + << "src_dist_attr: [" << name##_dist_attr.to_string() << "] " \ + << "src_dist_attr: [" << name##_dist_attr_dst.to_string() << "]"; \ + } while (0) + +#define LOG_SPMD_OUTPUT(name) \ + do { \ + VLOG(4) << #name; \ + VLOG(4) << "src_dist_attr: [" << name.to_string() << "]"; \ + } while (0) + using phi::distributed::auto_parallel::str_join; +TensorDistAttr MapDims( + const TensorDistAttr& src, + const std::unordered_map& axes_mapping, + const std::string& axes) { + auto dst = CopyTensorDistAttrForOutput(src); + auto dims_mapping = GetDimsMappingForAxes(axes, axes_mapping, true); + dst.set_dims_mapping(dims_mapping); + return dst; +} + SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q, const DistMetaTensor& k, const DistMetaTensor& v, @@ -153,13 +178,15 @@ SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q, v_dims_mapping_size)); // fixed_seed_offset - + // TODO(liuzhenhai): process fixed_seed_offset and attn_mask + auto fixed_seed_offset_dist_attr = fixed_seed_offset.dist_attr(); + auto fixed_seed_offset_shape = phi::vectorize(fixed_seed_offset.dims()); // attn_mask - auto mask_shape = phi::vectorize(attn_mask.dims()); - int mask_ndim = mask_shape.size(); - auto mask_dist_attr = attn_mask.dist_attr(); - int mask_dims_mapping_size = mask_dist_attr.dims_mapping().size(); - if (!IsEmpty(mask_shape)) { + auto attn_mask_shape = phi::vectorize(attn_mask.dims()); + int mask_ndim = attn_mask_shape.size(); + auto attn_mask_dist_attr = attn_mask.dist_attr(); + int mask_dims_mapping_size = attn_mask_dist_attr.dims_mapping().size(); + if (!IsEmpty(attn_mask_shape)) { PADDLE_ENFORCE_EQ( mask_ndim, mask_dims_mapping_size, @@ -194,8 +221,7 @@ SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q, std::string softmax_axes = { batch_axis, num_heads_axis, seq_len_q_axis, seq_len_kv_axis}; // [batch_size, num_heads, seq_len_q, seq_len_kv] - std::string softmax_lse_axes = { - batch_axis, num_heads_axis, seq_len_q_axis, seq_len_kv_axis}; + std::string softmax_lse_axes = {batch_axis, num_heads_axis, seq_len_q_axis}; std::string q_axes_align = q_axes; q_axes_align[1] = alphabet[used_axes_index++]; @@ -217,67 +243,42 @@ SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q, auto axis_to_dim_map = ShardingMergeForTensors(axes_sharding_info); - auto q_dist_attr_dst = CopyTensorDistAttrForOutput(q_dist_attr); - auto q_dims_mapping = GetDimsMappingForAxes(q_axes, axis_to_dim_map, true); - q_dist_attr_dst.set_dims_mapping(q_dims_mapping); - auto k_dist_attr_dst = CopyTensorDistAttrForOutput(k_dist_attr); - auto k_dims_mapping = GetDimsMappingForAxes(k_axes, axis_to_dim_map, true); - k_dist_attr_dst.set_dims_mapping(k_dims_mapping); - auto v_dist_attr_dst = CopyTensorDistAttrForOutput(v_dist_attr); - auto v_dims_mapping = GetDimsMappingForAxes(v_axes, axis_to_dim_map, true); - v_dist_attr_dst.set_dims_mapping(v_dims_mapping); + auto q_dist_attr_dst = MapDims(q_dist_attr, axis_to_dim_map, q_axes); + auto k_dist_attr_dst = MapDims(k_dist_attr, axis_to_dim_map, k_axes); + auto v_dist_attr_dst = MapDims(v_dist_attr, axis_to_dim_map, v_axes); - // TODO(liuzhenhai): process fixed_seed_offset and attn_mask - auto fixed_seed_offset_dist_attr = fixed_seed_offset.dist_attr(); - auto attn_mask_dist_attr = attn_mask.dist_attr(); + // TODO(liuzhenhai): process fixed_seed and attn_mask + auto fixed_seed_offset_dist_attr_dst = fixed_seed_offset_dist_attr; + auto attn_mask_dist_attr_dst = attn_mask_dist_attr; + + auto out = MapDims(q_dist_attr, axis_to_dim_map, out_axes); + auto softmax = MapDims(q_dist_attr, axis_to_dim_map, softmax_axes); + auto softmax_lse = MapDims(q_dist_attr, axis_to_dim_map, softmax_lse_axes); - TensorDistAttr out; - auto out_dims_mapping = - GetDimsMappingForAxes(out_axes, axis_to_dim_map, true); - out.set_dims_mapping(out_dims_mapping); - TensorDistAttr softmax; - softmax.set_dims_mapping( - GetDimsMappingForAxes(softmax_axes, axis_to_dim_map, true)); - TensorDistAttr softmax_lse; - softmax_lse.set_dims_mapping( - GetDimsMappingForAxes(softmax_lse_axes, axis_to_dim_map, true)); - TensorDistAttr seed_offset = - CopyTensorDistAttrForOutput(fixed_seed_offset_dist_attr); - // same as input - seed_offset.set_dims_mapping(fixed_seed_offset_dist_attr.dims_mapping()); + TensorDistAttr seed_offset = fixed_seed_offset_dist_attr; VLOG(4) << "FlashAttInferSpmd:"; VLOG(4) << "Einsum Notation: " << q_axes << "," << k_axes << "," << v_axes << "-->" << out_axes << "," << softmax_axes << "," << softmax_lse_axes; - VLOG(4) << "q"; - VLOG(4) << "Input shape: [" << str_join(q_shape) << "] " - << "src_dims_mapping: [" << str_join(q_dist_attr.dims_mapping()) - << "] " - << "dst_dims_mapping: [" << str_join(q_dims_mapping) << "]"; - - VLOG(4) << "k"; - VLOG(4) << "Input shape: [" << str_join(k_shape) << "] " - << "src_dims_mapping: [" << str_join(k_dist_attr.dims_mapping()) - << "] " - << "dst_dims_mapping: [" << str_join(v_dims_mapping) << "]"; - - VLOG(4) << "v"; - VLOG(4) << "Input shape: [" << str_join(v_shape) << "] " - << "src_dims_mapping: [" << str_join(v_dist_attr.dims_mapping()) - << "] " - << "dst_dims_mapping: [" << str_join(v_dims_mapping) << "]"; - - VLOG(4) << "Output" - << " dims_mapping: [" << str_join(out_dims_mapping) << "]"; + LOG_SPMD_INPUT(q); + LOG_SPMD_INPUT(k); + LOG_SPMD_INPUT(v); + LOG_SPMD_INPUT(fixed_seed_offset); + LOG_SPMD_INPUT(attn_mask); + VLOG(4) << "Outputs:"; + LOG_SPMD_OUTPUT(out); + LOG_SPMD_OUTPUT(softmax); + LOG_SPMD_OUTPUT(softmax_lse); + LOG_SPMD_OUTPUT(seed_offset); VLOG(4) << std::endl; return {{q_dist_attr_dst, k_dist_attr_dst, v_dist_attr_dst, - fixed_seed_offset_dist_attr, - attn_mask_dist_attr}, + fixed_seed_offset_dist_attr_dst, + attn_mask_dist_attr_dst}, {out, softmax, softmax_lse, seed_offset}}; } @@ -360,7 +361,7 @@ SpmdInfo FlashAttGradInferSpmd(const DistMetaTensor& q, PADDLE_ENFORCE_EQ( k_ndim, k_dims_mapping_size, - phi::errors::InvalidArgument("The Tensor q's rank [%d] and Its " + phi::errors::InvalidArgument("The Tensor k's rank [%d] and Its " "dims_mapping size [%d] are not matched.", k_ndim, k_dims_mapping_size)); @@ -408,28 +409,35 @@ SpmdInfo FlashAttGradInferSpmd(const DistMetaTensor& q, PADDLE_ENFORCE_EQ( v_ndim, v_dims_mapping_size, - phi::errors::InvalidArgument("The Tensor q's rank [%d] and Its " + phi::errors::InvalidArgument("The Tensor v's rank [%d] and Its " "dims_mapping size [%d] are not matched.", v_ndim, v_dims_mapping_size)); // fixed_seed_offset + auto seed_offset_dist_attr = seed_offset.dist_attr(); + auto seed_offset_shape = phi::vectorize(seed_offset.dims()); // attn_mask - auto mask_shape = phi::vectorize(attn_mask.dims()); - int mask_ndim = mask_shape.size(); - auto mask_dist_attr = attn_mask.dist_attr(); - int mask_dims_mapping_size = mask_dist_attr.dims_mapping().size(); - PADDLE_ENFORCE_EQ( - mask_ndim, - mask_dims_mapping_size, - phi::errors::InvalidArgument("The Tensor q's rank [%d] and Its " - "dims_mapping size [%d] are not matched.", - mask_ndim, - mask_dims_mapping_size)); + auto attn_mask_shape = phi::vectorize(attn_mask.dims()); + int mask_ndim = attn_mask_shape.size(); + auto attn_mask_dist_attr = attn_mask.dist_attr(); + int mask_dims_mapping_size = attn_mask_dist_attr.dims_mapping().size(); + if (!IsEmpty(attn_mask_shape)) { + PADDLE_ENFORCE_EQ( + mask_ndim, + mask_dims_mapping_size, + phi::errors::InvalidArgument("The Tensor mask's rank [%d] and Its " + "dims_mapping size [%d] are not matched.", + mask_ndim, + mask_dims_mapping_size)); + } auto out_shape = phi::vectorize(out.dims()); - auto out_dist_attr = attn_mask.dist_attr(); + auto out_dist_attr = out.dist_attr(); + + auto softmax_lse_shape = phi::vectorize(softmax_lse.dims()); + auto softmax_lse_dist_attr = softmax_lse.dist_attr(); auto out_grad_shape = phi::vectorize(out_grad.dims()); auto out_grad_dist_attr = out_grad.dist_attr(); @@ -458,9 +466,8 @@ SpmdInfo FlashAttGradInferSpmd(const DistMetaTensor& q, // [batch_size, num_heads, seq_len_q, seq_len_kv] std::string softmax_axes = { batch_axis, num_heads_axis, seq_len_q_axis, seq_len_kv_axis}; - // [batch_size, num_heads, seq_len_q, seq_len_kv] - std::string softmax_lse_axes = { - batch_axis, num_heads_axis, seq_len_q_axis, seq_len_kv_axis}; + // [batch_size, num_heads, seq_len_q] + std::string softmax_lse_axes = {batch_axis, num_heads_axis, seq_len_q_axis}; std::string q_axes_align = q_axes; q_axes_align[1] = alphabet[used_axes_index++]; @@ -483,104 +490,56 @@ SpmdInfo FlashAttGradInferSpmd(const DistMetaTensor& q, out_grad_axes_align[3] = alphabet[used_axes_index++]; std::vector>> axes_sharding_info; - axes_sharding_info.emplace_back(q_axes_align, q_dist_attr.dims_mapping()); axes_sharding_info.emplace_back(k_axes_align, k_dist_attr.dims_mapping()); axes_sharding_info.emplace_back(v_axes_align, k_dist_attr.dims_mapping()); axes_sharding_info.emplace_back(out_axes_align, out_dist_attr.dims_mapping()); axes_sharding_info.emplace_back(out_grad_axes_align, out_grad_dist_attr.dims_mapping()); - auto axis_to_dim_map = ShardingMergeForTensors(axes_sharding_info); - auto q_dist_attr_dst = CopyTensorDistAttrForOutput(q_dist_attr); - auto q_dims_mapping = GetDimsMappingForAxes(q_axes, axis_to_dim_map, true); - q_dist_attr_dst.set_dims_mapping(q_dims_mapping); - auto k_dist_attr_dst = CopyTensorDistAttrForOutput(k_dist_attr); - auto k_dims_mapping = GetDimsMappingForAxes(k_axes, axis_to_dim_map, true); - k_dist_attr_dst.set_dims_mapping(k_dims_mapping); - auto v_dist_attr_dst = CopyTensorDistAttrForOutput(v_dist_attr); - auto v_dims_mapping = GetDimsMappingForAxes(v_axes, axis_to_dim_map, true); - v_dist_attr_dst.set_dims_mapping(v_dims_mapping); - auto out_dist_attr_dst = CopyTensorDistAttrForOutput(out_dist_attr); - auto out_dims_mapping = - GetDimsMappingForAxes(out_axes, axis_to_dim_map, true); - out_dist_attr_dst.set_dims_mapping(out_dims_mapping); + auto q_dist_attr_dst = MapDims(q_dist_attr, axis_to_dim_map, q_axes); + auto k_dist_attr_dst = MapDims(k_dist_attr, axis_to_dim_map, k_axes); + auto v_dist_attr_dst = MapDims(v_dist_attr, axis_to_dim_map, v_axes); + auto out_dist_attr_dst = MapDims(out_dist_attr, axis_to_dim_map, out_axes); + auto softmax_lse_dist_attr_dst = + MapDims(softmax_lse_dist_attr, axis_to_dim_map, softmax_lse_axes); - // TODO(liuzhenhai): process fixed_seed_offset and attn_mask - auto fixed_seed_offset_dist_attr = seed_offset.dist_attr(); - auto attn_mask_dist_attr = attn_mask.dist_attr(); - - auto out_grad_dist_attr_dst = CopyTensorDistAttrForOutput(out_grad_dist_attr); - auto out_grad_dims_mapping = - GetDimsMappingForAxes(out_axes, axis_to_dim_map, true); - v_dist_attr_dst.set_dims_mapping(out_grad_dims_mapping); - - TensorDistAttr q_grad; - auto q_grad_dims_mapping = - GetDimsMappingForAxes(q_axes, axis_to_dim_map, true); - q_grad.set_dims_mapping(q_grad_dims_mapping); - - TensorDistAttr k_grad; - auto k_grad_dims_mapping = - GetDimsMappingForAxes(k_axes, axis_to_dim_map, true); - k_grad.set_dims_mapping(k_grad_dims_mapping); + // TODO(liuzhenhai): process seed and attn_mask + auto& seed_offset_dist_attr_dst = seed_offset_dist_attr; + auto& attn_mask_dist_attr_dst = attn_mask_dist_attr; - TensorDistAttr v_grad; - auto v_grad_dims_mapping = - GetDimsMappingForAxes(v_axes, axis_to_dim_map, true); - v_grad.set_dims_mapping(v_grad_dims_mapping); + auto out_grad_dist_attr_dst = + MapDims(out_grad_dist_attr, axis_to_dim_map, out_axes); + auto q_grad = MapDims(q_dist_attr, axis_to_dim_map, q_axes); + auto k_grad = MapDims(k_dist_attr, axis_to_dim_map, k_axes); + auto v_grad = MapDims(v_dist_attr, axis_to_dim_map, v_axes); VLOG(4) << "FlashAttInferSpmd:"; VLOG(4) << "Einsum Notation: " << q_axes << "," << k_axes << "," << v_axes - << "-->" << out_axes << "," << softmax_axes << "," - << softmax_lse_axes; - - VLOG(4) << "q"; - VLOG(4) << "Input shape: [" << str_join(q_shape) << "] " - << "src_dims_mapping: [" << str_join(q_dist_attr.dims_mapping()) - << "] " - << "dst_dims_mapping: [" << str_join(q_dims_mapping) << "]"; - - VLOG(4) << "k"; - VLOG(4) << "Input shape: [" << str_join(k_shape) << "] " - << "src_dims_mapping: [" << str_join(k_dist_attr.dims_mapping()) - << "] " - << "dst_dims_mapping: [" << str_join(v_dims_mapping) << "]"; - - VLOG(4) << "v"; - VLOG(4) << "Input shape: [" << str_join(v_shape) << "] " - << "src_dims_mapping: [" << str_join(v_dist_attr.dims_mapping()) - << "] " - << "dst_dims_mapping: [" << str_join(v_dims_mapping) << "]"; - - VLOG(4) << "out"; - VLOG(4) << "Input shape: [" << str_join(out_shape) << "] " - << "src_dims_mapping: [" << str_join(out_dist_attr.dims_mapping()) - << "] " - << "dst_dims_mapping: [" << str_join(out_dims_mapping) << "]"; - - VLOG(4) << "out_grad"; - VLOG(4) << "Input shape: [" << str_join(out_grad_shape) << "] " - << "src_dims_mapping: [" - << str_join(out_grad_dist_attr.dims_mapping()) << "] " - << "dst_dims_mapping: [" << str_join(out_grad_dims_mapping) << "]"; - - VLOG(4) << "q_grad" - << " dims_mapping: [" << str_join(q_grad_dims_mapping) << "]"; - VLOG(4) << "k_grad" - << " dims_mapping: [" << str_join(k_grad_dims_mapping) << "]"; - VLOG(4) << "v_grad" - << " dims_mapping: [" << str_join(v_grad_dims_mapping) << "]"; - - VLOG(4) << std::endl; + << "-->" << out_axes << "," << softmax_axes << "," << softmax_lse_axes + << std::endl; + VLOG(4) << "Inputs:" << std::endl; + LOG_SPMD_INPUT(q); + LOG_SPMD_INPUT(k); + LOG_SPMD_INPUT(v); + LOG_SPMD_INPUT(out); + LOG_SPMD_INPUT(softmax_lse); + LOG_SPMD_INPUT(seed_offset); + LOG_SPMD_INPUT(attn_mask); + LOG_SPMD_INPUT(out_grad); + VLOG(4) << "Outputs:" << std::endl; + LOG_SPMD_OUTPUT(q_grad); + LOG_SPMD_OUTPUT(k_grad); + LOG_SPMD_OUTPUT(v_grad); return {{q_dist_attr_dst, k_dist_attr_dst, v_dist_attr_dst, out_dist_attr_dst, - fixed_seed_offset_dist_attr, - attn_mask_dist_attr, + softmax_lse_dist_attr_dst, + seed_offset_dist_attr_dst, + attn_mask_dist_attr_dst, out_grad_dist_attr_dst}, {q_grad, k_grad, v_grad}}; } diff --git a/test/auto_parallel/semi_auto_parallel_for_flash_attention.py b/test/auto_parallel/semi_auto_parallel_for_flash_attention.py index 9975b976248a4..e73c124bb0d1d 100644 --- a/test/auto_parallel/semi_auto_parallel_for_flash_attention.py +++ b/test/auto_parallel/semi_auto_parallel_for_flash_attention.py @@ -39,7 +39,7 @@ def test_flash_att_forward(self): inputs_shape=shapes, inputs_specs=specs, op_func=flash_attention, - with_backward=False, + with_backward=True, causal=True, ) self.check_dim_mapping(outputs[0], [0, -1, -1, -1]) @@ -55,7 +55,7 @@ def test_flash_att_forward_reshard(self): inputs_shape=shapes, inputs_specs=specs, op_func=flash_attention, - with_backward=False, + with_backward=True, causal=True, ) self.check_dim_mapping(outputs[0], [0, -1, -1, -1]) diff --git a/test/auto_parallel/semi_auto_parallel_util.py b/test/auto_parallel/semi_auto_parallel_util.py index 84567d19b30ac..3fc00db886ce7 100644 --- a/test/auto_parallel/semi_auto_parallel_util.py +++ b/test/auto_parallel/semi_auto_parallel_util.py @@ -98,12 +98,12 @@ def terminal_cond(x): input_np = np.random.random(size=shape).astype("float32") input = paddle.to_tensor(input_np) input = paddle.cast(input, self._dtype).detach() - input.stop_gradient = False + input.stop_gradient = not with_backward input_dist_attr = dist.DistAttr( mesh=self._mesh, sharding_specs=spec ) dist_input = dist.shard_tensor(input, dist_attr=input_dist_attr) - dist_input.stop_gradient = False + dist_input.stop_gradient = not with_backward flat_inputs.append(input) flat_dist_inputs.append(dist_input) inputs, _ = self.unflatten(flat_inputs, inputs_structure) @@ -127,8 +127,8 @@ def terminal_cond2(x): flat_dist_out, _ = self.flatten(dist_out, terminal_cond2) assert len(flat_out) == len(flat_dist_out) for output, dist_output in zip(flat_out, flat_dist_out): - self.check_tensor_eq(out, dist_out) - if out is not None: + self.check_tensor_eq(output, dist_output) + if output is not None: output.backward() dist_output.backward() From bbf18f67294ca7b84c17f896a4022171504d8537 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Wed, 22 Nov 2023 19:47:17 +0800 Subject: [PATCH 7/9] polsh --- paddle/phi/api/yaml/generator/dist_api_gen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index 758dfdd6251ad..c63d91a09b883 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -917,7 +917,7 @@ def generate_output_creation_code(self) -> str: and self.generate_general_infer_spmd ): output_creation_code += SINGLE_INPLACE_OUT_DIST_ATTR - elif self.infer_meta['spmd_rule'] is not None: + if self.infer_meta['spmd_rule'] is not None: output_creation_code += SINGLE_OUT_CREATION_TEMPLATE else: output_creation_code += SINGLE_OUT_CREATION_TEMPLATE_NO_SPMD From fb1610043bb45e6246427e9eac14ac73fa9c0365 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Wed, 22 Nov 2023 23:19:12 +0800 Subject: [PATCH 8/9] polish --- .../infermeta/spmd_rules/flash_attention.cc | 79 +++++++------------ test/auto_parallel/semi_auto_parallel_util.py | 3 +- test/cpp/auto_parallel/spmd_rule_test.cc | 63 +++++++++++++++ 3 files changed, 94 insertions(+), 51 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/flash_attention.cc b/paddle/phi/infermeta/spmd_rules/flash_attention.cc index 02ef6e49df1a9..669c29a91f3bb 100644 --- a/paddle/phi/infermeta/spmd_rules/flash_attention.cc +++ b/paddle/phi/infermeta/spmd_rules/flash_attention.cc @@ -223,29 +223,21 @@ SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q, // [batch_size, num_heads, seq_len_q, seq_len_kv] std::string softmax_lse_axes = {batch_axis, num_heads_axis, seq_len_q_axis}; - std::string q_axes_align = q_axes; - q_axes_align[1] = alphabet[used_axes_index++]; - q_axes_align[3] = alphabet[used_axes_index++]; - - std::string k_axes_align = k_axes; - k_axes_align[1] = alphabet[used_axes_index++]; - k_axes_align[3] = alphabet[used_axes_index++]; - - std::string v_axes_align = v_axes; - v_axes_align[1] = alphabet[used_axes_index++]; - v_axes_align[3] = alphabet[used_axes_index++]; + auto q_dist_attr_dst = UnShardTensorDims(q_dist_attr, {1, 3}); + auto k_dist_attr_dst = UnShardTensorDims(k_dist_attr, {1, 3}); + auto v_dist_attr_dst = UnShardTensorDims(k_dist_attr, {1, 3}); std::vector>> axes_sharding_info; - axes_sharding_info.emplace_back(q_axes_align, q_dist_attr.dims_mapping()); - axes_sharding_info.emplace_back(k_axes_align, k_dist_attr.dims_mapping()); - axes_sharding_info.emplace_back(v_axes_align, k_dist_attr.dims_mapping()); + axes_sharding_info.emplace_back(q_axes, q_dist_attr_dst.dims_mapping()); + axes_sharding_info.emplace_back(k_axes, k_dist_attr_dst.dims_mapping()); + axes_sharding_info.emplace_back(v_axes, v_dist_attr_dst.dims_mapping()); auto axis_to_dim_map = ShardingMergeForTensors(axes_sharding_info); - auto q_dist_attr_dst = MapDims(q_dist_attr, axis_to_dim_map, q_axes); - auto k_dist_attr_dst = MapDims(k_dist_attr, axis_to_dim_map, k_axes); - auto v_dist_attr_dst = MapDims(v_dist_attr, axis_to_dim_map, v_axes); + q_dist_attr_dst = MapDims(q_dist_attr, axis_to_dim_map, q_axes); + k_dist_attr_dst = MapDims(k_dist_attr, axis_to_dim_map, k_axes); + v_dist_attr_dst = MapDims(v_dist_attr, axis_to_dim_map, v_axes); // TODO(liuzhenhai): process fixed_seed and attn_mask auto fixed_seed_offset_dist_attr_dst = fixed_seed_offset_dist_attr; @@ -469,48 +461,37 @@ SpmdInfo FlashAttGradInferSpmd(const DistMetaTensor& q, // [batch_size, num_heads, seq_len_q] std::string softmax_lse_axes = {batch_axis, num_heads_axis, seq_len_q_axis}; - std::string q_axes_align = q_axes; - q_axes_align[1] = alphabet[used_axes_index++]; - q_axes_align[3] = alphabet[used_axes_index++]; - - std::string k_axes_align = k_axes; - k_axes_align[1] = alphabet[used_axes_index++]; - k_axes_align[3] = alphabet[used_axes_index++]; - - std::string v_axes_align = v_axes; - v_axes_align[1] = alphabet[used_axes_index++]; - v_axes_align[3] = alphabet[used_axes_index++]; - - std::string out_axes_align = out_axes; - out_axes_align[1] = alphabet[used_axes_index++]; - out_axes_align[3] = alphabet[used_axes_index++]; - - std::string out_grad_axes_align = out_axes; - out_grad_axes_align[1] = alphabet[used_axes_index++]; - out_grad_axes_align[3] = alphabet[used_axes_index++]; + auto q_dist_attr_dst = UnShardTensorDims(q_dist_attr, {1, 3}); + auto k_dist_attr_dst = UnShardTensorDims(k_dist_attr, {1, 3}); + auto v_dist_attr_dst = UnShardTensorDims(k_dist_attr, {1, 3}); + auto out_dist_attr_dst = UnShardTensorDims(out_dist_attr, {1, 3}); + auto out_grad_dist_attr_dst = UnShardTensorDims(out_grad_dist_attr, {1, 3}); + auto softmax_lse_dist_attr_dst = + UnShardTensorDims(softmax_lse_dist_attr, {2}); std::vector>> axes_sharding_info; - axes_sharding_info.emplace_back(q_axes_align, q_dist_attr.dims_mapping()); - axes_sharding_info.emplace_back(k_axes_align, k_dist_attr.dims_mapping()); - axes_sharding_info.emplace_back(v_axes_align, k_dist_attr.dims_mapping()); - axes_sharding_info.emplace_back(out_axes_align, out_dist_attr.dims_mapping()); - axes_sharding_info.emplace_back(out_grad_axes_align, - out_grad_dist_attr.dims_mapping()); + axes_sharding_info.emplace_back(q_axes, q_dist_attr_dst.dims_mapping()); + axes_sharding_info.emplace_back(k_axes, k_dist_attr_dst.dims_mapping()); + axes_sharding_info.emplace_back(v_axes, v_dist_attr_dst.dims_mapping()); + axes_sharding_info.emplace_back(out_axes, out_dist_attr_dst.dims_mapping()); + axes_sharding_info.emplace_back(out_axes, + out_grad_dist_attr_dst.dims_mapping()); + axes_sharding_info.emplace_back(softmax_lse_axes, + softmax_lse_dist_attr_dst.dims_mapping()); auto axis_to_dim_map = ShardingMergeForTensors(axes_sharding_info); - auto q_dist_attr_dst = MapDims(q_dist_attr, axis_to_dim_map, q_axes); - auto k_dist_attr_dst = MapDims(k_dist_attr, axis_to_dim_map, k_axes); - auto v_dist_attr_dst = MapDims(v_dist_attr, axis_to_dim_map, v_axes); - auto out_dist_attr_dst = MapDims(out_dist_attr, axis_to_dim_map, out_axes); - auto softmax_lse_dist_attr_dst = + q_dist_attr_dst = MapDims(q_dist_attr, axis_to_dim_map, q_axes); + k_dist_attr_dst = MapDims(k_dist_attr, axis_to_dim_map, k_axes); + v_dist_attr_dst = MapDims(v_dist_attr, axis_to_dim_map, v_axes); + out_dist_attr_dst = MapDims(out_dist_attr, axis_to_dim_map, out_axes); + softmax_lse_dist_attr_dst = MapDims(softmax_lse_dist_attr, axis_to_dim_map, softmax_lse_axes); // TODO(liuzhenhai): process seed and attn_mask auto& seed_offset_dist_attr_dst = seed_offset_dist_attr; auto& attn_mask_dist_attr_dst = attn_mask_dist_attr; + out_grad_dist_attr_dst = MapDims(out_dist_attr, axis_to_dim_map, out_axes); - auto out_grad_dist_attr_dst = - MapDims(out_grad_dist_attr, axis_to_dim_map, out_axes); auto q_grad = MapDims(q_dist_attr, axis_to_dim_map, q_axes); auto k_grad = MapDims(k_dist_attr, axis_to_dim_map, k_axes); auto v_grad = MapDims(v_dist_attr, axis_to_dim_map, v_axes); diff --git a/test/auto_parallel/semi_auto_parallel_util.py b/test/auto_parallel/semi_auto_parallel_util.py index 3fc00db886ce7..8f92f72eea20a 100644 --- a/test/auto_parallel/semi_auto_parallel_util.py +++ b/test/auto_parallel/semi_auto_parallel_util.py @@ -95,9 +95,8 @@ def terminal_cond(x): assert len(flat_inputs_specs) == len(flat_inputs_shape) for shape, spec in zip(flat_inputs_shape, flat_inputs_specs): - input_np = np.random.random(size=shape).astype("float32") + input_np = np.random.random(size=shape).astype(self._dtype) input = paddle.to_tensor(input_np) - input = paddle.cast(input, self._dtype).detach() input.stop_gradient = not with_backward input_dist_attr = dist.DistAttr( mesh=self._mesh, sharding_specs=spec diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index 77c0c555b4564..410b406a5784e 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -1086,6 +1086,69 @@ TEST(LayerNorm, Ctor) { check_partial_dims(spmd2.second[2], {0}); } +TEST(FlashAtt, Ctor) { + std::vector mesh_shape = {2, 2}; + std::vector process_ids = {0, 1, 2, 3}; + std::vector dim_names = {"x", "y"}; + ProcessMesh process_mesh(mesh_shape, process_ids, dim_names); + + auto build_input = [&](const std::vector& shape, + const std::vector& dim_mapping) { + auto t_dist_attr = TensorDistAttr(); + t_dist_attr.set_process_mesh(process_mesh); + t_dist_attr.set_dims_mapping(dim_mapping); + t_dist_attr.set_dynamic_dims(std::vector(shape.size(), false)); + auto input = + phi::distributed::DistMetaTensor(phi::make_ddim(shape), t_dist_attr); + return input; + }; + + // b, s, m, h + std::vector qkv_shape = {2, 256, 2, 128}; + std::vector dim_mapping = {0, 1, -1, -1}; + + auto qkv = build_input(qkv_shape, dim_mapping); + auto mask = build_input({}, {}); + auto seed_offset = build_input({}, {}); + + auto spmd1 = FlashAttInferSpmd( + qkv, qkv, qkv, seed_offset, mask, 0.5, false, false, false, ""); + + EXPECT_EQ(spmd1.first.size(), static_cast(5)); + EXPECT_EQ(spmd1.second.size(), static_cast(4)); + check_dim_mapping(spmd1.first[0], {0, -1, -1, -1}); + check_dim_mapping(spmd1.first[1], {0, -1, -1, -1}); + check_dim_mapping(spmd1.first[2], {0, -1, -1, -1}); + check_dim_mapping(spmd1.first[3], {}); + check_dim_mapping(spmd1.first[4], {}); + check_dim_mapping(spmd1.second[0], {0, -1, -1, -1}); + check_dim_mapping(spmd1.second[1], {0, -1, -1, -1}); + check_dim_mapping(spmd1.second[2], {0, -1, -1}); + check_dim_mapping(spmd1.second[3], {}); + + auto out = build_input(qkv_shape, {0, -1, 1, -1}); + auto softmax_lse = build_input({2, 2, 256}, {0, 1, -1}); + auto out_grad = build_input(qkv_shape, {-1, -1, -1, -1}); + + auto spmd2 = FlashAttGradInferSpmd( + qkv, qkv, qkv, out, softmax_lse, seed_offset, mask, out_grad, 0.5, false); + + EXPECT_EQ(spmd2.first.size(), static_cast(8)); + EXPECT_EQ(spmd2.second.size(), static_cast(3)); + + check_dim_mapping(spmd2.first[0], {0, -1, 1, -1}); + check_dim_mapping(spmd2.first[1], {0, -1, 1, -1}); + check_dim_mapping(spmd2.first[2], {0, -1, 1, -1}); + check_dim_mapping(spmd2.first[3], {0, -1, 1, -1}); + check_dim_mapping(spmd2.first[4], {0, 1, -1}); + check_dim_mapping(spmd2.first[5], {}); + check_dim_mapping(spmd2.first[6], {}); + check_dim_mapping(spmd2.first[7], {0, -1, 1, -1}); + check_dim_mapping(spmd2.second[0], {0, -1, 1, -1}); + check_dim_mapping(spmd2.second[1], {0, -1, 1, -1}); + check_dim_mapping(spmd2.second[2], {0, -1, 1, -1}); +} + TEST(Util, Ctor) { // test equal test not equal using phi::distributed::PartialStatus; From 352333d852ef6e90da5e61ed544e55d15bdab401 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Thu, 23 Nov 2023 11:22:50 +0800 Subject: [PATCH 9/9] polish --- paddle/phi/infermeta/spmd_rules/flash_attention.cc | 5 ++++- paddle/phi/infermeta/spmd_rules/flash_attention.h | 7 +++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/flash_attention.cc b/paddle/phi/infermeta/spmd_rules/flash_attention.cc index 669c29a91f3bb..c12f666523772 100644 --- a/paddle/phi/infermeta/spmd_rules/flash_attention.cc +++ b/paddle/phi/infermeta/spmd_rules/flash_attention.cc @@ -1,8 +1,11 @@ /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + +http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/paddle/phi/infermeta/spmd_rules/flash_attention.h b/paddle/phi/infermeta/spmd_rules/flash_attention.h index 5cb881dd9ad50..c2c0add58f9b4 100644 --- a/paddle/phi/infermeta/spmd_rules/flash_attention.h +++ b/paddle/phi/infermeta/spmd_rules/flash_attention.h @@ -1,8 +1,11 @@ /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -11,12 +14,8 @@ limitations under the License. */ #pragma once -#include -#include #include -#include -#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" #include "paddle/phi/core/distributed/type_defs.h"