Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Semi-Auto] Add cross_entropy_with_softmax infer_backward rule #56507

Merged
merged 1 commit into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ CrossEntropyWithSoftmaxSPMDRule::InferForward(
input_specs_size));

auto x_shape = input_specs[0].shape();
int x_ndim = static_cast<int>(x_shape.size());
int x_ndim = x_shape.size();
auto x_dist_attr_src = input_specs[0].dist_attr();
std::vector<int64_t> x_dims_mapping_src = x_dist_attr_src.dims_mapping();

Expand Down Expand Up @@ -173,10 +173,116 @@ CrossEntropyWithSoftmaxSPMDRule::InferForward(
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
CrossEntropyWithSoftmaxSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of CrossEntropyWithSoftmaxSPMDRule is NOT implemented "
"yet."));
// step0: verify input args based on cross_entropy_with_softmax logic
int64_t ninputs = input_specs.size();
int64_t noutputs = output_specs.size();
PADDLE_ENFORCE_EQ(
ninputs,
2,
phi::errors::InvalidArgument("The size of InputSpec of cross entropy "
"with softmax should be 2, but got [%d].",
ninputs));
PADDLE_ENFORCE_EQ(
noutputs,
2,
phi::errors::InvalidArgument("The size of OutputSpec of cross entropy "
"with softmax should be 2, but got [%d].",
noutputs));
VerifySpecs(output_specs, "cross_entropy_with_softmax_backward");

// step1: build Einsum Notation
std::vector<int64_t> x_shape = input_specs[0].shape();
int64_t x_ndim = x_shape.size();
std::vector<int64_t> label_shape = input_specs[1].shape();

int axis = ExtractAttr<int>("axis", attrs);
int ignore_index = ExtractAttr<int>("ignore_index", attrs);
bool numeric_stable_mode = ExtractAttr<bool>("numeric_stable_mode", attrs);
bool use_softmax = ExtractAttr<bool>("use_softmax", attrs);
bool soft_label = ExtractAttr<bool>("soft_label", attrs);

// normalize axis
if (axis < 0) {
axis = x_ndim + axis;
}

std::string alphabet =
"abcdefghijlmnopqrstuvwxyz"; // k for softmax_normalize axis
std::string x_axes = GetBroadcastAxes(x_ndim, x_ndim, alphabet);
x_axes[axis] = 'k';
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should check if softmax_normalize is the last axis, if it is not, it could not be sharded.

and if softmax_normalize is the last axis,but if soft_label == true, it could not be sharded.

std::string label_axes = x_axes;
if (!soft_label) {
label_axes[axis] = '1';
}
std::string loss_axes = x_axes;
loss_axes[axis] = '1';
// optional output
std::string softmax_out_axes;
if (use_softmax) {
softmax_out_axes = x_axes;
} else {
softmax_out_axes = "";
}

// step2: Sharding Propogation
// step2.1 merge output dims mappings
std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info;
axes_sharding_info =
GetAxesDimsMappingPair({softmax_out_axes, loss_axes}, output_specs);
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors(axes_sharding_info);

// step2.2 infer inputs' dims mappings from merged dims mapping
std::vector<TensorDistAttr> input_dist_attrs;
input_dist_attrs.emplace_back(input_specs[0].dist_attr());
input_dist_attrs.emplace_back(input_specs[1].dist_attr());
// infer and set input X's dims mapping
input_dist_attrs[0].set_dims_mapping(
GetDimsMappingForAxes(x_axes, axis_to_dim_map));
// infer and set input label's dims mapping
input_dist_attrs[1].set_dims_mapping(
GetDimsMappingForAxes(label_axes, axis_to_dim_map));

// step2.3 update outputs' dims mappings with merged dims mapping
std::vector<TensorDistAttr> output_dist_attrs;
output_dist_attrs.emplace_back(output_specs[0].dist_attr()); // softmax_out
output_dist_attrs.emplace_back(output_specs[1].dist_attr()); // loss
output_dist_attrs[0].set_dims_mapping(
GetDimsMappingForAxes(softmax_out_axes, axis_to_dim_map));
output_dist_attrs[1].set_dims_mapping(
GetDimsMappingForAxes(loss_axes, axis_to_dim_map));

// step3: Handle partial state (TODO)

VLOG(4) << "CrossEntropyWithSoftmaxSPMDRule InferBackward: "
<< "axis: " << axis << ", ignore_index: " << ignore_index
<< ", numeric_stable_mode: "
<< (numeric_stable_mode ? "true" : "false")
<< ", use_softmax: " << use_softmax
<< ", soft_label: " << (soft_label ? "true" : "false");
VLOG(4) << "Einsum notation: [" << x_axes << "," << label_axes << " --> "
<< softmax_out_axes << "," << loss_axes << "]. (inputs --> outputs)";
for (int64_t i = 0; i < noutputs; i++) {
VLOG(4) << "Output" << std::to_string(i) << ": "
<< "shape: [" << str_join(output_specs[i].shape())
<< "], src_dims_mapping: ["
<< str_join(output_specs[i].dims_mapping())
<< "], dst_dims_mapping: ["
<< str_join(output_dist_attrs[i].dims_mapping()) << "]";
}
for (int64_t i = 0; i < ninputs; i++) {
VLOG(4) << "Input" << std::to_string(i) << ": "
<< "shape: [" << str_join(input_specs[i].shape())
<< "], infered_dims_mapping: ["
<< str_join(input_dist_attrs[i].dims_mapping()) << "]";
}
VLOG(4) << std::endl;

// according to the phi api implemetation, the softmax_out tensor will alway
// be genereated not matter the value of use_softmax.
return {input_dist_attrs, output_dist_attrs};
}

} // namespace auto_parallel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class CrossEntropyWithSoftmaxSPMDRule : public SPMDRuleBase {
const paddle::framework::AttributeMap& attrs) override;

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
Expand Down
119 changes: 119 additions & 0 deletions test/auto_parallel/spmd_rules/test_cross_entropy_with_softmax_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def setUp(self):
label_shape, label_tensor_dist_attr
)

self.loss_spec = DistTensorSpec(self.lable_dist_tensor_spec)
self.softmax_out_spec = DistTensorSpec(self.x_dist_tensor_spec)

self.attrs = {
'ignore_index': -1,
'axis': -1,
Expand Down Expand Up @@ -147,6 +150,122 @@ def test_cross_entropy_with_softmax_infer_forward(self):
)
self.attrs['axis'] = -1

def test_cross_entropy_with_softmax_infer_backward(self):
# GPT DP case
# [1, 0, -1], [1, 0, -1] (outputs) -->
# [1, 0, -1], [1, 0, -1], (inputs)
# [1, 0, -1], [1, 0, -1] (outputs)
self.attrs['axis'] = -1
self.attrs['use_softmax'] = True
self.attrs['soft_label'] = False
self.softmax_out_spec.set_dims_mapping([1, 0, -1])
self.loss_spec.set_dims_mapping([1, 0, -1])

result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec],
[self.softmax_out_spec, self.loss_spec],
self.attrs,
)
self.assertEqual(len(result_dist_attrs), 2)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(len(infered_input_dist_attrs), 2)
self.assertEqual(len(infered_output_dist_attrs), 2)

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, 0, -1])

self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, 0, -1]
) # softmax output
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [1, 0, -1]
) # loss

# GPT MP case, shard normalized axis
# [-1, -1, 0], [-1, -1, -1] (outputs) -->
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

distinguish
outputs: loss (last axis is 1) & softmax_out
inputs: label(last axis maybe 1) & logits

# [-1, -1, 0], [-1, -1, -1], (inputs)
# [-1, -1, 0], [-1, -1, -1] (outputs)
self.attrs['axis'] = -1
self.attrs['use_softmax'] = True
self.attrs['soft_label'] = False
self.softmax_out_spec.set_dims_mapping([-1, -1, 0])
self.loss_spec.set_dims_mapping([-1, -1, -1])

result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec],
[self.softmax_out_spec, self.loss_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1, -1])

self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0]
) # softmax output
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [-1, -1, -1]
) # loss

# GPT MP-DP case
# [-1, -1, 0], [1, -1, -1] (outputs) -->
# [1, -1, 0], [1, -1, -1], (inputs)
# [1, -1, 0], [1, -1, -1] (outputs)
self.attrs['axis'] = -1
self.attrs['use_softmax'] = True
self.attrs['soft_label'] = False
self.softmax_out_spec.set_dims_mapping([-1, -1, 0])
self.loss_spec.set_dims_mapping([1, -1, -1])

result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec],
[self.softmax_out_spec, self.loss_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, 0])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, -1, -1])

self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, -1, 0]
) # softmax output
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [1, -1, -1]
) # loss

# Soft Label, normalized axis = 1
# [1, -1, 0], [1, -1, -1] (outputs) -->
# [1, -1, 0], [1, -1, 0], (inputs)
# [1, -1, 0], [1, -1, 0] (outputs)
self.attrs['axis'] = 1
self.attrs['use_softmax'] = True
self.attrs['soft_label'] = True
self.softmax_out_spec.set_dims_mapping([1, -1, 0])
self.loss_spec.set_dims_mapping([1, -1, -1])
result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec],
[self.softmax_out_spec, self.loss_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, 0])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, -1, 0])

self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, -1, 0]
) # softmax output
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [1, -1, 0]
) # loss


if __name__ == "__main__":
unittest.main()