diff --git a/paddle/phi/kernels/impl/logsumexp_kernel_impl.h b/paddle/phi/kernels/impl/logsumexp_kernel_impl.h index cc5057396265c..9c4ee034b7409 100644 --- a/paddle/phi/kernels/impl/logsumexp_kernel_impl.h +++ b/paddle/phi/kernels/impl/logsumexp_kernel_impl.h @@ -51,7 +51,7 @@ struct LogsumexpFunctor { auto x_mt = (*x).template cast(); auto y_dim = y->dimensions(); - auto x_max = x_mt.maximum(dim); + auto x_max = x_mt.maximum(dim).eval(); y->device(place) = (x_max + (x_mt - x_max.reshape(t_dim).broadcast(r_dim)).exp().sum(dim).log())