Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Semi Auto] Entropy SPMD Rule #55394

Merged
merged 74 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
a47bf99
base rule
JZ-LIANG May 16, 2023
21b5a75
add sharidng merge
JZ-LIANG May 18, 2023
f7e39d7
add sharidng axis merge
JZ-LIANG May 19, 2023
c92992d
define unified data class for inferencing dist_attr
pkuzyc May 18, 2023
42a7b77
test wrap DistTensorSpec in dygraph mode
pkuzyc May 19, 2023
180edcc
matmul main logic done
JZ-LIANG May 23, 2023
ecbb1ae
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JZ-LIANG May 23, 2023
f314b56
Merge remote-tracking branch 'zyc/develop' into semi-auto/rule-base
JZ-LIANG May 23, 2023
46153c7
shape int64
JZ-LIANG May 23, 2023
1dcb80e
common cc
JZ-LIANG May 23, 2023
198bc1f
define unified data class for inferencing dist_attr
pkuzyc May 18, 2023
09d82a5
test wrap DistTensorSpec in dygraph mode
pkuzyc May 19, 2023
c3ea2a6
define python api and wrap function in static mode for DistTensorSpec
pkuzyc May 23, 2023
4cd1a2c
revise syntax
JZ-LIANG May 24, 2023
3631f06
Merge remote-tracking branch 'zyc/develop' into semi-auto/rule-base
JZ-LIANG May 24, 2023
ed0c31e
map bugfix
JZ-LIANG May 29, 2023
701d3fa
broadcast func
JZ-LIANG May 29, 2023
c1545a4
compile 1
JZ-LIANG May 29, 2023
3ca2b73
add unitest
JZ-LIANG May 31, 2023
747de08
add registry
JZ-LIANG Jun 6, 2023
968ce61
Merge branch 'semi-auto/rule-base' of https://github.com/JZ-LIANG/Pad…
JZ-LIANG Jun 6, 2023
7be672d
update unitest
JZ-LIANG Jun 6, 2023
3389b7e
bugfix
JZ-LIANG Jun 6, 2023
3882a2c
bugfix
JZ-LIANG Jun 6, 2023
3719a5a
add pybind
JZ-LIANG Jun 6, 2023
73f49a8
bugfix
JZ-LIANG Jun 6, 2023
aced5ea
bugfix macro gloabl name space
JZ-LIANG Jun 6, 2023
ef92dc4
bugfix macro gloabl name space
JZ-LIANG Jun 6, 2023
adcb470
segment fault
JZ-LIANG Jun 8, 2023
43df148
pybind
JZ-LIANG Jun 8, 2023
27803af
pybind test
JZ-LIANG Jun 8, 2023
5612da9
pybind bugfixed1
JZ-LIANG Jun 14, 2023
f9bd281
pybind bugfixed2
JZ-LIANG Jun 14, 2023
18f8d29
pybind unitest
JZ-LIANG Jun 14, 2023
2628043
Merge remote-tracking branch 'upstream/develop' into semi-auto/rule-base
JZ-LIANG Jun 16, 2023
68a512a
merge dev
JZ-LIANG Jun 16, 2023
f2b2edb
merge dev
JZ-LIANG Jun 16, 2023
132558a
merge dev
JZ-LIANG Jun 16, 2023
f3bc740
fixed cmake conflict
JZ-LIANG Jun 16, 2023
c11cdd2
fixed cmake conflict
JZ-LIANG Jun 16, 2023
491bf65
rename get method
JZ-LIANG Jun 20, 2023
041abd4
revise inferforward output type
JZ-LIANG Jun 20, 2023
60c90d3
revise comment
JZ-LIANG Jun 20, 2023
d5d7557
replicated rule
JZ-LIANG Jun 21, 2023
44e9404
replicated rule 2
JZ-LIANG Jun 21, 2023
7657ee5
revert bug deps
JZ-LIANG Jun 27, 2023
223f960
Merge branch 'semi-auto/revert-phi-dep' into semi-auto/replicated-rule
JZ-LIANG Jun 27, 2023
5dc1be3
add rule
JZ-LIANG Jun 28, 2023
3ce0e74
add unitest
JZ-LIANG Jun 28, 2023
80f2a03
add rule
JZ-LIANG Jun 29, 2023
062970d
add unitest
JZ-LIANG Jul 6, 2023
0cb4a9c
move ut of auto_parallel
zhiqiu Jul 6, 2023
ab67ce1
fix ut
zhiqiu Jul 7, 2023
f9675bd
Merge remote-tracking branch 'upstream/develop' into semi-auto/embedd…
JZ-LIANG Jul 7, 2023
7e31dea
Merge branch 'dev/mv_ut' of https://github.com/zhiqiu/Paddle into sem…
JZ-LIANG Jul 7, 2023
694b310
Merge branch 'semi-auto/embedding-rule' into semi-auto/softmax-rule
JZ-LIANG Jul 7, 2023
2d4e938
bugfix
JZ-LIANG Jul 7, 2023
f45eca8
bugfix
JZ-LIANG Jul 7, 2023
e9b4ddc
bugfix
JZ-LIANG Jul 7, 2023
43a4373
bugfix
JZ-LIANG Jul 7, 2023
ad31f1b
bugfix
JZ-LIANG Jul 7, 2023
9ca9969
bugfix
JZ-LIANG Jul 7, 2023
dfad99d
bugfix
JZ-LIANG Jul 7, 2023
fc3dfe6
Merge remote-tracking branch 'upstream/develop' into semi-auto/embedd…
JZ-LIANG Jul 10, 2023
def09f0
Merge branch 'semi-auto/embedding-rule' into semi-auto/softmax-rule
JZ-LIANG Jul 10, 2023
934cc61
resolute input sharding conflict maybe
JZ-LIANG Jul 11, 2023
daf098a
Merge branch 'semi-auto/embedding-rule' into semi-auto/softmax-rule
JZ-LIANG Jul 11, 2023
99a10f4
fixed comment
JZ-LIANG Jul 12, 2023
49257b9
Merge remote-tracking branch 'upstream/develop' into semi-auto/entrop…
JZ-LIANG Jul 12, 2023
fcf2ccb
add rule
JZ-LIANG Jul 12, 2023
4d2a854
Merge remote-tracking branch 'upstream/develop' into semi-auto/entrop…
JZ-LIANG Jul 13, 2023
6f7199a
add unitest
JZ-LIANG Jul 13, 2023
7c69300
Merge remote-tracking branch 'upstream/develop' into semi-auto/entrop…
JZ-LIANG Jul 17, 2023
013412f
fixed typoes
JZ-LIANG Jul 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h"

namespace paddle {
namespace distributed {
namespace auto_parallel {

using phi::distributed::auto_parallel::str_join;

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
CrossEntropyWithSoftmaxSPMDRule::InferForward(
const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: verify input args based on cross_entropy_with_softmax logic
auto input_specs_size = input_specs.size();
PADDLE_ENFORCE_EQ(
input_specs_size,
2,
phi::errors::InvalidArgument("The size of InputSpec of cross entropy "
"with softmax should be 2, but got [%d].",
input_specs_size));

auto x_shape = input_specs[0].shape();
int x_ndim = x_shape.size();
auto x_dist_attr_src = input_specs[0].dist_attr();
std::vector<int64_t> x_dims_mapping_src = x_dist_attr_src.dims_mapping();

auto label_shape = input_specs[1].shape();
auto label_dist_attr_src = input_specs[1].dist_attr();
std::vector<int64_t> label_dims_mapping_src =
label_dist_attr_src.dims_mapping();

int axis = ExtractAttr<int>("axis", attrs);
int ignore_index = ExtractAttr<int>("ignore_index", attrs);
bool numeric_stable_mode = ExtractAttr<bool>("numeric_stable_mode", attrs);
bool use_softmax = ExtractAttr<bool>("use_softmax", attrs);
bool soft_label = ExtractAttr<bool>("soft_label", attrs);

VLOG(6) << "CrossEntropyWithSoftmaxSPMDRule InferForward Inputs: "
<< "X shape: [" << str_join(x_shape) << "], x_dims_mapping_src: ["
<< str_join(x_dims_mapping_src) << "]; Label shape: ["
<< str_join(label_shape) << "], Label dims mapping: ["
<< str_join(label_dims_mapping_src) << "]; axis: "
<< "[" << axis << "], ignore_index: [" << ignore_index
<< "], numeric_stable_mode: [" << numeric_stable_mode
<< "], use_softmax: [" << use_softmax << "], soft_label: ["
<< soft_label << "].";

// normalize axis
if (axis < 0) {
axis = x_ndim + axis;
}

// trying to shard the normal axis of softmax, BUT
// c_softmax_with_entropy kernel not support:
// 1. soft label
// 2. axis != -1
// support above two features in future.
if (x_dims_mapping_src[axis] > -1) {
PADDLE_ENFORCE_EQ(
soft_label,
false,
phi::errors::InvalidArgument(
"Trying to shard the softmax_normalize axis of the input tensor, "
"but the soft_label is set as True, which is not supported yet!"));

PADDLE_ENFORCE_EQ(
axis,
x_ndim - 1,
phi::errors::InvalidArgument(
"Trying to shard the softmax_normalize axis of the input tensor, "
"but the softmax_normalize axis is not the last axis, which is not "
"supported yet! The softmax_normalize is [%d].",
axis));

PADDLE_ENFORCE_EQ(use_softmax,
true,
phi::errors::InvalidArgument(
"Trying to shard the softmax_normalize axis of the "
"input tensor, use_softmax must be set to True !"));
}

// step1: build Einsum Notation
std::string alphabet =
"abcdefghijlmnopqrstuvwxyz"; // k for softmax_normalize axis
std::string broadcast_axes =
GetBroadcastAxes(x_ndim - 1, x_ndim - 1, alphabet);
std::string x_axes = broadcast_axes;
x_axes.insert(axis, "k");
std::string label_axes;
if (soft_label) {
label_axes = x_axes;
} else {
label_axes = broadcast_axes;
label_axes.insert(axis, "1");
}
std::string loss_axes = broadcast_axes;
loss_axes.insert(axis, "1");
// optional output
std::string softmax_out_axes;
if (use_softmax) {
softmax_out_axes = x_axes;
} else {
softmax_out_axes = "";
}

// step2: Sharding Propogation
std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info;
axes_sharding_info =
GetAxesDimsMappingPair({x_axes, label_axes}, input_specs);
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors(axes_sharding_info);

// step3: Infer dst Dims Mapping.
TensorDistAttr loss_dist_attr_dst =
CopyTensorDistAttrForOutput(label_dist_attr_src);
loss_dist_attr_dst.set_dims_mapping(
GetDimsMappingForAxes(loss_axes, axis_to_dim_map));
TensorDistAttr softmax_out_dist_attr_dst =
CopyTensorDistAttrForOutput(x_dist_attr_src);
softmax_out_dist_attr_dst.set_dims_mapping(
GetDimsMappingForAxes(softmax_out_axes, axis_to_dim_map));

TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(
GetDimsMappingForAxes(x_axes, axis_to_dim_map));
TensorDistAttr label_dist_attr_dst =
CopyTensorDistAttrForOutput(label_dist_attr_src);
label_dist_attr_dst.set_dims_mapping(
GetDimsMappingForAxes(label_axes, axis_to_dim_map));

VLOG(4) << "CrossEntropyWithSoftmaxSPMDRule InferForward Inputs: "
<< "Einsum notation: [" << x_axes << "," << label_axes << " --> "
<< softmax_out_axes << "," << loss_axes << "]. " << std::endl
<< "X shape: [" << str_join(x_shape) << "], x_dims_mapping_src: ["
<< str_join(x_dims_mapping_src) << "], x_dims_mapping_dst: ["
<< str_join(x_dist_attr_dst.dims_mapping()) << "]; Label shape: ["
<< str_join(label_shape) << "], label_dims_mapping_src: ["
<< str_join(label_dims_mapping_src) << "], label_dims_mapping_dst: ["
<< str_join(label_dist_attr_dst.dims_mapping())
<< "]; loss_dims_mapping: ["
<< str_join(loss_dist_attr_dst.dims_mapping())
<< "], softmax_out_dims_mapping_src: ["
<< str_join(softmax_out_dist_attr_dst.dims_mapping()) << "]; axis: "
<< "[" << axis << "], ignore_index: [" << ignore_index
<< "], numeric_stable_mode: ["
<< (numeric_stable_mode ? "true" : "false") << "], use_softmax: ["
<< (use_softmax ? "true" : "false") << "], soft_label: ["
<< (soft_label ? "true" : "false") << "].";

// todo if softmax_normalize axis is sharded, notify downstream phi api to
// select c_softmax_with_entropy_kernel.

// according to the phi api implemetation, the softmax_out tensor will alway
// be genereated not matter the value of use_softmax.
return {{x_dist_attr_dst, label_dist_attr_dst},
{softmax_out_dist_attr_dst, loss_dist_attr_dst}};
}

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
CrossEntropyWithSoftmaxSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of CrossEntropyWithSoftmaxSPMDRule is NOT implemented "
"yet."));
}

} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"

namespace paddle {
namespace distributed {
namespace auto_parallel {

class CrossEntropyWithSoftmaxSPMDRule : public SPMDRuleBase {
public:
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) override;

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
5 changes: 5 additions & 0 deletions paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h"
Expand Down Expand Up @@ -145,6 +146,10 @@ REGISTER_SPMD_RULE(lookup_table_v2, EmbeddingSPMDRule);
REGISTER_SPMD_RULE(softmax, SoftmaxSPMDRule);
REGISTER_SPMD_RULE(log_softmax, SoftmaxSPMDRule);

// cross_entropy_with_softmax
REGISTER_SPMD_RULE(cross_entropy_with_softmax, CrossEntropyWithSoftmaxSPMDRule);
REGISTER_SPMD_RULE(softmax_with_cross_entropy, CrossEntropyWithSoftmaxSPMDRule);
Comment on lines +150 to +151
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two ops have same attrs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, here is to adopt for the difference between paddle static mode and dygraph mode, we will unify it in future.


} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
152 changes: 152 additions & 0 deletions test/auto_parallel/spmd_rules/test_cross_entropy_with_softmax_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from paddle.distributed.auto_parallel.static.completion import get_spmd_rule
from paddle.distributed.auto_parallel.static.dist_attribute import (
DistTensorSpec,
TensorDistAttr,
)
from paddle.distributed.fleet import auto


class TestCrossEntropyWithSoftmaxSPMDRule(unittest.TestCase):
def setUp(self):
self.rule1 = get_spmd_rule("cross_entropy_with_softmax")

x_shape = [8, 1024, 50304] # [batch_size, max_seq_len, vocab_size]
label_shape = [8, 1024, 1]

process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
x_tensor_dist_attr = TensorDistAttr()
x_tensor_dist_attr.process_mesh = process_mesh
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
label_tensor_dist_attr = TensorDistAttr()
label_tensor_dist_attr.process_mesh = process_mesh
self.lable_dist_tensor_spec = DistTensorSpec(
label_shape, label_tensor_dist_attr
)

self.attrs = {
'ignore_index': -1,
'axis': -1,
'numeric_stable_mode': True,
'use_softmax': True,
'soft_label': False,
}

def test_cross_entropy_with_softmax_infer_forward(self):
# GPT DP case
self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1])
self.lable_dist_tensor_spec.set_dims_mapping([-1, 0, -1])

result_dist_attrs = self.rule1.infer_forward(
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec], self.attrs
)
self.assertEqual(len(result_dist_attrs), 2)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(len(infered_input_dist_attrs), 2)
self.assertEqual(len(infered_output_dist_attrs), 2)

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, 0, -1])

self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [1, 0, -1]
) # loss
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, 0, -1]
) # softmax output

# GPT MP case, shard normalized axis
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0])
self.lable_dist_tensor_spec.set_dims_mapping([-1, -1, -1])

result_dist_attrs = self.rule1.infer_forward(
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1, -1])

self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [-1, -1, -1]
) # loss
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0]
) # softmax output

# GPT MP-DP case
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0])
self.lable_dist_tensor_spec.set_dims_mapping([1, -1, -1])

result_dist_attrs = self.rule1.infer_forward(
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, 0])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, -1, -1])

self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [1, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, 0])

# Soft Label Error
self.attrs['soft_label'] = True
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0])
self.lable_dist_tensor_spec.set_dims_mapping([1, -1, -1])
with self.assertRaises(ValueError):
result_dist_attrs = self.rule1.infer_forward(
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec],
self.attrs,
)
self.attrs['soft_label'] = False

# Normalized axis
self.attrs['axis'] = 1
self.x_dist_tensor_spec.set_dims_mapping([1, -1, 0])
self.lable_dist_tensor_spec.set_dims_mapping([-1, -1, -1])
result_dist_attrs = self.rule1.infer_forward(
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, 0])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, -1, 0])

self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [1, -1, 0])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, 0])
self.attrs['axis'] = -1

# Soft Normalized axis Error
self.attrs['axis'] = 1
self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1])
self.lable_dist_tensor_spec.set_dims_mapping([1, -1, -1])
with self.assertRaises(ValueError):
result_dist_attrs = self.rule1.infer_forward(
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec],
self.attrs,
)
self.attrs['axis'] = -1


if __name__ == "__main__":
unittest.main()