Skip to content

Commit

Permalink
fix performance problem caused by Conj (PaddlePaddle#38939)
Browse files Browse the repository at this point in the history
  • Loading branch information
zyfncg authored Jan 15, 2022
1 parent 050aa6f commit a887914
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion paddle/pten/kernels/complex_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#pragma once

#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/empty_kernel.h"
Expand All @@ -23,12 +24,29 @@ namespace pten {
template <typename T, typename Context>
void ConjKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);

template <typename T, typename Context>
// If T is complex
template <typename T,
typename Context,
std::enable_if_t<
std::is_same<T, paddle::platform::complex<float>>::value ||
std::is_same<T, paddle::platform::complex<double>>::value,
bool> = true>
DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) {
auto out_meta = UnchangedInferMeta(x.meta());
auto dense_out = pten::Empty<T, Context>(dev_ctx, std::move(out_meta));
ConjKernel<T>(dev_ctx, x, &dense_out);
return dense_out;
}

// If T is not complex
template <typename T,
typename Context,
std::enable_if_t<
!std::is_same<T, paddle::platform::complex<float>>::value &&
!std::is_same<T, paddle::platform::complex<double>>::value,
bool> = true>
DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) {
return x;
}

} // namespace pten

0 comments on commit a887914

Please sign in to comment.