Skip to content

Commit

Permalink
Merge with add_n_kernel PR PaddlePaddle#49854
Browse files Browse the repository at this point in the history
  • Loading branch information
zhhsplendid committed Jan 16, 2023
1 parent 9cfb275 commit 5dfd603
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 19 deletions.
4 changes: 1 addition & 3 deletions paddle/fluid/operators/controlflow/conditional_block_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,10 +384,8 @@ class ConditionalBlockGradOp : public ConditionalOp {
if (!input_tensor.IsInitialized() || input_tensor.numel() == 0) {
return;
}
if (input_tensor.dims().size() != 0) {
outside_tensor->Resize(input_tensor.dims());
}
VLOG(4) << "Assigning zero to " << outside_tensor;
outside_tensor->Resize(input_tensor.dims());
outside_tensor->mutable_data(place, input_tensor.dtype());
const platform::DeviceContext *dev_ctx =
platform::DeviceContextPool::Instance().Get(place);
Expand Down
7 changes: 0 additions & 7 deletions paddle/phi/kernels/cpu/add_n_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,6 @@ void AddNKernel(const Context& dev_ctx,
const std::vector<const TensorBase*>& x,
DenseTensor* out) {
size_t in_num = x.size();
for (const TensorBase* tb : x) {
if (tb->initialized() && DenseTensor::classof(tb)) {
auto* dt = static_cast<const DenseTensor*>(tb);
out->set_meta(dt->meta());
break;
}
}
dev_ctx.template Alloc<T>(out);

bool in_place = false;
Expand Down
9 changes: 0 additions & 9 deletions paddle/phi/kernels/gpu/add_n_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,6 @@ void AddNKernel(const Context &dev_ctx,
grids = dim3(CEIL_DIV(length, tile_size), 1, 1);
blocks = dim3(tile_size, 1, 1);
};

for (const TensorBase *tb : x) {
if (tb->initialized() && DenseTensor::classof(tb)) {
auto *dt = static_cast<const DenseTensor *>(tb);
out->set_meta(dt->meta());
break;
}
}

auto *out_ptr = dev_ctx.template Alloc<T>(out);
bool in_place = false;
if (x.size() > 0 && x[0]->initialized() && DenseTensor::classof(x[0])) {
Expand Down

0 comments on commit 5dfd603

Please sign in to comment.