Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash att infer #59083

Merged
merged 13 commits into from
Nov 23, 2023
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,7 @@
infer_meta :
func : FlashAttnGradInferMeta
param : [q, k, v]
spmd_rule : FlashAttGradInferSpmd
kernel :
func : flash_attn_grad
data_type: q
Expand Down
13 changes: 12 additions & 1 deletion paddle/phi/api/yaml/generator/dist_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,14 @@ def generate_specialized_infer_spmd_code(self) -> str:
name=param
)
input_args_code += "meta_dist_input_" + param + ", "
elif (
self.inputs['input_info'][param]
== "const paddle::optional<Tensor>&"
):
input_decl_code += (
OPTIONAL_SINGLE_DIST_META_IN_TEMPLATE.format(name=param)
)
input_args_code += "meta_dist_input_" + param + ", "
elif (
self.inputs['input_info'][param]
== "const std::vector<Tensor>&"
Expand Down Expand Up @@ -900,7 +908,10 @@ def generate_output_creation_code(self) -> str:
# kernel output generate
self.dist_output_args.append('dist_out')
self.dense_output_args.append('dense_out')
if self.outputs['types'][0] == 'Tensor':
if (
self.outputs['types'][0] == 'Tensor'
or self.outputs['types'][0] == 'const paddle::optional<Tensor>'
):
if (
self.need_to_generate_code_for_inplace_impl(0)
and self.generate_general_infer_spmd
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,7 @@
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
spmd_rule : FlashAttInferSpmd
kernel :
func : flash_attn
data_type : q
Expand Down
Loading