diff --git a/test/cpp/test_aten_xla_tensor_4.cpp b/test/cpp/test_aten_xla_tensor_4.cpp index 0a0d84d463ae..ff6130ca1b95 100644 --- a/test/cpp/test_aten_xla_tensor_4.cpp +++ b/test/cpp/test_aten_xla_tensor_4.cpp @@ -1226,7 +1226,6 @@ TEST_F(AtenXlaTensorTest, TestPixelShuffle) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::permute_copy", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestSumToSize) { diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 1c91b29bc5b4..a7ae1c479640 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -3748,10 +3748,8 @@ at::Tensor XLANativeFunctions::narrow_copy_symint(const at::Tensor& self, at::Tensor XLANativeFunctions::pixel_shuffle(const at::Tensor& self, int64_t upscale_factor) { - XLA_CHECK( - !runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)); - return at::functionalization::functionalize_aten_op::call(self, upscale_factor); + return bridge::AtenFromXlaTensor(tensor_methods::pixel_shuffle( + bridge::GetXlaTensor(self), upscale_factor)); } at::Tensor XLANativeFunctions::pixel_unshuffle(const at::Tensor& self, diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 7391f8ff7141..af4daf286486 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -593,6 +593,34 @@ torch::lazy::NodePtr Pdist_forward(const torch::lazy::Value& input, std::move(lower_fn), 1); } +torch::lazy::NodePtr PixelShuffle(const torch::lazy::Value& input, + int64_t upscale_factor) { + auto lower_fn = [=](const XlaNode& node, + LoweringContext* loctx) -> XlaOpVector { + xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); + return node.ReturnOp(BuildPixelShuffle(xla_input, upscale_factor), loctx); + }; + auto lower_for_shape_fn = + [&](absl::Span operands) -> xla::XlaOp { + return BuildPixelShuffle(operands[0], upscale_factor); + }; + const xla::Shape& input_shape = GetXlaShape(input); + absl::Span dimensions = input_shape.dimensions(); + int64_t channels = dimensions[1]; + + if (channels % (upscale_factor * upscale_factor) != 0) { + XLA_ERROR() << "Number of channels must be divisible by the square of the " + "upscale factor."; + } + + return GenericOp( + torch::lazy::OpKind(at::aten::pixel_shuffle), {input}, + [&]() { + return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn); + }, + std::move(lower_fn), 1); +} + torch::lazy::NodePtr LinalgVectorNorm(const torch::lazy::Value& input, const at::Scalar& ord, std::vector dimensions, diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index 013474aa03c5..5d423b3b1eec 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -177,6 +177,9 @@ torch::lazy::NodePtr Pdist_forward(const torch::lazy::Value& input, const c10::optional& p, c10::optional dtype); +torch::lazy::NodePtr PixelShuffle(const torch::lazy::Value& input, + int64_t upscale_factor); + torch::lazy::NodePtr LinalgVectorNorm(const torch::lazy::Value& input, const at::Scalar& ord, std::vector dimensions, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 092c76891830..083948b340ae 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -998,6 +998,12 @@ XLATensorPtr pdist_forward(const XLATensorPtr& input, double p) { return input->CreateFrom(Pdist_forward(input->GetIrValue(), p, dtype)); } +XLATensorPtr pixel_shuffle(const XLATensorPtr& input, int64_t upscale_factor) { + c10::optional dtype = input->dtype_optional(); + torch::lazy::NodePtr node = PixelShuffle(input->GetIrValue(), upscale_factor); + return input->CreateFrom(node, dtype); +} + XLATensorPtr celu(const XLATensorPtr& input, const at::Scalar& alpha) { return input->CreateFrom(Celu(input->GetIrValue(), alpha)); } diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 0a704dea6364..8503a9917f79 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -286,6 +286,8 @@ XLATensorPtr cdist_forward(const XLATensorPtr& x1, const XLATensorPtr& x2, XLATensorPtr pdist_forward(const XLATensorPtr& input, double p); +XLATensorPtr pixel_shuffle(const XLATensorPtr& self, int64_t upscale_factor); + XLATensorPtr celu(const XLATensorPtr& input, const at::Scalar& alpha); void celu_(XLATensorPtr& input, const at::Scalar& alpha); diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 47b438a6e359..385cf8ab7258 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1194,6 +1194,27 @@ xla::XlaOp BuildCdistForward(xla::XlaOp x1, xla::XlaOp x2, xla::XlaOp p, } } +xla::XlaOp BuildPixelShuffle(xla::XlaOp input, int64_t upscale_factor) { + const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); + absl::Span dimensions = input_shape.dimensions(); + int64_t batch_size = dimensions[0]; + int64_t channels = dimensions[1]; + int64_t height = dimensions[2]; + int64_t width = dimensions[3]; + + int64_t new_channels = channels / (upscale_factor * upscale_factor); + int64_t new_height = height * upscale_factor; + int64_t new_width = width * upscale_factor; + + xla::XlaOp tmp = + xla::Reshape(input, {batch_size, new_channels, upscale_factor, + upscale_factor, height, width}); + tmp = xla::Transpose(tmp, {0, 1, 4, 2, 5, 3}); + xla::XlaOp output = + xla::Reshape(tmp, {batch_size, new_channels, new_height, new_width}); + return output; +} + xla::XlaOp BuildMultinomial(xla::XlaOp input, int64_t num_samples, bool replacement, xla::XlaOp seed) { const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index 8e632796c238..5382a84d165b 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -148,6 +148,8 @@ xla::XlaOp BuildAddcmul(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2, xla::XlaOp BuildCdistForward(xla::XlaOp x1, xla::XlaOp x2, xla::XlaOp p, bool use_hamming, bool use_chebyshev); +xla::XlaOp BuildPixelShuffle(xla::XlaOp input, int64_t upscale_factor); + xla::XlaOp BuildUpperTriangle(xla::XlaOp input); xla::XlaOp BuildCustomSharding(const xla::XlaOp& input);