Skip to content

Commit

Permalink
infer reduce dims using out dims
Browse files Browse the repository at this point in the history
  • Loading branch information
cxxly committed Jan 16, 2023
1 parent 31ec399 commit bbe3480
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 66 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/operators/expand_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ class ExpandV2GradCompositeOpMaker : public prim::GradCompositeOpMakerBase {
auto x_grad_p = this->GetOutputPtr(&x_grad);
auto x_grad_name = this->GetOutputName(x_grad);
auto shape = this->Attr<std::vector<int>>("shape");
prim::expand_grad<prim::DescTensor>(x, out_grad, IntArray(shape), x_grad_p);
prim::expand_grad<prim::DescTensor>(
x, out_grad, paddle::experimental::IntArray(shape), x_grad_p);
this->RecoverOutputName(x_grad, x_grad_name);
}
};
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/prim/api/manual/backward/composite_backward_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ namespace prim {
using Tensor = paddle::experimental::Tensor;
using IntArray =
paddle::experimental::IntArrayBase<paddle::experimental::Tensor>;
// using IntArray = paddle::experimental::IntArray;
// This function should have as same signature as phi, which defined in
// paddle/phi/api/backward/backward_api.h
template <typename T>
Expand All @@ -34,6 +33,7 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
auto grad_x_tmp = multiply<T>(grad_out, tmp);
grad_x->set_impl(grad_x_tmp.impl());
}

template <typename T>
void subtract_grad(const Tensor& x,
const Tensor& y,
Expand Down Expand Up @@ -148,9 +148,9 @@ void sum_grad(const Tensor& x,
axis_ = axis.GetData();
}
auto out_grad_ = unsqueeze<T>(out_grad, axis_);
x_grad_tmp = expand<T>(out_grad_, x_dim);
x_grad_tmp = expand<T>(out_grad_, IntArray(x_dim));
} else {
x_grad_tmp = expand<T>(out_grad, x_dim);
x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
}

x_grad->set_impl(x_grad_tmp.impl());
Expand Down Expand Up @@ -214,11 +214,11 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
auto div_x = full<T>(phi::vectorize(out.dims()), 0.5);
auto tmp = divide<T>(div_x, out);
auto x_grad_tmp = multiply<T>(out_grad, tmp);
x_grad->set_impl(x_grad_tmp.impl());
x_grad->set_impl(x_grad_tmp.impl());
}
}

template<typename T>
template <typename T>
void multiply_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
Expand Down
19 changes: 2 additions & 17 deletions paddle/fluid/prim/api/manual/prim_api/static_prim_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ Tensor full<DescTensor>(paddle::experimental::IntArray shape,
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs();
op->InferVarType(block);
// TODO(jiabin, cxxly): This may have runtime shape skip infershape for now.
op->InferShape(*block);
return out;
}

Expand Down Expand Up @@ -222,22 +222,7 @@ Tensor reshape<DescTensor>(Tensor x, paddle::experimental::IntArray shape) {
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}

template <>
Tensor expand<DescTensor>(const Tensor& x, const IntArray& shape) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("expand_v2");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetAttr("Shape", paddle::any_cast<std::vector<int>>(shape.GetData()));
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
// TODO(jiabin, cxxly): This may have runtime shape skip infershape for now.
return out;
}

Expand Down
67 changes: 31 additions & 36 deletions paddle/fluid/prim/api/manual/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/ddim.h"
using IntArray = paddle::experimental::IntArray;

namespace paddle {
namespace prim {
// We put some api like utils here
Expand All @@ -37,47 +38,41 @@ template <typename T>
void by_pass(const paddle::experimental::Tensor& x,
paddle::experimental::Tensor* out);

// Returns reduced axes for param@shape, which broadcast with or broadcast to
// param@ref_shape.
// Note: Broadcast semantics is bidirectional. This method only returns reduced
// axes for direction shape to ref_shape.
static phi::DDim get_reduce_dims(const phi::DDim& shape,
const phi::DDim& ref_shape) {
// These method don't need to be specified
static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims,
const phi::DDim& in_dims) {
std::vector<int64_t> result;
auto src_shape = phi::vectorize(shape);
auto dst_shape = phi::vectorize(ref_shape);

// Align rank
if (src_shape.size() > dst_shape.size()) {
auto size = src_shape.size() - dst_shape.size();
for (std::size_t i = 0; i < size; i++) {
dst_shape.insert(std::begin(dst_shape), 1);
}
} else {
auto size = dst_shape.size() - src_shape.size();
for (std::size_t i = 0; i < size; i++) {
src_shape.insert(std::begin(src_shape), 1);
}
int bat = dout_dims.size() - in_dims.size();
for (int i = 0; i < bat; ++i) {
result.push_back(i);
}

// Reduced axes
for (std::size_t i = 0; i < src_shape.size(); i++) {
if (src_shape[i] == 1 && dst_shape[i] > 1) {
result.push_back(i);
} else if (src_shape[i] != dst_shape[i] && src_shape[i] != 1 &&
dst_shape[i] != 1) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The input arguments of GetReduceDims are not broadcastable. The "
"size of parameter shape[%d]:%d can not broadcast with the size "
"of parameter ref_shape[%d]:%d.",
i,
src_shape[i],
i,
dst_shape[i]));
for (int i = 0; i < in_dims.size(); ++i) {
if (in_dims[i] == 1) {
result.push_back(i + bat);
} else {
PADDLE_ENFORCE_EQ(
in_dims[i],
dout_dims[i + bat],
platform::errors::InvalidArgument(
"ReduceDims dimension mismatch. Operands could "
"not be broadcast together with the shape of dout = [%s] and "
"the shape of in_dims = [%s]. Received [%d] in X is not equal to "
"[%d] in Y at i:%d.",
dout_dims,
in_dims,
dout_dims[i + bat],
in_dims[i],
i));
}
}
return phi::make_ddim(result);
}

static phi::DDim get_reduce_dims(const phi::DDim& x_dims,
const phi::DDim& y_dims) {
auto out_dims = paddle::operators::details::BroadcastTwoDims(x_dims, y_dims);
return get_reduce_dims_from_out(out_dims, x_dims);
}

} // namespace prim
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,17 @@
np.random.rand(2, 3, 1, 4),
np.float32,
),
(np.random.rand(2, 3, 3, 4), np.random.rand(2, 3, 1, 4), np.float32),
(
np.random.rand(2, 3, 3, 4),
np.random.rand(2, 3, 1, 4),
np.float32
),
( np.random.rand(2, 1, 3, 4),
np.random.rand(2, 1, 3, 4),
np.random.rand(2, 3, 1, 4),
np.float32,
),
(
np.random.rand(2, 3, 3, 4),
np.random.rand(2, 1, 1, 4),
np.float32,
)
),
],
)
class TestDivGradComp(unittest.TestCase):
Expand Down

0 comments on commit bbe3480

Please sign in to comment.