Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Phi]Add meshgrid yaml and unittest #41411

Merged
merged 1 commit into from
Apr 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions paddle/phi/api/lib/api_custom_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,5 +410,153 @@ std::vector<Tensor> stack_grad_impl(const std::vector<Tensor>& x,
return x_grad;
}

std::vector<Tensor> meshgrid_impl(const std::vector<Tensor>& inputs) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;

if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(inputs);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}

const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"meshgrid", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "meshgrid API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "meshgrid API kernel: " << kernel;

auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);

auto input_inputs_vec = PrepareData(inputs, kernel.InputAt(0), {});
std::vector<const phi::DenseTensor*> input_inputs(input_inputs_vec->size());
for (size_t i = 0; i < input_inputs.size(); ++i) {
input_inputs[i] = &input_inputs_vec->at(i);
}

auto x_meta_vec = MakeMetaTensor(input_inputs);
std::vector<phi::MetaTensor*> inputs_metas(x_meta_vec.size());
for (size_t i = 0; i < x_meta_vec.size(); ++i) {
inputs_metas[i] = &x_meta_vec[i];
}

// Calculate the number of out tensors
size_t out_number = inputs.size();

std::vector<Tensor> out;
auto dense_outs = SetKernelOutput(out_number, kernel_backend, &out);

std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_number);
std::vector<phi::MetaTensor*> meta_out_ptrs;
meta_out_ptrs.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
meta_outs.push_back(dense_outs[i]);
meta_out_ptrs.push_back(&meta_outs.back());
}
phi::MeshgridInferMeta(inputs_metas, meta_out_ptrs);

using kernel_signature = void (*)(const platform::DeviceContext&,
const std::vector<const phi::DenseTensor*>&,
std::vector<phi::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, input_inputs, dense_outs);

return out;
}

std::vector<Tensor> meshgrid_grad_impl(
const std::vector<Tensor>& inputs,
const std::vector<Tensor>& outputs_grad) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;

if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(inputs, outputs_grad);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}

const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"meshgrid_grad", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "meshgrid_grad API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "meshgrid_grad API kernel: " << kernel;

auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);

auto input_inputs_vec = PrepareData(inputs, kernel.InputAt(0), {});
std::vector<const phi::DenseTensor*> input_inputs(input_inputs_vec->size());
for (size_t i = 0; i < input_inputs.size(); ++i) {
input_inputs[i] = &input_inputs_vec->at(i);
}
auto input_outputs_grad_vec =
PrepareData(outputs_grad, kernel.InputAt(1), {});
std::vector<const phi::DenseTensor*> input_outputs_grad(
input_outputs_grad_vec->size());
for (size_t i = 0; i < input_outputs_grad.size(); ++i) {
input_outputs_grad[i] = &input_outputs_grad_vec->at(i);
}

size_t out_number = inputs.size();
std::vector<Tensor> api_output;
auto kernel_out = SetKernelOutput(out_number, kernel_backend, &api_output);

auto inputs_meta_vec = MakeMetaTensor(input_inputs);
std::vector<phi::MetaTensor*> inputs_metas(inputs_meta_vec.size());
for (size_t i = 0; i < inputs_meta_vec.size(); ++i) {
inputs_metas[i] = &inputs_meta_vec[i];
}

auto outputs_grad_meta_vec = MakeMetaTensor(input_outputs_grad);
std::vector<phi::MetaTensor*> outputs_grad_metas(
outputs_grad_meta_vec.size());
for (size_t i = 0; i < outputs_grad_meta_vec.size(); ++i) {
outputs_grad_metas[i] = &outputs_grad_meta_vec[i];
}

std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_number);
std::vector<phi::MetaTensor*> meta_out_ptrs;
meta_out_ptrs.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
meta_outs.push_back(kernel_out[i]);
meta_out_ptrs.push_back(&meta_outs.back());
}

phi::MeshgridGradInferMeta(inputs_metas, outputs_grad_metas, meta_out_ptrs);

using kernel_signature = void (*)(const platform::DeviceContext&,
const std::vector<const phi::DenseTensor*>&,
const std::vector<const phi::DenseTensor*>&,
std::vector<phi::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, input_inputs, input_outputs_grad, kernel_out);

return api_output;
}

} // namespace experimental
} // namespace paddle
3 changes: 3 additions & 0 deletions paddle/phi/api/lib/api_custom_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ std::vector<Tensor> concat_grad_impl(const std::vector<Tensor>& x,
std::vector<Tensor> stack_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad,
int axis);
std::vector<Tensor> meshgrid_impl(const std::vector<Tensor>& inputs);
std::vector<Tensor> meshgrid_grad_impl(const std::vector<Tensor>& inputs,
const std::vector<Tensor>& outputs_grad);

} // namespace experimental
} // namespace paddle
14 changes: 14 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,20 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
dx->share_meta(x);
}

void MeshgridGradInferMeta(const std::vector<MetaTensor*>& inputs,
const std::vector<MetaTensor*>& outputs_grad,
std::vector<MetaTensor*> inputs_grad) {
PADDLE_ENFORCE_GT(outputs_grad.size(),
1,
errors::InvalidArgument(
"Number of Inputs(Out@Grad) should be larger than 1."
"But received Inputs(Out@Grad)' size = %d .",
outputs_grad.size()));
for (size_t i = 0; i < inputs.size(); i++) {
inputs_grad[i]->share_meta(*inputs[i]);
}
}

void NllLossGradInferMeta(const MetaTensor& x,
const MetaTensor& label,
paddle::optional<const MetaTensor&> weight,
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
bool adaptive,
MetaTensor* dx);

void MeshgridGradInferMeta(const std::vector<MetaTensor*>& inputs,
const std::vector<MetaTensor*>& outputs_grad,
std::vector<MetaTensor*> inputs_grad);

void NllLossGradInferMeta(const MetaTensor& input,
const MetaTensor& label,
paddle::optional<const MetaTensor&> weight,
Expand Down
43 changes: 43 additions & 0 deletions python/paddle/fluid/tests/unittests/test_meshgrid_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import paddle.fluid as fluid
import paddle
from paddle.fluid import compiler, Program, program_guard, core
from paddle.fluid.framework import _test_eager_guard


class TestMeshgridOp(OpTest):
Expand Down Expand Up @@ -149,6 +150,10 @@ def test_api_with_dygraph(self):
assert np.array_equal(res_3.shape, [100, 200])
assert np.array_equal(res_4.shape, [100, 200])

def test_api_eager_dygraph(self):
with _test_eager_guard():
self.test_api_with_dygraph()


class TestMeshgridOp7(unittest.TestCase):
def test_api_with_dygraph_list_input(self):
Expand All @@ -163,6 +168,10 @@ def test_api_with_dygraph_list_input(self):
assert np.array_equal(res_3.shape, [100, 200])
assert np.array_equal(res_4.shape, [100, 200])

def test_api_eager_dygraph(self):
with _test_eager_guard():
self.test_api_with_dygraph_list_input()


class TestMeshgridOp8(unittest.TestCase):
def test_api_with_dygraph_tuple_input(self):
Expand All @@ -177,6 +186,40 @@ def test_api_with_dygraph_tuple_input(self):
assert np.array_equal(res_3.shape, [100, 200])
assert np.array_equal(res_4.shape, [100, 200])

def test_api_eager_dygraph(self):
with _test_eager_guard():
self.test_api_with_dygraph_tuple_input()


class TestMeshgridEager(unittest.TestCase):
def test_dygraph_final_state_api(self):
input_1 = np.random.randint(0, 100, [100, ]).astype('int32')
input_2 = np.random.randint(0, 100, [200, ]).astype('int32')

with fluid.dygraph.guard():
tensor_1 = fluid.dygraph.to_variable(input_1)
tensor_2 = fluid.dygraph.to_variable(input_2)
tensor_1.stop_gradient = False
tensor_2.stop_gradient = False
res_1, res_2 = paddle.tensor.meshgrid((tensor_1, tensor_2))
sum = paddle.add_n([res_1, res_2])
sum.backward()
with _test_eager_guard():
tensor_eager_1 = fluid.dygraph.to_variable(input_1)
tensor_eager_2 = fluid.dygraph.to_variable(input_2)
tensor_eager_1.stop_gradient = False
tensor_eager_2.stop_gradient = False
res_eager_1, res_eager_2 = paddle.tensor.meshgrid(
(tensor_eager_1, tensor_eager_2))
sum_eager = paddle.add_n([res_eager_1, res_eager_2])
sum_eager.backward()
self.assertEqual((
tensor_1.grad.numpy() == tensor_eager_1.grad.numpy()).all(),
True)
self.assertEqual((
tensor_2.grad.numpy() == tensor_eager_2.grad.numpy()).all(),
True)


if __name__ == '__main__':
paddle.enable_static()
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,10 +776,12 @@ def meshgrid(*args, **kwargs):

if len(args) == 1 and isinstance(args[0], (list, tuple)):
args = args[0]
if paddle.in_dynamic_mode():
if _in_legacy_dygraph():
num = len(args)
out = _C_ops.meshgrid(list(args), num)
return out
if in_dygraph_mode():
return _C_ops.final_state_meshgrid(list(args))

name = kwargs.get("name", None)
helper = LayerHelper('meshgrid', **locals())
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,12 @@
func : mean
backward : mean_grad

- api : meshgrid
args : (Tensor[] inputs)
output : Tensor[]
invoke : meshgrid_impl(inputs)
backward : meshgrid_grad

- api : min
args : (Tensor x, int64_t[] dims={}, bool keep_dim=false)
output : Tensor(out)
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/utils/code_gen/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,12 @@
kernel :
func : mean_grad

- backward_api : meshgrid_grad
forward : meshgrid (Tensor[] inputs) -> Tensor[](outputs)
args : (Tensor[] inputs, Tensor[] outputs_grad)
output : Tensor[](inputs_grad)
invoke : meshgrid_grad_impl(inputs, outputs_grad)

- backward_api : min_grad
forward: min (Tensor x, int64_t[] dims={}, bool keep_dim=false) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] dims={}, bool keep_dim=false, bool reduce_all=false)
Expand Down