Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] Scatter gather 0d support #48452

Merged
merged 2 commits into from
Dec 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 55 additions & 23 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1268,37 +1268,69 @@ void GatherInferMeta(const MetaTensor& x,
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
index_dims.size(),
1,
index_dims.size() == 1 || index_dims.size() == 0,
true,
phi::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d",
"The index should be 0D or 1D, when it is not 2D, but we get %d",
index_dims.size()));
}

auto input_dim = x.dims();
auto axis_v = axis.to<int>();
if (axis.FromTensor() || axis_v == 0) {
// if axis.FromTensor(), we can not obtain correct shape of output
int batch_size = index_dims[0];
phi::DDim output_dims(input_dim);
output_dims[0] = batch_size;
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
} else {
int index_size = index_dims[0];
std::vector<int> out_dim_vec;
for (int i = 0; i < axis_v; i++) {
out_dim_vec.push_back(input_dim[i]);
if (index_dims.size() == 0) {
// 0D index will decrease the dimension
if (input_dim.size() == 1) {
// the index is a 0D tensor and the x is a 1D tensor
out->set_dims(phi::DDim(phi::Dim<0>()));
} else {
if (axis.FromTensor() || axis_v == 0) {
// decrease the output dimension
std::vector<int> out_dim_vec;
for (int i = 1; i < input_dim.size(); ++i) {
out_dim_vec.emplace_back(input_dim[i]);
}
auto output_dims = phi::make_ddim(out_dim_vec);
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
} else {
std::vector<int> out_dim_vec;
for (int i = 0; i < axis_v; i++) {
out_dim_vec.push_back(input_dim[i]);
}
for (int i = axis_v + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
}
auto output_dims = phi::make_ddim(out_dim_vec);
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}
}
out_dim_vec.push_back(index_size);
for (int i = axis_v + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
} else {
if (axis.FromTensor() || axis_v == 0) {
// if axis.FromTensor(), we can not obtain correct shape of output
int batch_size = index_dims[0];
phi::DDim output_dims(input_dim);
output_dims[0] = batch_size;
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
} else {
int index_size = index_dims[0];
std::vector<int> out_dim_vec;
for (int i = 0; i < axis_v; i++) {
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
for (int i = axis_v + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
}
auto output_dims = phi::make_ddim(out_dim_vec);
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}
auto output_dims = phi::make_ddim(out_dim_vec);
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}
}

Expand Down
49 changes: 26 additions & 23 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -995,31 +995,34 @@ void ScatterInferMeta(const MetaTensor& x,
"index is a 2D tensor, but we get %d.",
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(index_dims.size() == 1 || index_dims.size() == 0,
true,
phi::errors::InvalidArgument(
"The index should be a 0D or 1D tensor when the "
"index is not a 2D tensor, but we get %d.",
index_dims.size()));
}
if (index_dims.size() != 0) {
PADDLE_ENFORCE_EQ(
index_dims.size(),
1,
phi::errors::InvalidArgument("The index should be a 1D tensor when the "
"index is not a 2D tensor, but we get %d.",
index_dims.size()));
(ref_dims.size() == updates_dims.size()),
true,
phi::errors::InvalidArgument(
"When the Input(Updates) is not a 0D tensor, the "
"Input(X) and Input(Updates) should have the same shape size, "
"but received the size of Input(x)'s shape is %d, the size of "
"Input(Updates)'s shape is %d.",
ref_dims.size(),
updates_dims.size()));
PADDLE_ENFORCE_EQ(
updates_dims[0],
index_dims[0],
phi::errors::InvalidArgument(
"Input(Updates) and Input(Ids) should have same batch-size, but"
" received Input(Updates)'s batch-size is %d, Input(Ids)'s "
"batch-size is %d.",
updates_dims[0],
index_dims[0]));
}
PADDLE_ENFORCE_EQ(
ref_dims.size(),
updates_dims.size(),
phi::errors::InvalidArgument(
"Input(X) and Input(Updates) should have the same shape size, "
"but received the size of Input(x)'s shape is %d, the size of "
"Input(Updates)'s shape is %d.",
ref_dims.size(),
updates_dims.size()));
PADDLE_ENFORCE_EQ(
updates_dims[0],
index_dims[0],
phi::errors::InvalidArgument(
"Input(Updates) and Input(Ids) should have same batch-size, but"
" received Input(Updates)'s batch-size is %d, Input(Ids)'s "
"batch-size is %d.",
updates_dims[0],
index_dims[0]));
out->set_dims(ref_dims);
out->share_lod(x);
out->set_dtype(x.dtype());
Expand Down
9 changes: 4 additions & 5 deletions paddle/phi/kernels/funcs/gather.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,9 @@ void GPUGather(const phi::GPUContext& ctx,
}

// index size
int64_t index_size = index.dims()[0];
if (index_size == 0) return;
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];

auto src_dims = src.dims();
phi::DDim output_dims(src_dims);
output_dims[0] = index_size;

// slice size
int64_t slice_size = 1;
Expand Down Expand Up @@ -246,7 +243,9 @@ void GatherV2CUDAFunction(const DenseTensor* input,
inner_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
if (index->dims().size() != 0) {
out_dim_vec.push_back(index_size);
}
for (int i = axis_index + 1; i < input_dim.size(); i++) {
outer_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
Expand Down
28 changes: 18 additions & 10 deletions paddle/phi/kernels/funcs/gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ void CPUGather(const phi::CPUContext& ctx,
const DenseTensor& src,
const DenseTensor& index,
DenseTensor* output) {
// check index of shape 1-D
if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(
index.dims()[1],
Expand All @@ -48,14 +47,15 @@ void CPUGather(const phi::CPUContext& ctx,
"in gather_op, but received value is [%d].",
index.dims()[1]));
} else {
PADDLE_ENFORCE_EQ(index.dims().size(),
1,
phi::errors::InvalidArgument(
"index.dims().size() should be 1 or 2 in gather_op,"
"but received shape's size is [%d].",
index.dims().size()));
PADDLE_ENFORCE_EQ(
index.dims().size() == 1 || index.dims().size() == 0,
true,
phi::errors::InvalidArgument(
"The index should be 0D or 1D, when it is not 2D, but we get %d",
index.dims().size()));
}
int64_t index_size = index.dims()[0];

int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];

auto src_dims = src.dims();

Expand Down Expand Up @@ -188,7 +188,9 @@ void GatherV2Function(const phi::CPUContext& ctx,
inner_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
if (index->dims().size() != 0) {
out_dim_vec.push_back(index_size);
}
for (int i = axis_index + 1; i < input_dim.size(); i++) {
outer_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
Expand Down Expand Up @@ -224,7 +226,13 @@ void GatherV2GradFunction(const phi::CPUContext& ctx,

if (input->numel() == 0) return;
int axis_index = axis;
int64_t input_index_dim_size = input_dim[axis_index];
int64_t input_index_dim_size;
if (input_dim.size() == out->dims().size()) {
input_index_dim_size = input_dim[axis_index];
} else {
// 0d index
input_index_dim_size = 1;
}

int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
Expand Down
26 changes: 16 additions & 10 deletions paddle/phi/kernels/funcs/scatter.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ void GPUScatterAssign(const phi::GPUContext& ctx,
const DenseTensor& index,
DenseTensor* output,
bool overwrite = true) {
// check index of shape 1-D
if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(
index.dims()[1],
Expand All @@ -132,26 +131,33 @@ void GPUScatterAssign(const phi::GPUContext& ctx,
"But received value is [%d]",
index.dims()[1]));
} else {
PADDLE_ENFORCE_EQ(index.dims().size(),
1,
phi::errors::InvalidArgument(
"index.dims().size() should be 1 or 2 in scatter_op."
"But received value is [%d]",
index.dims().size()));
PADDLE_ENFORCE_EQ(
index.dims().size() == 1 || index.dims().size() == 0,
true,
phi::errors::InvalidArgument(
"index.dims().size() should be 0, 1 or 2 in scatter_op."
"But received value is [%d]",
index.dims().size()));
}
int64_t index_size = index.dims()[0];

int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];

auto src_dims = src.dims();
phi::DDim output_dims(src_dims);
output_dims[0] = index_size;

// slice size
int64_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
size_t slice_size = 1;
if (index.dims().size() != 0) {
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
} else {
for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
}

const T* p_src = src.data<T>();
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();

const size_t& slice_bytes = slice_size * sizeof(T);

// set block and grid num
Expand Down
Loading