-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Semi Auto] Entropy SPMD Rule #55394
Merged
zhiqiu
merged 74 commits into
PaddlePaddle:develop
from
JZ-LIANG:semi-auto/entropy-rule
Jul 20, 2023
Merged
Changes from all commits
Commits
Show all changes
74 commits
Select commit
Hold shift + click to select a range
a47bf99
base rule
JZ-LIANG 21b5a75
add sharidng merge
JZ-LIANG f7e39d7
add sharidng axis merge
JZ-LIANG c92992d
define unified data class for inferencing dist_attr
pkuzyc 42a7b77
test wrap DistTensorSpec in dygraph mode
pkuzyc 180edcc
matmul main logic done
JZ-LIANG ecbb1ae
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JZ-LIANG f314b56
Merge remote-tracking branch 'zyc/develop' into semi-auto/rule-base
JZ-LIANG 46153c7
shape int64
JZ-LIANG 1dcb80e
common cc
JZ-LIANG 198bc1f
define unified data class for inferencing dist_attr
pkuzyc 09d82a5
test wrap DistTensorSpec in dygraph mode
pkuzyc c3ea2a6
define python api and wrap function in static mode for DistTensorSpec
pkuzyc 4cd1a2c
revise syntax
JZ-LIANG 3631f06
Merge remote-tracking branch 'zyc/develop' into semi-auto/rule-base
JZ-LIANG ed0c31e
map bugfix
JZ-LIANG 701d3fa
broadcast func
JZ-LIANG c1545a4
compile 1
JZ-LIANG 3ca2b73
add unitest
JZ-LIANG 747de08
add registry
JZ-LIANG 968ce61
Merge branch 'semi-auto/rule-base' of https://github.com/JZ-LIANG/Pad…
JZ-LIANG 7be672d
update unitest
JZ-LIANG 3389b7e
bugfix
JZ-LIANG 3882a2c
bugfix
JZ-LIANG 3719a5a
add pybind
JZ-LIANG 73f49a8
bugfix
JZ-LIANG aced5ea
bugfix macro gloabl name space
JZ-LIANG ef92dc4
bugfix macro gloabl name space
JZ-LIANG adcb470
segment fault
JZ-LIANG 43df148
pybind
JZ-LIANG 27803af
pybind test
JZ-LIANG 5612da9
pybind bugfixed1
JZ-LIANG f9bd281
pybind bugfixed2
JZ-LIANG 18f8d29
pybind unitest
JZ-LIANG 2628043
Merge remote-tracking branch 'upstream/develop' into semi-auto/rule-base
JZ-LIANG 68a512a
merge dev
JZ-LIANG f2b2edb
merge dev
JZ-LIANG 132558a
merge dev
JZ-LIANG f3bc740
fixed cmake conflict
JZ-LIANG c11cdd2
fixed cmake conflict
JZ-LIANG 491bf65
rename get method
JZ-LIANG 041abd4
revise inferforward output type
JZ-LIANG 60c90d3
revise comment
JZ-LIANG d5d7557
replicated rule
JZ-LIANG 44e9404
replicated rule 2
JZ-LIANG 7657ee5
revert bug deps
JZ-LIANG 223f960
Merge branch 'semi-auto/revert-phi-dep' into semi-auto/replicated-rule
JZ-LIANG 5dc1be3
add rule
JZ-LIANG 3ce0e74
add unitest
JZ-LIANG 80f2a03
add rule
JZ-LIANG 062970d
add unitest
JZ-LIANG 0cb4a9c
move ut of auto_parallel
zhiqiu ab67ce1
fix ut
zhiqiu f9675bd
Merge remote-tracking branch 'upstream/develop' into semi-auto/embedd…
JZ-LIANG 7e31dea
Merge branch 'dev/mv_ut' of https://github.com/zhiqiu/Paddle into sem…
JZ-LIANG 694b310
Merge branch 'semi-auto/embedding-rule' into semi-auto/softmax-rule
JZ-LIANG 2d4e938
bugfix
JZ-LIANG f45eca8
bugfix
JZ-LIANG e9b4ddc
bugfix
JZ-LIANG 43a4373
bugfix
JZ-LIANG ad31f1b
bugfix
JZ-LIANG 9ca9969
bugfix
JZ-LIANG dfad99d
bugfix
JZ-LIANG fc3dfe6
Merge remote-tracking branch 'upstream/develop' into semi-auto/embedd…
JZ-LIANG def09f0
Merge branch 'semi-auto/embedding-rule' into semi-auto/softmax-rule
JZ-LIANG 934cc61
resolute input sharding conflict maybe
JZ-LIANG daf098a
Merge branch 'semi-auto/embedding-rule' into semi-auto/softmax-rule
JZ-LIANG 99a10f4
fixed comment
JZ-LIANG 49257b9
Merge remote-tracking branch 'upstream/develop' into semi-auto/entrop…
JZ-LIANG fcf2ccb
add rule
JZ-LIANG 4d2a854
Merge remote-tracking branch 'upstream/develop' into semi-auto/entrop…
JZ-LIANG 6f7199a
add unitest
JZ-LIANG 7c69300
Merge remote-tracking branch 'upstream/develop' into semi-auto/entrop…
JZ-LIANG 013412f
fixed typoes
JZ-LIANG File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
184 changes: 184 additions & 0 deletions
184
paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h" | ||
|
||
namespace paddle { | ||
namespace distributed { | ||
namespace auto_parallel { | ||
|
||
using phi::distributed::auto_parallel::str_join; | ||
|
||
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> | ||
CrossEntropyWithSoftmaxSPMDRule::InferForward( | ||
const std::vector<DistTensorSpec>& input_specs, | ||
const paddle::framework::AttributeMap& attrs) { | ||
// step0: verify input args based on cross_entropy_with_softmax logic | ||
auto input_specs_size = input_specs.size(); | ||
PADDLE_ENFORCE_EQ( | ||
input_specs_size, | ||
2, | ||
phi::errors::InvalidArgument("The size of InputSpec of cross entropy " | ||
"with softmax should be 2, but got [%d].", | ||
input_specs_size)); | ||
|
||
auto x_shape = input_specs[0].shape(); | ||
int x_ndim = x_shape.size(); | ||
auto x_dist_attr_src = input_specs[0].dist_attr(); | ||
std::vector<int64_t> x_dims_mapping_src = x_dist_attr_src.dims_mapping(); | ||
|
||
auto label_shape = input_specs[1].shape(); | ||
auto label_dist_attr_src = input_specs[1].dist_attr(); | ||
std::vector<int64_t> label_dims_mapping_src = | ||
label_dist_attr_src.dims_mapping(); | ||
|
||
int axis = ExtractAttr<int>("axis", attrs); | ||
int ignore_index = ExtractAttr<int>("ignore_index", attrs); | ||
bool numeric_stable_mode = ExtractAttr<bool>("numeric_stable_mode", attrs); | ||
bool use_softmax = ExtractAttr<bool>("use_softmax", attrs); | ||
bool soft_label = ExtractAttr<bool>("soft_label", attrs); | ||
|
||
VLOG(6) << "CrossEntropyWithSoftmaxSPMDRule InferForward Inputs: " | ||
<< "X shape: [" << str_join(x_shape) << "], x_dims_mapping_src: [" | ||
<< str_join(x_dims_mapping_src) << "]; Label shape: [" | ||
<< str_join(label_shape) << "], Label dims mapping: [" | ||
<< str_join(label_dims_mapping_src) << "]; axis: " | ||
<< "[" << axis << "], ignore_index: [" << ignore_index | ||
<< "], numeric_stable_mode: [" << numeric_stable_mode | ||
<< "], use_softmax: [" << use_softmax << "], soft_label: [" | ||
<< soft_label << "]."; | ||
|
||
// normalize axis | ||
if (axis < 0) { | ||
axis = x_ndim + axis; | ||
} | ||
|
||
// trying to shard the normal axis of softmax, BUT | ||
// c_softmax_with_entropy kernel not support: | ||
// 1. soft label | ||
// 2. axis != -1 | ||
// support above two features in future. | ||
if (x_dims_mapping_src[axis] > -1) { | ||
PADDLE_ENFORCE_EQ( | ||
soft_label, | ||
false, | ||
phi::errors::InvalidArgument( | ||
"Trying to shard the softmax_normalize axis of the input tensor, " | ||
"but the soft_label is set as True, which is not supported yet!")); | ||
|
||
PADDLE_ENFORCE_EQ( | ||
axis, | ||
x_ndim - 1, | ||
phi::errors::InvalidArgument( | ||
"Trying to shard the softmax_normalize axis of the input tensor, " | ||
"but the softmax_normalize axis is not the last axis, which is not " | ||
"supported yet! The softmax_normalize is [%d].", | ||
axis)); | ||
|
||
PADDLE_ENFORCE_EQ(use_softmax, | ||
true, | ||
phi::errors::InvalidArgument( | ||
"Trying to shard the softmax_normalize axis of the " | ||
"input tensor, use_softmax must be set to True !")); | ||
} | ||
|
||
// step1: build Einsum Notation | ||
std::string alphabet = | ||
"abcdefghijlmnopqrstuvwxyz"; // k for softmax_normalize axis | ||
std::string broadcast_axes = | ||
GetBroadcastAxes(x_ndim - 1, x_ndim - 1, alphabet); | ||
std::string x_axes = broadcast_axes; | ||
x_axes.insert(axis, "k"); | ||
std::string label_axes; | ||
if (soft_label) { | ||
label_axes = x_axes; | ||
} else { | ||
label_axes = broadcast_axes; | ||
label_axes.insert(axis, "1"); | ||
} | ||
std::string loss_axes = broadcast_axes; | ||
loss_axes.insert(axis, "1"); | ||
// optional output | ||
std::string softmax_out_axes; | ||
if (use_softmax) { | ||
softmax_out_axes = x_axes; | ||
} else { | ||
softmax_out_axes = ""; | ||
} | ||
|
||
// step2: Sharding Propogation | ||
std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info; | ||
axes_sharding_info = | ||
GetAxesDimsMappingPair({x_axes, label_axes}, input_specs); | ||
std::unordered_map<std::string, int64_t> axis_to_dim_map = | ||
ShardingMergeForTensors(axes_sharding_info); | ||
|
||
// step3: Infer dst Dims Mapping. | ||
TensorDistAttr loss_dist_attr_dst = | ||
CopyTensorDistAttrForOutput(label_dist_attr_src); | ||
loss_dist_attr_dst.set_dims_mapping( | ||
GetDimsMappingForAxes(loss_axes, axis_to_dim_map)); | ||
TensorDistAttr softmax_out_dist_attr_dst = | ||
CopyTensorDistAttrForOutput(x_dist_attr_src); | ||
softmax_out_dist_attr_dst.set_dims_mapping( | ||
GetDimsMappingForAxes(softmax_out_axes, axis_to_dim_map)); | ||
|
||
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); | ||
x_dist_attr_dst.set_dims_mapping( | ||
GetDimsMappingForAxes(x_axes, axis_to_dim_map)); | ||
TensorDistAttr label_dist_attr_dst = | ||
CopyTensorDistAttrForOutput(label_dist_attr_src); | ||
label_dist_attr_dst.set_dims_mapping( | ||
GetDimsMappingForAxes(label_axes, axis_to_dim_map)); | ||
|
||
VLOG(4) << "CrossEntropyWithSoftmaxSPMDRule InferForward Inputs: " | ||
<< "Einsum notation: [" << x_axes << "," << label_axes << " --> " | ||
<< softmax_out_axes << "," << loss_axes << "]. " << std::endl | ||
<< "X shape: [" << str_join(x_shape) << "], x_dims_mapping_src: [" | ||
<< str_join(x_dims_mapping_src) << "], x_dims_mapping_dst: [" | ||
<< str_join(x_dist_attr_dst.dims_mapping()) << "]; Label shape: [" | ||
<< str_join(label_shape) << "], label_dims_mapping_src: [" | ||
<< str_join(label_dims_mapping_src) << "], label_dims_mapping_dst: [" | ||
<< str_join(label_dist_attr_dst.dims_mapping()) | ||
<< "]; loss_dims_mapping: [" | ||
<< str_join(loss_dist_attr_dst.dims_mapping()) | ||
<< "], softmax_out_dims_mapping_src: [" | ||
<< str_join(softmax_out_dist_attr_dst.dims_mapping()) << "]; axis: " | ||
<< "[" << axis << "], ignore_index: [" << ignore_index | ||
<< "], numeric_stable_mode: [" | ||
<< (numeric_stable_mode ? "true" : "false") << "], use_softmax: [" | ||
<< (use_softmax ? "true" : "false") << "], soft_label: [" | ||
<< (soft_label ? "true" : "false") << "]."; | ||
|
||
// todo if softmax_normalize axis is sharded, notify downstream phi api to | ||
// select c_softmax_with_entropy_kernel. | ||
|
||
// according to the phi api implemetation, the softmax_out tensor will alway | ||
// be genereated not matter the value of use_softmax. | ||
return {{x_dist_attr_dst, label_dist_attr_dst}, | ||
{softmax_out_dist_attr_dst, loss_dist_attr_dst}}; | ||
} | ||
|
||
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> | ||
CrossEntropyWithSoftmaxSPMDRule::InferBackward( | ||
const std::vector<DistTensorSpec>& input_specs, | ||
const paddle::framework::AttributeMap& attrs) { | ||
PADDLE_THROW(phi::errors::Unimplemented( | ||
"InferBackward of CrossEntropyWithSoftmaxSPMDRule is NOT implemented " | ||
"yet.")); | ||
} | ||
|
||
} // namespace auto_parallel | ||
} // namespace distributed | ||
} // namespace paddle |
35 changes: 35 additions & 0 deletions
35
paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" | ||
|
||
namespace paddle { | ||
namespace distributed { | ||
namespace auto_parallel { | ||
|
||
class CrossEntropyWithSoftmaxSPMDRule : public SPMDRuleBase { | ||
public: | ||
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> | ||
InferForward(const std::vector<DistTensorSpec>& input_specs, | ||
const paddle::framework::AttributeMap& attrs) override; | ||
|
||
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> | ||
InferBackward(const std::vector<DistTensorSpec>& output_specs, | ||
const paddle::framework::AttributeMap& attrs) override; | ||
}; | ||
} // namespace auto_parallel | ||
} // namespace distributed | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
152 changes: 152 additions & 0 deletions
152
test/auto_parallel/spmd_rules/test_cross_entropy_with_softmax_rule.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
|
||
from paddle.distributed.auto_parallel.static.completion import get_spmd_rule | ||
from paddle.distributed.auto_parallel.static.dist_attribute import ( | ||
DistTensorSpec, | ||
TensorDistAttr, | ||
) | ||
from paddle.distributed.fleet import auto | ||
|
||
|
||
class TestCrossEntropyWithSoftmaxSPMDRule(unittest.TestCase): | ||
def setUp(self): | ||
self.rule1 = get_spmd_rule("cross_entropy_with_softmax") | ||
|
||
x_shape = [8, 1024, 50304] # [batch_size, max_seq_len, vocab_size] | ||
label_shape = [8, 1024, 1] | ||
|
||
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) | ||
x_tensor_dist_attr = TensorDistAttr() | ||
x_tensor_dist_attr.process_mesh = process_mesh | ||
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) | ||
label_tensor_dist_attr = TensorDistAttr() | ||
label_tensor_dist_attr.process_mesh = process_mesh | ||
self.lable_dist_tensor_spec = DistTensorSpec( | ||
label_shape, label_tensor_dist_attr | ||
) | ||
|
||
self.attrs = { | ||
'ignore_index': -1, | ||
'axis': -1, | ||
'numeric_stable_mode': True, | ||
'use_softmax': True, | ||
'soft_label': False, | ||
} | ||
|
||
def test_cross_entropy_with_softmax_infer_forward(self): | ||
# GPT DP case | ||
self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1]) | ||
self.lable_dist_tensor_spec.set_dims_mapping([-1, 0, -1]) | ||
|
||
result_dist_attrs = self.rule1.infer_forward( | ||
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec], self.attrs | ||
) | ||
self.assertEqual(len(result_dist_attrs), 2) | ||
infered_input_dist_attrs = result_dist_attrs[0] | ||
infered_output_dist_attrs = result_dist_attrs[1] | ||
|
||
self.assertEqual(len(infered_input_dist_attrs), 2) | ||
self.assertEqual(len(infered_output_dist_attrs), 2) | ||
|
||
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0, -1]) | ||
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, 0, -1]) | ||
|
||
self.assertEqual( | ||
infered_output_dist_attrs[1].dims_mapping, [1, 0, -1] | ||
) # loss | ||
self.assertEqual( | ||
infered_output_dist_attrs[0].dims_mapping, [1, 0, -1] | ||
) # softmax output | ||
|
||
# GPT MP case, shard normalized axis | ||
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0]) | ||
self.lable_dist_tensor_spec.set_dims_mapping([-1, -1, -1]) | ||
|
||
result_dist_attrs = self.rule1.infer_forward( | ||
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec], self.attrs | ||
) | ||
infered_input_dist_attrs = result_dist_attrs[0] | ||
infered_output_dist_attrs = result_dist_attrs[1] | ||
|
||
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0]) | ||
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1, -1]) | ||
|
||
self.assertEqual( | ||
infered_output_dist_attrs[1].dims_mapping, [-1, -1, -1] | ||
) # loss | ||
self.assertEqual( | ||
infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0] | ||
) # softmax output | ||
|
||
# GPT MP-DP case | ||
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0]) | ||
self.lable_dist_tensor_spec.set_dims_mapping([1, -1, -1]) | ||
|
||
result_dist_attrs = self.rule1.infer_forward( | ||
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec], self.attrs | ||
) | ||
infered_input_dist_attrs = result_dist_attrs[0] | ||
infered_output_dist_attrs = result_dist_attrs[1] | ||
|
||
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, 0]) | ||
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, -1, -1]) | ||
|
||
self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [1, -1, -1]) | ||
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, 0]) | ||
|
||
# Soft Label Error | ||
self.attrs['soft_label'] = True | ||
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0]) | ||
self.lable_dist_tensor_spec.set_dims_mapping([1, -1, -1]) | ||
with self.assertRaises(ValueError): | ||
result_dist_attrs = self.rule1.infer_forward( | ||
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec], | ||
self.attrs, | ||
) | ||
self.attrs['soft_label'] = False | ||
|
||
# Normalized axis | ||
self.attrs['axis'] = 1 | ||
self.x_dist_tensor_spec.set_dims_mapping([1, -1, 0]) | ||
self.lable_dist_tensor_spec.set_dims_mapping([-1, -1, -1]) | ||
result_dist_attrs = self.rule1.infer_forward( | ||
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec], self.attrs | ||
) | ||
infered_input_dist_attrs = result_dist_attrs[0] | ||
infered_output_dist_attrs = result_dist_attrs[1] | ||
|
||
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, 0]) | ||
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, -1, 0]) | ||
|
||
self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [1, -1, 0]) | ||
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, 0]) | ||
self.attrs['axis'] = -1 | ||
|
||
# Soft Normalized axis Error | ||
self.attrs['axis'] = 1 | ||
self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1]) | ||
self.lable_dist_tensor_spec.set_dims_mapping([1, -1, -1]) | ||
with self.assertRaises(ValueError): | ||
result_dist_attrs = self.rule1.infer_forward( | ||
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec], | ||
self.attrs, | ||
) | ||
self.attrs['axis'] = -1 | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The two ops have same attrs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, here is to adopt for the difference between paddle static mode and dygraph mode, we will unify it in future.