Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… trt-IPluginV2Ext
  • Loading branch information
zlsh80826 committed Jul 9, 2021
2 parents df49070 + 1f28968 commit 95f4d3d
Show file tree
Hide file tree
Showing 14 changed files with 549 additions and 149 deletions.
3 changes: 3 additions & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,6 @@
| zhaopu7 | Pu Zhao |
| zhouxiao-coder | Xiao Zhou |
| Zrachel | Rui-Qing Zhang |
| jeng1220 | Bai-Cheng(Ryan) Jeng (NVIDIA) |
| mingxu1067 | Ming Huang (NVIDIA) |
| zlsh80826 | Reese Wang (NVIDIA) |
8 changes: 8 additions & 0 deletions paddle/fluid/imperative/partial_grad_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ static void GetGraphInfoBetweenTargets(
std::unordered_map<OpBase *, size_t> *op_deps_ptr,
std::unordered_set<VariableWrapper *> *related_grad_vars_ptr,
const std::unordered_set<VariableWrapper *> &no_grad_var_grad) {
VLOG(10) << "prune graph starts";
/**
* Step 1. Find the candidate startup grad ops, prepared for following BFS.
*/
Expand Down Expand Up @@ -117,6 +118,8 @@ static void GetGraphInfoBetweenTargets(
auto *op = op_node_pair.first;
auto *node = op_node_pair.second;

VLOG(10) << "Visit node " << node << " , visit op " << op->Type();

for (auto &output_pair : op->GetOutsMap()) {
if (!output_pair.second.IsGrad()) {
VLOG(10) << "WARNING: " << op->Type() << " outputs a forward var";
Expand All @@ -135,6 +138,7 @@ static void GetGraphInfoBetweenTargets(

for (auto &pending_node : node->GradPendingNodes()) {
if (visited.count(pending_node.get()) == 0) {
visited.insert(pending_node.get());
for (auto &pending_op : *pending_node) {
preceding_ops[&pending_op].insert(op);
q.emplace(&pending_op, pending_node.get());
Expand All @@ -143,6 +147,8 @@ static void GetGraphInfoBetweenTargets(
}
}

VLOG(10) << "Found endpoint op ends";

/**
* Step 3. Based on the found input_target_grads, BFS the graph in reverse
* order. `target_vars` would record all grad vars in the graph, and
Expand Down Expand Up @@ -246,6 +252,8 @@ static void GetGraphInfoBetweenTargets(
}
}

VLOG(10) << "Found startup op ends";

/**
* Step 4. Prune output_targets which is not the input of startup_ops
*/
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/operators/slice_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,16 @@ namespace operators {

using Tensor = framework::Tensor;

void UpdateAttr(const framework::DDim in_dims, const std::vector<int> axes,
void UpdateAttr(const framework::DDim& in_dims, const std::vector<int> axes,
const std::vector<int> starts, const std::vector<int> ends,
std::vector<int>* offsets, std::vector<int>* size) {
int cnt = 0;
for (int i = 0; i < in_dims.size(); ++i) {
int start = 0;
int end = in_dims[i];
int axis = axes[cnt];

// NOTE(zhiqiu): Becareful that cnt may > axes.size() and result in
// overflow.
int axis = cnt < static_cast<int>(axes.size()) ? axes[cnt] : -1;
if (axis == i) {
start = starts[cnt];
if (start < 0) {
Expand Down Expand Up @@ -63,10 +64,10 @@ class SliceNPUKernel : public framework::OpKernel<T> {
auto axes = ctx.Attr<std::vector<int>>("axes");
auto starts = ctx.Attr<std::vector<int>>("starts");
auto ends = ctx.Attr<std::vector<int>>("ends");
const auto& in_dims = input->dims();

out->mutable_data<T>(ctx.GetPlace());

auto in_dims = input->dims();
std::vector<int> offsets(in_dims.size());
std::vector<int> size(in_dims.size());

Expand All @@ -93,8 +94,7 @@ class SliceGradNPUKernel : public framework::OpKernel<T> {
auto axes = ctx.Attr<std::vector<int>>("axes");
auto starts = ctx.Attr<std::vector<int>>("starts");
auto ends = ctx.Attr<std::vector<int>>("ends");

auto in_dims = input->dims();
const auto& in_dims = input->dims();
int rank = in_dims.size();

std::vector<int> offsets(rank);
Expand Down
Loading

1 comment on commit 95f4d3d

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.