diff --git a/test/cpp/run_tests.sh b/test/cpp/run_tests.sh index ad8bb43de305..6e31511b0170 100755 --- a/test/cpp/run_tests.sh +++ b/test/cpp/run_tests.sh @@ -10,6 +10,12 @@ RMBUILD=1 LOGFILE=/tmp/pytorch_cpp_test.log XLA_EXPERIMENTAL="nonzero:masked_select" +# See Note [Keep Going] +CONTINUE_ON_ERROR=false +if [[ "$CONTINUE_ON_ERROR" == "1" ]]; then + set +e +fi + if [ "$DEBUG" == "1" ]; then BUILDTYPE="Debug" fi diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index d0eb5c973c61..961556784d30 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -917,7 +917,7 @@ TEST_F(AtenXlaTensorTest, TestSVD) { } } ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::svd", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_linalg_svd", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestLinalgSVD) { @@ -1024,7 +1024,7 @@ TEST_F(AtenXlaTensorTest, TestSLogDet) { } ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::slogdet", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_linalg_slogdet", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestTriangularSolve) { @@ -2333,7 +2333,8 @@ TEST_F(AtenXlaTensorTest, TestKlDiv) { AllClose(output, xla_output); }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::kl_div", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::mul", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::sub", cpp_test::GetIgnoredCounters()); } } } @@ -3447,7 +3448,7 @@ TEST_F(AtenXlaTensorTest, TestBartlettWindow) { ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); ExpectCounterChanged("xla::arange_out", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::slice", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::slice_copy", cpp_test::GetIgnoredCounters()); } } @@ -3579,7 +3580,7 @@ TEST_F(AtenXlaTensorTest, TestSiLUBackward) { device, testfn); }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::silu_backward", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::sigmoid", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestSigmoid) { @@ -3740,7 +3741,7 @@ TEST_F(AtenXlaTensorTest, TestMvOut) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::mv_out", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::mv", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestBatchAddBatchMatMul) { @@ -3831,7 +3832,7 @@ TEST_F(AtenXlaTensorTest, TestPinverse) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::svd", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_linalg_svd", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestEinsumOuter) { @@ -4166,7 +4167,6 @@ TEST_F(AtenXlaTensorTest, TestBilinear) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::_trilinear", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestUpsampleNearest2D) { @@ -5025,6 +5025,7 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceProd) { if (UsingTpu()) { GTEST_SKIP(); } + torch::Tensor a = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat)); torch::Tensor b = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat)); torch::Tensor c = torch::empty({3, 5}, torch::TensorOptions(torch::kLong)); @@ -5053,6 +5054,7 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceProdInPlace) { if (UsingTpu()) { GTEST_SKIP(); } + torch::Tensor a = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat)); torch::Tensor b = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat)); torch::Tensor c = torch::empty({3, 5}, torch::TensorOptions(torch::kLong)); @@ -5080,6 +5082,7 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceMin) { if (UsingTpu()) { GTEST_SKIP(); } + torch::Tensor a = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat)); torch::Tensor b = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat)); torch::Tensor c = torch::empty({3, 5}, torch::TensorOptions(torch::kLong)); @@ -5108,6 +5111,7 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceMinInPlace) { if (UsingTpu()) { GTEST_SKIP(); } + torch::Tensor a = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat)); torch::Tensor b = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat)); torch::Tensor c = torch::empty({3, 5}, torch::TensorOptions(torch::kLong)); @@ -5239,8 +5243,9 @@ TEST_F(AtenXlaTensorTest, TestInverse) { torch::Tensor xla_b = torch::inverse(xla_a); AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/1e-4); }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::inverse", cpp_test::GetIgnoredCounters()); + ExpectCounterNotChanged("aten::(?!_local_scalar_dense).*", + cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::linalg_inv_ex", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestIsnan) { @@ -5288,7 +5293,7 @@ TEST_F(AtenXlaTensorTest, TestExpandAs) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::expand", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::expand_copy", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestExpandSymIntStatic) { @@ -5303,7 +5308,8 @@ TEST_F(AtenXlaTensorTest, TestExpandSymIntStatic) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::expand_symint", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::expand_copy_symint", + cpp_test::GetIgnoredCounters()); } static c10::SymInt make_symint(const torch::lazy::NodePtr& p) { @@ -5369,7 +5375,7 @@ TEST_F(AtenXlaTensorTest, TestBroadcastTensors) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::expand", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::expand_copy", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestOneIndex) { @@ -5472,8 +5478,7 @@ TEST_F(AtenXlaTensorTest, TestMaskedScatter) { // calls. ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); } - ExpectCounterChanged("xla::masked_scatter_", - cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::masked_scatter", cpp_test::GetIgnoredCounters()); ResetCounters(); }); } @@ -6423,7 +6428,7 @@ TEST_F(AtenXlaTensorTest, TestPrelu) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::prelu", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_prelu_kernel", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestHardshrink) { @@ -7072,7 +7077,7 @@ TEST_F(AtenXlaTensorTest, TestSelu) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::selu", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::elu", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestSeluInPlace) { @@ -7087,7 +7092,7 @@ TEST_F(AtenXlaTensorTest, TestSeluInPlace) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::selu_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::elu", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestCelu) { @@ -7118,7 +7123,7 @@ TEST_F(AtenXlaTensorTest, TestCeluInPlace) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::celu_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::celu", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestGelu) { @@ -7231,7 +7236,7 @@ TEST_F(AtenXlaTensorTest, TestReshape) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::view", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::view_copy", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestResize) { @@ -7382,7 +7387,7 @@ TEST_F(AtenXlaTensorTest, TestNarrow) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::slice", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::slice_copy", cpp_test::GetIgnoredCounters()); } } } @@ -7406,7 +7411,7 @@ TEST_F(AtenXlaTensorTest, TestNarrowUpdate) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::slice", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::slice_copy", cpp_test::GetIgnoredCounters()); } } } @@ -7430,7 +7435,7 @@ TEST_F(AtenXlaTensorTest, TestNarrowUpdateBaseCheck) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::slice", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::slice_copy", cpp_test::GetIgnoredCounters()); } } } @@ -7463,7 +7468,7 @@ TEST_F(AtenXlaTensorTest, TestNarrowUpdateTwoSlices) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::slice", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::slice_copy", cpp_test::GetIgnoredCounters()); } } } @@ -7490,7 +7495,7 @@ TEST_F(AtenXlaTensorTest, TestNarrowUpdateView) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::slice", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::slice_copy", cpp_test::GetIgnoredCounters()); } } } @@ -7517,7 +7522,7 @@ TEST_F(AtenXlaTensorTest, TestNarrowInNarrowUpdate) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::slice", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::slice_copy", cpp_test::GetIgnoredCounters()); } } } @@ -7553,7 +7558,7 @@ TEST_F(AtenXlaTensorTest, TestViewAs) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::view", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::view_copy", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestLogSoftmax) { @@ -8912,7 +8917,7 @@ TEST_F(AtenXlaTensorTest, TestMaskedFill) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::masked_fill_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::masked_fill", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestMaskedFillInPlace) { @@ -8931,10 +8936,10 @@ TEST_F(AtenXlaTensorTest, TestMaskedFillInPlace) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::masked_fill_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::masked_fill", cpp_test::GetIgnoredCounters()); } -TEST_F(AtenXlaTensorTest, TestMaskedFillBroadcast) { +TEST_F(AtenXlaTensorTest, TestMaskedFillBroadcast1) { torch::Tensor input = torch::rand({2, 5, 4, 3}, torch::TensorOptions(torch::kFloat)); torch::Tensor mask = @@ -8949,7 +8954,25 @@ TEST_F(AtenXlaTensorTest, TestMaskedFillBroadcast) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::masked_fill_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::masked_fill", cpp_test::GetIgnoredCounters()); +} + +TEST_F(AtenXlaTensorTest, TestMaskedFillBroadcast2) { + torch::Tensor input = + torch::rand({2, 1}, torch::TensorOptions(torch::kFloat)); + torch::Tensor mask = + torch::randint(0, 2, {2, 3}, torch::TensorOptions(torch::kBool)); + torch::Scalar value(42); + torch::Tensor result = torch::masked_fill(input, mask, value); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_input = CopyToDevice(input, device); + torch::Tensor xla_mask = CopyToDevice(mask, device); + torch::Tensor xla_result = torch::masked_fill(xla_input, xla_mask, value); + AllClose(result, xla_result); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::masked_fill", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestFill) { @@ -9067,7 +9090,7 @@ TEST_F(AtenXlaTensorTest, TestPixelShuffle) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::permute", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::permute_copy", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestSumToSize) { @@ -9457,7 +9480,11 @@ TEST_F(AtenXlaTensorTest, TestDiagFlat) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::diag", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::zero_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::view_copy_symint", + cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_to_copy", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_copy_from", cpp_test::GetIgnoredCounters()); } } @@ -9990,7 +10017,7 @@ TEST_F(AtenXlaTensorTest, TestMeshgrid) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::view", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::view_copy", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestConstantPad) { @@ -11407,44 +11434,6 @@ TEST_F(AtenXlaTensorTest, TestEmbeddingBackward) { } } -TEST_F(AtenXlaTensorTest, TestAmpForeachNonFiniteCheckAndUnscale) { - XlaDeviceType hw_type = - static_cast(GetDefaultDevice()->type()); - if (hw_type != XlaDeviceType::GPU && hw_type != XlaDeviceType::CPU) { - return; - } - torch::Tensor grads0 = - torch::tensor({1, 2, 3, 4}, torch::TensorOptions(torch::kFloat)); - torch::Tensor grads1 = torch::tensor({1.0, 2.0, std::nan("1"), 4.0}, - torch::TensorOptions(torch::kFloat)); - torch::Tensor inv_scale = - torch::scalar_tensor(0.2, torch::TensorOptions(torch::kFloat)); - torch::Tensor found_inf = - torch::scalar_tensor(0, torch::TensorOptions(torch::kFloat)); - torch::Tensor grads_output0 = grads0 * inv_scale; - torch::Tensor found_inf_output0 = - torch::scalar_tensor(0, torch::TensorOptions(torch::kFloat)); - torch::Tensor found_inf_output1 = - torch::scalar_tensor(1, torch::TensorOptions(torch::kFloat)); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_grads0 = CopyToDevice(grads0, device); - torch::Tensor xla_inv_scale = CopyToDevice(inv_scale, device); - torch::Tensor xla_found_inf = CopyToDevice(found_inf, device); - torch::_amp_foreach_non_finite_check_and_unscale_(xla_grads0, xla_found_inf, - xla_inv_scale); - AllClose(grads_output0, xla_grads0, /*rtol=*/1e-2, /*atol=*/1e-4); - AllEqual(found_inf_output0, xla_found_inf); - - torch::Tensor xla_grads1 = CopyToDevice(grads1, device); - torch::_amp_foreach_non_finite_check_and_unscale_(xla_grads1, xla_found_inf, - xla_inv_scale); - AllEqual(found_inf_output1, xla_found_inf); - }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::_amp_foreach_non_finite_check_and_unscale_", - cpp_test::GetIgnoredCounters()); -} - TEST_F(AtenXlaTensorTest, TestAmpUpdateScale) { XlaDeviceType hw_type = static_cast(GetDefaultDevice()->type()); @@ -11754,46 +11743,6 @@ TEST_F(AtenXlaTensorTest, TestNanToNum) { ExpectCounterChanged("xla::nan_to_num", cpp_test::GetIgnoredCounters()); } -TEST_F(AtenXlaTensorTest, TestNanToNumInplace) { - for (torch::ScalarType scalar_type : - {torch::kHalf, torch::kFloat, torch::kDouble, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor input = - isFloatingType(scalar_type) - ? torch::tensor( - {1.0, std::nan("1"), std::numeric_limits::infinity(), - -std::numeric_limits::infinity()}, - torch::TensorOptions(scalar_type)) - : torch::randint(0, 100, {3, 4}, torch::TensorOptions(scalar_type)); - torch::Tensor input_copy = input.clone(); - input.nan_to_num_(); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_input = CopyToDevice(input_copy, device); - xla_input.nan_to_num_(); - if (static_cast( - bridge::AtenDeviceToXlaDevice(device).type()) == - XlaDeviceType::TPU && - scalar_type == torch::kDouble) { - // Since TPU converts double to float (unlike CPU), the Inf entries are - // expected to be different. Skipping checks for Inf entries. - AllEqual(input[0], xla_input[0]); - AllEqual(input[1], xla_input[1]); - } else { - AllClose(input, xla_input); - } - }); - input = input_copy.clone(); - input.nan_to_num_(/*nan=*/1.0, /*posinf=*/2.0, /*neginf=*/3.0); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_input = CopyToDevice(input_copy, device); - xla_input.nan_to_num_(/*nan=*/1.0, /*posinf=*/2.0, /*neginf=*/3.0); - AllClose(input, xla_input); - }); - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::nan_to_num", cpp_test::GetIgnoredCounters()); -} - TEST_F(AtenXlaTensorTest, TestNanToNumOut) { for (torch::ScalarType scalar_type : {torch::kHalf, torch::kFloat, torch::kDouble, torch::kShort, torch::kInt, diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 7e5bdd92966c..86b1b7b55387 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -122,6 +122,7 @@ def test_simple_model(self): assert torch.allclose(res_cpu_3, res_xla_dynamo_3.cpu()) assert torch.allclose(input.grad, xla_input.grad.cpu()) + @unittest.skip("Broke by functionalization, #4680") def test_resnet18(self): torch._dynamo.reset() met.clear_counters() @@ -216,6 +217,7 @@ def test_simple_model(self): assert torch.allclose(input.grad, xla_input.grad.cpu()) assert torch.allclose(input, xla_input.cpu()) + @unittest.skip("Broke by functionalization, #4680") def test_resnet18(self): torch._dynamo.reset() met.clear_counters() diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py index 7deb443752f7..15946ceae6f7 100644 --- a/test/pytorch_test_base.py +++ b/test/pytorch_test_base.py @@ -201,6 +201,9 @@ 'test_random_to_xla', # doesn't raise 'test_copy_', # test against complex32 which is nto supported 'test_assertRaisesRegex_ignore_msg_non_native_device_xla', # segfault on wheel sanity test + 'test_index_reduce', # Broke by functionalization, pytorch/pytorch#94471 + 'test_logcumsumexp_xla', # doesn't raise, pytorch/pytorch#92912 + 'test_narrow_copy_non_contiguous', # the test is added for CPU, pytorch/pytorch#91789 }, # test_view_ops.py @@ -228,6 +231,8 @@ 'test_empty_ndim_index', # expecting a different runtime error 'test_index_put_byte_indices_xla', # expecting a different runtime error }, + + # test_indexing.py 'NumpyTestsXLA': { 'test_trivial_fancy_out_of_bounds', # expecting a different runtime error 'test_boolean_assignment_value_mismatch', # expecting a different runtime error @@ -277,6 +282,9 @@ 'test_upsamplingBicubic2d_correctness_xla', # FIXME! Got dtypes torch.float32 and torch.float64 'test_CTCLoss_no_batch_dim_xla', # Value out of range 'test_upsamplingBilinear2d_xla', # precision on GPU/TPU, slow compilation on CPU + # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0 + 'test_GRU_grad_and_gradgrad_xla_float64', # Broke by functionalization, #4711 + 'test_LSTM_grad_and_gradgrad_xla_float64', # Broke by functionalization, #4711 }, # test/nn/test_dropout.py diff --git a/test/spmd/test_xla_virtual_device.py b/test/spmd/test_xla_virtual_device.py index 0cae2e8db28f..7b292ba14173 100644 --- a/test/spmd/test_xla_virtual_device.py +++ b/test/spmd/test_xla_virtual_device.py @@ -48,8 +48,7 @@ def test_model_weight_metrics(self): model = nn.Linear(128, 64).to(xm.xla_device()) xs.mark_sharding(model.weight, self._get_mesh((1, self.n_devices)), partition_spec) - self.assertIn("VirtualDeviceUsage", met.counter_names()) - self.assertNotEqual(met.counter_value("VirtualDeviceUsage"), 0) + self.assertNotIn("VirtualDeviceUsage", met.counter_names()) def test_no_sharding(self): t1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 068954900934..82bbc88a2ea1 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -56,6 +56,10 @@ def test_simple_expand_on_2d_tensor(self): self.assertEqual(t4.shape[0], 2) self.assertEqual(t4.shape[1], size2) + # size_clone should be called as part of decomposition from + # the python dispatcher. + self.assertGreater(met.counter_value("xla::size_clone"), 0) + def test_simple_expand_add_dimension(self): size1 = 5 size2 = 2 @@ -178,6 +182,57 @@ def test_expand_symint_correctness(self): self.assertEqual(t3.shape[0], 2) self.assertEqual(expand_out_aten.cpu(), expand_out_xla.cpu()) + def test_sizeGe(self): + met.clear_all() + + size1 = 5 + size2 = 2 + t1 = torch.zeros([size1, size2], device=dev) + t1[3][0] = 1 + # t2 has size [<=10, 2] + t2 = torch.nonzero(t1) + # Create a SizeAdd IR node. + # t2.shape[1] generates a SizeConstant node. + dyn_size = t2.shape[0] >= t2.shape[1] + self.assertGreater(met.counter_value("xla::size_ge"), 0) + # Exercises SizeGe::getDynamicValue. + dynamic_size = int(dyn_size) + self.assertEqual(dynamic_size, 0) + + def test_sizeLt(self): + met.clear_all() + + size1 = 5 + size2 = 2 + t1 = torch.zeros([size1, size2], device=dev) + t1[3][0] = 1 + # t2 has size [<=10, 2] + t2 = torch.nonzero(t1) + # Create a SizeAdd IR node. + # t2.shape[1] generates a SizeConstant node. + dyn_size = t2.shape[0] < t2.shape[1] + self.assertGreater(met.counter_value("xla::size_lt"), 0) + # Exercises SizeLt::getDynamicValue. + dynamic_size = int(dyn_size) + self.assertEqual(dynamic_size, 1) + + def test_sizeNe(self): + met.clear_all() + + size1 = 5 + size2 = 2 + t1 = torch.zeros([size1, size2], device=dev) + t1[3][0] = 1 + # t2 has size [<=10, 2] + t2 = torch.nonzero(t1) + # Create a SizeAdd IR node. + # t2.shape[1] generates a SizeConstant node. + dyn_size = t2.shape[0] != t2.shape[1] + self.assertGreater(met.counter_value("xla::size_ne"), 0) + # Exercises SizeNe::getDynamicValue. + dynamic_size = int(dyn_size) + self.assertEqual(dynamic_size, 1) + if __name__ == '__main__': assert os.environ['XLA_EXPERIMENTAL'] != '' diff --git a/test/test_operations.py b/test/test_operations.py index d60e946e5484..0a48a532b6bc 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -385,9 +385,6 @@ def test_masked_select_shape(self): torch.masked_select(x, mask), 0) self.assertEqual(x_dim0_shape.item(), 3) - @unittest.skip( - "Temporarily disable test. See https://github.com/pytorch/xla/issues/4501" - ) def test_nonzero_cast(self): t1 = torch.ones(5, 2, device=xm.xla_device()) # Result of the nonzero should be the index type. Currently @@ -457,6 +454,55 @@ def test_cat_empty_tensor(self): x_cat = torch.cat([x, empty_tensor_xla], 0) self.assertEqual(t_cat.data, x_cat.data.cpu()) + def test_nan_to_num_in_place(self): + t = torch.tensor([float('nan'), float('nan'), -float('nan'), 3.14]) + + def fn(x): + x.nan_to_num_(1.0, 2.0, 3.0) + return x + + self.runAtenTest(t, fn) + + @skipOnTpu + def test_nan_to_num_in_place_with_inf(self): + # Since TPU converts double to float (unlike CPU), the Inf entries are + # expected to be different. Skipping tests for Inf entries. + t = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14]) + + def fn(x): + x.nan_to_num_(1.0, 2.0, 3.0) + return x + + self.runAtenTest(t, fn) + + @skipOnTpu + def test_amp_foreach_non_finite_check_and_unscale_(self): + # Since TPU converts double to float (unlike CPU), the Inf entries are + # expected to be different. Skipping tests for Inf entries. + grads0 = torch.tensor([1, 2, 3, 4], dtype=torch.float32) + grads1 = torch.tensor([1.0, 2.0, float('nan'), 4.0], dtype=torch.float32) + inv_scale = torch.tensor(0.2, dtype=torch.float32) + found_inf = torch.tensor(0, dtype=torch.float32) + grads_output0 = grads0 * inv_scale + found_inf_output0 = torch.tensor(0, dtype=torch.float32) + found_inf_output1 = torch.tensor(1, dtype=torch.float32) + + xla_device = xm.xla_device() + xla_grads0 = grads0.to(xla_device) + xla_inv_scale = inv_scale.to(xla_device) + xla_found_inf = found_inf.to(xla_device) + torch._amp_foreach_non_finite_check_and_unscale_([xla_grads0], + xla_found_inf, + xla_inv_scale) + self.assertEqual(grads_output0, xla_grads0, prec=1e-4) + self.assertEqual(found_inf_output0, xla_found_inf) + + xla_grads1 = grads1.to(xla_device) + torch._amp_foreach_non_finite_check_and_unscale_([xla_grads1], + xla_found_inf, + xla_inv_scale) + self.assertEqual(found_inf_output1, xla_found_inf) + def test_masked_fill_with_tensor(self): input = _gen_tensor(2, 5, 4, 3) mask = _gen_mask(input.size()) @@ -609,6 +655,9 @@ def test_empty_advanced_indexing(self): xla_result = xla_base[:, torch.empty(0, 6, dtype=torch.int64)] self.assertEqual(result, xla_result) + @unittest.skip( + "grad_input produces wrong results after functionalization. pytorch/pytorch#91199" + ) def test_empty_strided(self): xla_device = xm.xla_device() m = nn.Conv1d(4, 6, kernel_size=3, groups=2) @@ -632,6 +681,8 @@ def test_empty_strided(self): xla_output.sum() + sum(map(lambda x: x.sum(), xla_grad_input)), (xla_a, xla_output) + tuple(xla_m.parameters()), retain_graph=True) + self.assertEqual(output, xla_output, prec=1e-4) + self.assertEqual(grad_input, xla_grad_input, prec=1e-4) self.assertEqual(grad_grad_input, xla_grad_grad_input, prec=1e-4) def test_clamp(self): @@ -903,14 +954,6 @@ def test_inplace_view_non_contig(self): x.sum().backward() self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1], [1, 1]]) - def test_view_data_update(self): - a = torch.zeros(4, device=xm.xla_device()) - v = a.view(2, 2) - a.data = a.data + 1 - self.assertEqual(a.tolist(), [1, 1, 1, 1]) - # Upadting a.data should not update v's value. - self.assertEqual(v.tolist(), [[0.0, 0.0], [0.0, 0.0]]) - def test_view_out_computation(self): def func(a, b): @@ -928,25 +971,32 @@ def test_set(self): t1 = torch.zeros(50, device=xm.xla_device()) t1 += 1 xm.mark_step() - self.assertEqual(met.counter_value('DestroyXlaTensor'), 2) - - t1.data = torch.zeros(20, device=xm.xla_device()) self.assertEqual(met.counter_value('DestroyXlaTensor'), 3) - t1.set_(torch.zeros(10, device=xm.xla_device())) + t2 = torch.zeros(10, device=xm.xla_device()) self.assertEqual(met.counter_value('DestroyXlaTensor'), 4) - t2 = torch.zeros(10, device=xm.xla_device()) t1.set_(t2) + self.assertEqual(met.counter_value('DestroyXlaTensor'), 6) + # shouldn't crash - t2.cpu() + self.assertTrue(torch.allclose(t2.cpu(), torch.zeros(10))) + + def test_replace_xla_tensor(self): + met.clear_all() - def test_view_data_slice(self): t1 = torch.zeros(50, device=xm.xla_device()) - t1_slice = t1.data[:5] - # Assigning the view back to origonal tensor's data should be OK. - t1.data = t1_slice - self.assertEqual(t1.tolist(), [0, 0, 0, 0, 0]) + t1 += 1 + xm.mark_step() + self.assertEqual(met.counter_value('DestroyXlaTensor'), 3) + + t2 = torch.zeros(10, device=xm.xla_device()) + self.assertEqual(met.counter_value('DestroyXlaTensor'), 4) + torch_xla._XLAC._replace_xla_tensor(t1, t2) + self.assertEqual(met.counter_value('DestroyXlaTensor'), 5) + + # shouldn't crash + self.assertTrue(torch.allclose(t2.cpu(), torch.zeros(10))) def test_pred_type(self): xla_device = xm.xla_device() @@ -1583,6 +1633,18 @@ def test_fn(*indices): for dtype in (torch.long, torch.int32, torch.bool) ], test_fn) + def test_conv2d_backward(self): + # Somehow eager cpu produces different results than us, and + # therefore we can't compare eager and xla. + conv = nn.Conv2d(1, 1, kernel_size=1).to('xla') + input = torch.tensor([[[[2077.0]]]]).to('xla') + + output = conv(input) + loss = torch.sum(output) + loss.backward() + self.assertTrue( + torch.allclose(conv.weight.grad.cpu(), torch.tensor([[[[2077.0]]]]))) + class MNISTComparator(nn.Module): diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin new file mode 100644 index 000000000000..37017a9aaedf --- /dev/null +++ b/torch_patches/.torch_pin @@ -0,0 +1 @@ +#94537 diff --git a/torch_xla/csrc/aten_cpu_fallback.cpp b/torch_xla/csrc/aten_cpu_fallback.cpp index 6bd8bbaf219a..ad66874f6ba9 100644 --- a/torch_xla/csrc/aten_cpu_fallback.cpp +++ b/torch_xla/csrc/aten_cpu_fallback.cpp @@ -42,7 +42,9 @@ void xla_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { } // Call the actual boxed CPU fallback. - at::native::cpu_fallback(op, stack); + // Set error_on_views as XLA should take care + // of all view ops after functionalization. + at::native::cpu_fallback(op, stack, true); } TORCH_LIBRARY_IMPL(_, XLA, m) { diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index 54d4f8100274..8f0fa3350412 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -1,5 +1,8 @@ #include "torch_xla/csrc/aten_xla_bridge.h" +#include +#include + #include #include #include @@ -52,7 +55,8 @@ AtenXlaDeviceMapper* AtenXlaDeviceMapper::Get() { } XLATensorImpl* GetXlaTensorImpl(const at::Tensor& tensor) { - return dynamic_cast(tensor.unsafeGetTensorImpl()); + auto inner_tensor = torch::lazy::maybe_unwrap_functional(tensor); + return dynamic_cast(inner_tensor.unsafeGetTensorImpl()); } } // namespace @@ -77,10 +81,11 @@ XLATensorPtr GetXlaTensor(const at::Tensor& tensor) { } void ReplaceXlaTensor(const at::Tensor& tensor, XLATensorPtr new_xla_tensor) { + auto inner_tensor = torch::lazy::maybe_unwrap_functional(tensor); XLATensorImpl* impl = - dynamic_cast(tensor.unsafeGetTensorImpl()); + dynamic_cast(inner_tensor.unsafeGetTensorImpl()); XLA_CHECK(impl != nullptr) - << "Input tensor is not an XLA tensor: " << tensor.toString(); + << "Input tensor is not an XLA tensor: " << inner_tensor.toString(); impl->set_tensor(std::move(new_xla_tensor)); } @@ -108,8 +113,12 @@ XLATensorPtr GetOrCreateXlaTensor(const at::Tensor& tensor, if (!tensor.defined()) { return XLATensorPtr(); } + auto inner_tensor = torch::lazy::maybe_unwrap_functional(tensor); + if (!inner_tensor.defined()) { + return XLATensorPtr(); + } auto xtensor = TryGetXlaTensor(tensor); - return xtensor ? xtensor : XLATensor::Create(tensor, device); + return xtensor ? xtensor : XLATensor::Create(inner_tensor, device); } XLATensorPtr GetOrCreateXlaTensor(const c10::optional& tensor, @@ -118,7 +127,8 @@ XLATensorPtr GetOrCreateXlaTensor(const c10::optional& tensor, return XLATensorPtr(); } auto xtensor = TryGetXlaTensor(*tensor); - return xtensor ? xtensor : XLATensor::Create(*tensor, device); + auto inner_tensor = torch::lazy::maybe_unwrap_functional(*tensor); + return xtensor ? xtensor : XLATensor::Create(inner_tensor, device); } std::vector GetOrCreateXlaTensors( @@ -139,14 +149,16 @@ std::vector XlaCreateTensorList(const at::ITensorListRef& tensors) { std::vector to_translate(tensors.size()); size_t ix = 0; for (const auto& tensor : tensors) { - if (tensor.defined()) { - auto xtensor = TryGetXlaTensor(tensor); - if (xtensor) { - to_translate[ix] = true; - xla_tensors.push_back(xtensor); - } else { - aten_xla_tensors[ix] = tensor; - } + if (!tensor.defined()) continue; + auto inner_tensor = torch::lazy::maybe_unwrap_functional(tensor); + if (!inner_tensor.defined()) continue; + + auto xtensor = TryGetXlaTensor(tensor); + if (xtensor) { + to_translate[ix] = true; + xla_tensors.push_back(xtensor); + } else { + aten_xla_tensors[ix] = tensor; } ++ix; } @@ -156,7 +168,12 @@ std::vector XlaCreateTensorList(const at::ITensorListRef& tensors) { // positions. for (size_t i = 0, defined_pos = 0; i < tensors.size(); ++i) { if (to_translate[i]) { - aten_xla_tensors[i] = std::move(defined_aten_xla_tensors[defined_pos++]); + auto tensor = defined_aten_xla_tensors[defined_pos++]; + XLA_CHECK(!at::functionalization::impl::isFunctionalTensor(tensor)) + << "Expected non-functional tensor!"; + // This function is responsible for returning CPU tensors. + // So we do not want to wrap the outputs into FunctionalTensorWrappers. + aten_xla_tensors[i] = tensor; } } return aten_xla_tensors; @@ -328,9 +345,23 @@ at::Tensor XlaToAtenTensor(XLATensorPtr xla_tensor, } at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor) { - return xla_tensor ? at::Tensor(c10::make_intrusive( - std::move(xla_tensor))) - : at::Tensor(); + if (xla_tensor) { + auto out = + at::Tensor(c10::make_intrusive(std::move(xla_tensor))); + // See Note [Lazy Tensor Functionalization] + if (c10::impl::tls_local_dispatch_key_set().excluded_.has( + c10::DispatchKey::Functionalize)) { + // Invariant: if the functionalization key is in the exclude set, then + // we're expected to return an ordinary tensor, which will be "lifted" + // into a functional wrapper later. + return out; + } else { + auto wrapped = at::functionalization::impl::to_functional_tensor(out); + return wrapped; + } + } else { + return at::Tensor(); + } } std::vector AtenFromXlaTensors( diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index c6e7abc1ff81..5ff2975835b8 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include @@ -507,6 +508,54 @@ std::vector XLANativeFunctions::_to_cpu(at::TensorList tensors) { return bridge::XlaCreateTensorList(tensors); } +// TODO(alanwaketan): Improve the error messages. +// Let's rewrite it without reusing other native functions. +at::Tensor XLANativeFunctions::_to_copy( + const at::Tensor& self, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory, bool non_blocking, + c10::optional memory_format) { + TORCH_LAZY_FN_COUNTER("xla::"); + + auto options = self.options(); + // I put each of these setters in a conditional instead of doing + // `self.options().dtype(dtype).layout(layout)... because calling + // .dtype(nullopt) on an options() that already has dtype appears to wipe it + if (dtype) { + options = options.dtype(dtype); + } + if (layout) { + options = options.layout(layout); + } + if (device) { + options = options.device(device); + } + if (pin_memory) { + options = options.pinned_memory(pin_memory); + } + if (memory_format) { + options = options.memory_format(memory_format); + } + + // Case 1: Materialize the tensor. + if (device && device->type() != c10::kXLA) { + XLA_CHECK(device->type() == c10::kCPU) + << "only cpu device is supported in _to_copy."; + auto self_tensor = bridge::GetXlaTensor(self); + auto eager_tensor = self_tensor->ToTensor(/*detached=*/true); + + // Use the eager .to on the eager tensor. + return eager_tensor.to(options, non_blocking, /*copy=*/true); + } + + // Case 2: Create a new XLA tensor with the supplied data and options. + auto new_tensor = + empty_symint(self.sym_sizes(), at::typeMetaToScalarType(options.dtype()), + options.layout(), options.device(), options.pinned_memory(), + options.memory_format_opt()); + return _copy_from(self, new_tensor, non_blocking); +} + at::Tensor& XLANativeFunctions::_index_put_impl_( at::Tensor& self, const c10::List>& indices, const at::Tensor& values, bool accumulate, bool /* unsafe */) { @@ -515,6 +564,17 @@ at::Tensor& XLANativeFunctions::_index_put_impl_( accumulate); } +std::tuple +XLANativeFunctions::_linalg_slogdet(const at::Tensor& self) { + TORCH_LAZY_FN_COUNTER("xla::"); + XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + auto outputs = tensor_methods::slogdet(self_tensor); + return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), + bridge::AtenFromXlaTensor(std::get<1>(outputs)), + bridge::AtenFromXlaTensor(XLATensorPtr()), + bridge::AtenFromXlaTensor(XLATensorPtr())); +} + at::Tensor XLANativeFunctions::_log_softmax(const at::Tensor& self, int64_t dim, bool half_to_float) { TORCH_LAZY_FN_COUNTER("xla::"); @@ -558,19 +618,10 @@ at::Tensor XLANativeFunctions::_softmax_backward_data( bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(output), dim)); } -at::Tensor XLANativeFunctions::_trilinear( - const at::Tensor& i1, const at::Tensor& i2, const at::Tensor& i3, - at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, - at::IntArrayRef sumdim, int64_t unroll_dim) { - TORCH_LAZY_FN_COUNTER("xla::"); - return at::native::_trilinear(i1, i2, i3, expand1, expand2, expand3, sumdim, - unroll_dim); -} - at::Tensor XLANativeFunctions::_unsafe_view(const at::Tensor& self, at::IntArrayRef size) { TORCH_LAZY_FN_COUNTER("xla::"); - return view_symint(self, c10::fromIntArrayRefSlow(size)); + return view_copy_symint(self, c10::fromIntArrayRefSlow(size)); } at::Tensor XLANativeFunctions::add(const at::Tensor& self, @@ -613,7 +664,7 @@ at::Tensor XLANativeFunctions::addmm(const at::Tensor& self, /*bias=*/bridge::GetXlaTensor(self))); } -at::Tensor XLANativeFunctions::alias(const at::Tensor& self) { +at::Tensor XLANativeFunctions::alias_copy(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::alias(bridge::GetXlaTensor(self))); @@ -649,7 +700,7 @@ at::Tensor XLANativeFunctions::argmin(const at::Tensor& self, tensor_methods::argmin(bridge::GetXlaTensor(self))); } -at::Tensor XLANativeFunctions::as_strided( +at::Tensor XLANativeFunctions::as_strided_copy( const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset) { TORCH_LAZY_FN_COUNTER("xla::"); @@ -667,22 +718,28 @@ at::Tensor XLANativeFunctions::as_strided( XlaHelpers::I64Optional(storage_offset))); } -const at::Tensor& XLANativeFunctions::as_strided_( - const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, +at::Tensor XLANativeFunctions::as_strided_scatter( + const at::Tensor& base, const at::Tensor& mutated_view, + at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset) { TORCH_LAZY_FN_COUNTER("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + auto base_ = bridge::GetXlaTensor(base); auto xsize = XlaHelpers::I64List(size); auto xstride = XlaHelpers::I64List(stride); - if (!AsStrided::StrideIsSupported(self_tensor->shape(), xsize, xstride, + if (!AsStrided::StrideIsSupported(base_->shape(), xsize, xstride, storage_offset.value_or(0))) { return at::native::call_fallback_fn< - &xla_cpu_fallback, ATEN_OP(as_strided_)>::call(self, size, stride, - storage_offset); + &xla_cpu_fallback, ATEN_OP(as_strided_scatter)>::call(base, + mutated_view, + size, stride, + storage_offset); } - tensor_methods::as_strided_(self_tensor, std::move(xsize), std::move(xstride), - XlaHelpers::I64Optional(storage_offset)); - return self; + auto mutated_view_ = bridge::GetXlaTensor(mutated_view); + auto base_clone = tensor_methods::clone(base_); + auto base_clone_slice = tensor_methods::as_strided( + base_clone, xsize, xstride, XlaHelpers::I64Optional(storage_offset)); + tensor_methods::copy_(base_clone_slice, mutated_view_); + return bridge::AtenFromXlaTensor(base_clone); } at::Tensor XLANativeFunctions::atan2(const at::Tensor& self, @@ -975,6 +1032,19 @@ XLANativeFunctions::convolution_backward_overrideable( : at::Tensor()); } +at::Tensor XLANativeFunctions::copy(const at::Tensor& self, + const at::Tensor& src, bool non_blocking) { + TORCH_LAZY_FN_COUNTER("xla::"); + return _copy_from(src, self, non_blocking); +} + +at::Tensor& XLANativeFunctions::copy_(at::Tensor& self, const at::Tensor& src, + bool non_blocking) { + TORCH_LAZY_FN_COUNTER("xla::"); + _copy_from(src, self, non_blocking); + return self; +} + at::Tensor XLANativeFunctions::cross(const at::Tensor& self, const at::Tensor& other, c10::optional dim) { @@ -1015,19 +1085,42 @@ at::Tensor XLANativeFunctions::cumsum(const at::Tensor& self, int64_t dim, tensor_methods::cumsum(self_tensor, dim, dtype)); } +// TODO(alanwaketan): Let's rewrite a without reusing other native functions. +at::Tensor XLANativeFunctions::detach_copy(const at::Tensor& self) { + TORCH_LAZY_FN_COUNTER("xla::"); + auto new_tensor = + empty_symint(self.sym_sizes(), at::typeMetaToScalarType(self.dtype()), + c10::nullopt, self.device(), c10::nullopt, c10::nullopt); + return _copy_from(self, new_tensor, true); +} + at::Tensor XLANativeFunctions::diag(const at::Tensor& self, int64_t diagonal) { TORCH_LAZY_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::diag(bridge::GetXlaTensor(self), diagonal)); } -at::Tensor XLANativeFunctions::diagonal(const at::Tensor& self, int64_t offset, - int64_t dim1, int64_t dim2) { +at::Tensor XLANativeFunctions::diagonal_copy(const at::Tensor& self, + int64_t offset, int64_t dim1, + int64_t dim2) { TORCH_LAZY_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::diagonal(bridge::GetXlaTensor(self), offset, dim1, dim2)); } +at::Tensor XLANativeFunctions::diagonal_scatter(const at::Tensor& base, + const at::Tensor& mutated_view, + int64_t offset, int64_t dim1, + int64_t dim2) { + auto base_ = bridge::GetXlaTensor(base); + auto mutated_view_ = bridge::GetXlaTensor(mutated_view); + auto base_clone = tensor_methods::clone(base_); + auto base_clone_slice = + tensor_methods::diagonal(base_clone, offset, dim1, dim2); + tensor_methods::copy_(base_clone_slice, mutated_view_); + return bridge::AtenFromXlaTensor(base_clone); +} + at::Tensor XLANativeFunctions::div(const at::Tensor& self, const at::Tensor& other) { return torch_xla::XLANativeFunctions::div(self, other, @@ -1102,18 +1195,6 @@ at::Tensor XLANativeFunctions::elu_backward(const at::Tensor& grad_output, bridge::GetXlaTensor(self_or_result))); } -at::Tensor XLANativeFunctions::embedding_symint(const at::Tensor& weight, - const at::Tensor& indices, - c10::SymInt padding_idx, - bool scale_grad_by_freq, - bool sparse) { - TORCH_LAZY_FN_COUNTER("xla::"); - // TODO: for now route to native, which dispatches supported XLA operations. - // We need to make use of the TPU embedding core here eventually. - return at::native::embedding_symint(weight, indices, padding_idx, - scale_grad_by_freq, sparse); -} - at::Tensor XLANativeFunctions::embedding_dense_backward( const at::Tensor& grad_output, const at::Tensor& indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { @@ -1155,13 +1236,13 @@ at::Tensor XLANativeFunctions::empty_strided_symint( auto stride = C10_AS_INTARRAYREF_SLOW(sym_stride); at::Tensor t = empty_symint(sym_size, dtype, layout, device, pin_memory, c10::nullopt); - return torch_xla::XLANativeFunctions::as_strided(t, size, stride, - /*storage_offset=*/0); + return torch_xla::XLANativeFunctions::as_strided_copy(t, size, stride, + /*storage_offset=*/0); } -at::Tensor XLANativeFunctions::expand_symint(const at::Tensor& self, - at::SymIntArrayRef sym_size, - bool implicit) { +at::Tensor XLANativeFunctions::expand_copy_symint(const at::Tensor& self, + at::SymIntArrayRef sym_size, + bool implicit) { TORCH_LAZY_FN_COUNTER("xla::"); c10::optional size = c10::asIntArrayRefSlowOpt(sym_size); if (size.has_value()) { @@ -1368,6 +1449,17 @@ at::Tensor& XLANativeFunctions::index_put_( at::Tensor& self, const c10::List>& indices, const at::Tensor& values, bool accumulate) { TORCH_LAZY_FN_COUNTER("xla::"); + bool indices_on_cpu_or_xla = + std::all_of(indices.begin(), indices.end(), + [=](const c10::optional& opt) { + return opt.has_value() && opt->defined() + ? (opt->is_cpu() || bridge::IsXlaTensor(*opt)) + : true; + }); + XLA_CHECK(bridge::IsXlaTensor(self) && indices_on_cpu_or_xla) + << "indices should be either on cpu or on the same" + << " device as the indexed tensor (XLA)." + << " When using XLA, the indexed tensor must be an XLA tensor."; XLA_CHECK(self.scalar_type() == values.scalar_type()); CanonicalIndexInfo canonical_index_info = GetCanonicalIndexInfo(self, indices); @@ -1455,6 +1547,41 @@ at::Tensor XLANativeFunctions::lerp(const at::Tensor& self, bridge::GetXlaTensor(self), bridge::GetXlaTensor(end), weight)); } +at::Tensor XLANativeFunctions::lift(const at::Tensor& tensor) { + TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_INTERNAL_ASSERT( + !at::functionalization::impl::isFunctionalTensor(tensor)); + return at::functionalization::impl::to_functional_tensor(tensor); +} + +at::Tensor XLANativeFunctions::lift_fresh(const at::Tensor& tensor) { + TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_INTERNAL_ASSERT( + !at::functionalization::impl::isFunctionalTensor(tensor)); + return at::functionalization::impl::to_functional_tensor(tensor); +} + +std::tuple XLANativeFunctions::linalg_inv_ex( + const at::Tensor& self, bool check_errors) { + TORCH_LAZY_FN_COUNTER("xla::"); + // The default value for `check_errors` is False. And for now, we don't + // do anything differently based on this flag. So when it's set to True, + // we'll fallback to CPU. + if (check_errors) { + return at::native::call_fallback_fn< + &xla_cpu_fallback, ATEN_OP(linalg_inv_ex)>::call(self, check_errors); + } + auto common_device = torch_xla::bridge::GetXlaDevice(self); + TORCH_INTERNAL_ASSERT(common_device); + torch::lazy::NodePtr node = + torch::lazy::MakeNode(bridge::GetXlaTensor(self)->GetIrValue()); + auto result = torch_xla::XLATensor::Create(std::move(node), *common_device); + auto info = tensor_methods::full_like(result, 0, result->GetDevice(), + at::ScalarType::Int); + return std::make_tuple(bridge::AtenFromXlaTensor(result), + bridge::AtenFromXlaTensor(info)); +} + at::Tensor XLANativeFunctions::linspace(const at::Scalar& start, const at::Scalar& end, int64_t steps, c10::optional dtype, @@ -1516,33 +1643,32 @@ at::Tensor XLANativeFunctions::xlogy(const at::Tensor& self, bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); } -at::Tensor& XLANativeFunctions::masked_fill_(at::Tensor& self, - const at::Tensor& mask, - const at::Scalar& value) { - TORCH_LAZY_FN_COUNTER("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - tensor_methods::masked_fill_(self_tensor, bridge::GetXlaTensor(mask), value); - return self; -} - -at::Tensor& XLANativeFunctions::masked_fill_(at::Tensor& self, - const at::Tensor& mask, - const at::Tensor& value) { +at::Tensor XLANativeFunctions::masked_fill(const at::Tensor& self, + const at::Tensor& mask, + const at::Tensor& value) { TORCH_LAZY_FN_COUNTER("xla::"); XLA_CHECK_EQ(value.dim(), 0) << "masked_fill_ only supports a 0-dimensional " << "value tensor, but got tensor " << "with " << value.dim() << " dimension(s)."; - return masked_fill_(self, mask, value.item()); + return masked_fill(self, mask, value.item()); } -at::Tensor& XLANativeFunctions::masked_scatter_(at::Tensor& self, - const at::Tensor& mask, - const at::Tensor& source) { +at::Tensor XLANativeFunctions::masked_fill(const at::Tensor& self, + const at::Tensor& mask, + const at::Scalar& value) { TORCH_LAZY_FN_COUNTER("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - tensor_methods::masked_scatter_(self_tensor, bridge::GetXlaTensor(mask), - bridge::GetXlaTensor(source)); - return self; + return bridge::AtenFromXlaTensor(tensor_methods::masked_fill( + self_tensor, bridge::GetXlaTensor(mask), value)); +} + +at::Tensor XLANativeFunctions::masked_scatter(const at::Tensor& self, + const at::Tensor& mask, + const at::Tensor& source) { + TORCH_LAZY_FN_COUNTER("xla::"); + XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + return bridge::AtenFromXlaTensor(tensor_methods::masked_scatter( + self_tensor, bridge::GetXlaTensor(mask), bridge::GetXlaTensor(source))); } at::Tensor XLANativeFunctions::masked_select(const at::Tensor& self, @@ -2141,8 +2267,8 @@ at::Tensor& XLANativeFunctions::normal_( return self; } -at::Tensor XLANativeFunctions::permute(const at::Tensor& self, - at::IntArrayRef dims) { +at::Tensor XLANativeFunctions::permute_copy(const at::Tensor& self, + at::IntArrayRef dims) { TORCH_LAZY_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::permute( bridge::GetXlaTensor(self), XlaHelpers::I64List(dims))); @@ -2185,10 +2311,9 @@ at::Tensor XLANativeFunctions::pow(const at::Scalar& self, tensor_methods::pow(self, bridge::GetXlaTensor(exponent))); } -at::Tensor XLANativeFunctions::prelu(const at::Tensor& self, - const at::Tensor& weight) { +at::Tensor XLANativeFunctions::_prelu_kernel(const at::Tensor& self, + const at::Tensor& weight) { TORCH_LAZY_FN_COUNTER("xla::"); - // If multiple weights, check channel size == number of weights. int64_t weight_num = weight.numel(); if (weight.numel() > 1) { @@ -2229,6 +2354,24 @@ at::Tensor XLANativeFunctions::prod(const at::Tensor& self, int64_t dim, PromoteIntegralType(self.scalar_type(), dtype))); } +void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input, + const at::Tensor& output) { + TORCH_LAZY_FN_COUNTER("xla::"); + // This op is only called when functionalize pass is transforming an in-place + // op. Therefore, we can populate some meta data to maintain any optimization + // for in-place ops we have in hands. + + // 1) Aid XLA's InputOutputAlias. + auto input_tensor = bridge::GetXlaTensor(input); + auto output_tensor = bridge::GetXlaTensor(output); + output_tensor->data()->alias_id = input_tensor->GetUniqueId(); + + // 2) Aid SPMD. + if (input_tensor->sharding_spec()) { + output_tensor->SetShardingSpec(*(input_tensor->sharding_spec())); + } +} + at::Tensor& XLANativeFunctions::put_(at::Tensor& self, const at::Tensor& index, const at::Tensor& source, bool accumulate) { @@ -2532,13 +2675,25 @@ at::Tensor XLANativeFunctions::scatter_reduce( } } -at::Tensor XLANativeFunctions::select(const at::Tensor& self, int64_t dim, - int64_t index) { +at::Tensor XLANativeFunctions::select_copy(const at::Tensor& self, int64_t dim, + int64_t index) { TORCH_LAZY_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::select(bridge::GetXlaTensor(self), dim, index)); } +at::Tensor XLANativeFunctions::select_scatter(const at::Tensor& base, + const at::Tensor& mutated_view, + int64_t dim, int64_t index) { + TORCH_LAZY_FN_COUNTER("xla::"); + auto base_ = bridge::GetXlaTensor(base); + auto mutated_view_ = bridge::GetXlaTensor(mutated_view); + auto base_clone = tensor_methods::clone(base_); + auto base_clone_slice = tensor_methods::select(base_clone, dim, index); + tensor_methods::copy_(base_clone_slice, mutated_view_); + return bridge::AtenFromXlaTensor(base_clone); +} + // TODO(JackCaoG): Remove after elu being codegened at::Tensor& XLANativeFunctions::selu_(at::Tensor& self) { TORCH_LAZY_FN_COUNTER("xla::"); @@ -2568,9 +2723,10 @@ at::Tensor XLANativeFunctions::sigmoid_backward(const at::Tensor& grad_output, bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(output))); } -at::Tensor XLANativeFunctions::slice(const at::Tensor& self, int64_t dim, - c10::optional start, - c10::optional end, int64_t step) { +at::Tensor XLANativeFunctions::slice_copy(const at::Tensor& self, int64_t dim, + c10::optional start, + c10::optional end, + int64_t step) { TORCH_LAZY_FN_COUNTER("xla::"); int64_t start_val = start.has_value() ? start.value() : 0; int64_t end_val = end.has_value() ? end.value() : INT64_MAX; @@ -2578,13 +2734,19 @@ at::Tensor XLANativeFunctions::slice(const at::Tensor& self, int64_t dim, bridge::GetXlaTensor(self), dim, start_val, end_val, step)); } -std::tuple XLANativeFunctions::slogdet( - const at::Tensor& self) { +at::Tensor XLANativeFunctions::slice_scatter( + const at::Tensor& base, const at::Tensor& mutated_view, int64_t dim, + c10::optional start, c10::optional end, int64_t step) { TORCH_LAZY_FN_COUNTER("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - auto outputs = tensor_methods::slogdet(self_tensor); - return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), - bridge::AtenFromXlaTensor(std::get<1>(outputs))); + auto base_ = bridge::GetXlaTensor(base); + auto mutated_view_ = bridge::GetXlaTensor(mutated_view); + auto base_clone = tensor_methods::clone(base_); + int64_t start_val = start.has_value() ? start.value() : 0; + int64_t end_val = end.has_value() ? end.value() : INT64_MAX; + auto base_clone_slice = + tensor_methods::slice(base_clone, dim, start_val, end_val, step); + tensor_methods::copy_(base_clone_slice, mutated_view_); + return bridge::AtenFromXlaTensor(base_clone); } at::Tensor XLANativeFunctions::smooth_l1_loss(const at::Tensor& self, @@ -2645,16 +2807,16 @@ std::tuple XLANativeFunctions::sort( bridge::AtenFromXlaTensor(std::get<1>(results))); } -std::vector XLANativeFunctions::split(const at::Tensor& self, - int64_t split_size, - int64_t dim) { +std::vector XLANativeFunctions::split_copy(const at::Tensor& self, + int64_t split_size, + int64_t dim) { TORCH_LAZY_FN_COUNTER("xla::"); auto xla_tensors = tensor_methods::split(bridge::GetXlaTensor(self), split_size, dim); return bridge::AtenFromXlaTensors(xla_tensors); } -std::vector XLANativeFunctions::split_with_sizes( +std::vector XLANativeFunctions::split_with_sizes_copy( const at::Tensor& self, at::IntArrayRef split_sizes, int64_t dim) { TORCH_LAZY_FN_COUNTER("xla::"); auto xla_tensors = tensor_methods::split_with_sizes( @@ -2668,32 +2830,19 @@ at::Tensor XLANativeFunctions::sqrt(const at::Tensor& self) { tensor_methods::sqrt(bridge::GetXlaTensor(self))); } -at::Tensor XLANativeFunctions::squeeze(const at::Tensor& self) { +at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::squeeze(bridge::GetXlaTensor(self))); } -at::Tensor XLANativeFunctions::squeeze(const at::Tensor& self, int64_t dim) { +at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self, + int64_t dim) { TORCH_LAZY_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::squeeze(bridge::GetXlaTensor(self), dim)); } -at::Tensor& XLANativeFunctions::squeeze_(at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - tensor_methods::squeeze_(self_tensor); - return self; -} - -at::Tensor& XLANativeFunctions::squeeze_(at::Tensor& self, int64_t dim) { - TORCH_LAZY_FN_COUNTER("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - tensor_methods::squeeze_(self_tensor, dim); - return self; -} - at::Tensor XLANativeFunctions::stack(at::TensorList tensors, int64_t dim) { TORCH_LAZY_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( @@ -2805,19 +2954,12 @@ std::tuple XLANativeFunctions::svd( bridge::AtenFromXlaTensor(std::get<2>(results))); } -at::Tensor XLANativeFunctions::t(const at::Tensor& self) { +at::Tensor XLANativeFunctions::t_copy(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::transpose(bridge::GetXlaTensor(self), 0, 1)); } -at::Tensor& XLANativeFunctions::t_(at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - tensor_methods::transpose_(self_tensor, 0, 1); - return self; -} - at::Tensor XLANativeFunctions::tanh_backward(const at::Tensor& grad_output, const at::Tensor& output) { TORCH_LAZY_FN_COUNTER("xla::"); @@ -2857,21 +2999,13 @@ at::Tensor XLANativeFunctions::trace(const at::Tensor& self) { tensor_methods::trace(bridge::GetXlaTensor(self))); } -at::Tensor XLANativeFunctions::transpose(const at::Tensor& self, int64_t dim0, - int64_t dim1) { +at::Tensor XLANativeFunctions::transpose_copy(const at::Tensor& self, + int64_t dim0, int64_t dim1) { TORCH_LAZY_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::transpose(bridge::GetXlaTensor(self), dim0, dim1)); } -at::Tensor& XLANativeFunctions::transpose_(at::Tensor& self, int64_t dim0, - int64_t dim1) { - TORCH_LAZY_FN_COUNTER("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - tensor_methods::transpose_(self_tensor, dim0, dim1); - return self; -} - std::tuple XLANativeFunctions::triangular_solve( const at::Tensor& b, const at::Tensor& A, bool upper, bool transpose, bool unitriangular) { @@ -2885,8 +3019,8 @@ std::tuple XLANativeFunctions::triangular_solve( bridge::AtenFromXlaTensor(std::get<1>(results))); } -std::vector XLANativeFunctions::unbind(const at::Tensor& self, - int64_t dim) { +std::vector XLANativeFunctions::unbind_copy(const at::Tensor& self, + int64_t dim) { TORCH_LAZY_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensors( tensor_methods::unbind(bridge::GetXlaTensor(self), dim)); @@ -2906,19 +3040,13 @@ at::Tensor& XLANativeFunctions::uniform_( return self; } -at::Tensor XLANativeFunctions::unsqueeze(const at::Tensor& self, int64_t dim) { +at::Tensor XLANativeFunctions::unsqueeze_copy(const at::Tensor& self, + int64_t dim) { TORCH_LAZY_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::unsqueeze(bridge::GetXlaTensor(self), dim)); } -at::Tensor& XLANativeFunctions::unsqueeze_(at::Tensor& self, int64_t dim) { - TORCH_LAZY_FN_COUNTER("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - tensor_methods::unsqueeze_(self_tensor, dim); - return self; -} - at::Tensor XLANativeFunctions::upsample_bilinear2d( const at::Tensor& self, at::IntArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w) { @@ -3057,8 +3185,8 @@ std::tuple XLANativeFunctions::var_mean( bridge::AtenFromXlaTensor(std::get<1>(results))); } -at::Tensor XLANativeFunctions::view_symint(const at::Tensor& self, - at::SymIntArrayRef sym_size) { +at::Tensor XLANativeFunctions::view_copy_symint(const at::Tensor& self, + at::SymIntArrayRef sym_size) { // TODO: support symbolic sizes auto size = C10_AS_INTARRAYREF_SLOW(sym_size); TORCH_LAZY_FN_COUNTER("xla::"); @@ -3072,7 +3200,7 @@ at::Tensor XLANativeFunctions::where(const at::Tensor& condition, TORCH_LAZY_FN_COUNTER("xla::"); c10::MaybeOwned b_condition, b_self, b_other; std::tie(b_condition, b_self, b_other) = - expand_outplace(condition, self, other, "where"); + xla_expand_outplace(condition, self, other, "where"); return bridge::AtenFromXlaTensor(tensor_methods::where( bridge::GetXlaTensor(*b_condition), bridge::GetXlaTensor(*b_self), bridge::GetXlaTensor(*b_other))); @@ -3152,56 +3280,182 @@ XLANativeFunctions::native_group_norm(const at::Tensor& input, // core that call into view operators internally. These are all composite ops // that LTC can technically re-use / get for free, but we need to // "functionalize" them to remove the view ops before we can use them. +at::Tensor XLANativeFunctions::affine_grid_generator(const at::Tensor& theta, + at::IntArrayRef size, + bool align_corners) { + return at::functionalization::functionalize_aten_op::call(theta, size, align_corners); +} + at::Tensor XLANativeFunctions::block_diag(at::TensorList tensors) { - return at::native::block_diag(tensors); + return at::functionalization::functionalize_aten_op::call(tensors); +} + +at::Tensor XLANativeFunctions::_convolution( + const at::Tensor& input, const at::Tensor& weight, + const c10::optional& bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, + at::IntArrayRef output_padding, int64_t groups, bool benchmark, + bool deterministic, bool cudnn_enabled, bool allow_tf32) { + return at::functionalization::functionalize_aten_op::call(input, weight, bias, stride, padding, dilation, + transposed, output_padding, groups, benchmark, + deterministic, cudnn_enabled, allow_tf32); } + +::std::tuple +XLANativeFunctions::convolution_backward( + const at::Tensor& grad_output, const at::Tensor& input, + const at::Tensor& weight, at::OptionalIntArrayRef bias_sizes, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + bool transposed, at::IntArrayRef output_padding, int64_t groups, + ::std::array output_mask) { + // TODO (alanwaketan): Let's resuse + // `at::functionalization::functionalize_aten_op` after upstream has solved + // its issue. + // The following is adopted from aten/src/ATen/FunctionalTensorWrapper.cpp: + // functionalize_op_helper. + auto func_grad_output = + at::functionalization::impl::to_functional_tensor(grad_output); + auto func_input = at::functionalization::impl::to_functional_tensor(input); + auto func_weight = at::functionalization::impl::to_functional_tensor(weight); + + auto curr_tls = c10::impl::tls_local_dispatch_key_set(); + auto tls_reenable_functionalize = c10::impl::PODLocalDispatchKeySet(); + tls_reenable_functionalize.set_included(curr_tls.included_); + tls_reenable_functionalize.set_excluded( + curr_tls.excluded_.remove(c10::DispatchKey::Functionalize)); + c10::impl::ForceDispatchKeyGuard guard_(tls_reenable_functionalize); + auto results = at::native::convolution_backward( + func_grad_output, func_input, func_weight, bias_sizes, stride, padding, + dilation, transposed, output_padding, groups, output_mask); + + return std::make_tuple( + at::functionalization::impl::from_functional_tensor(std::get<0>(results)), + at::functionalization::impl::from_functional_tensor(std::get<1>(results)), + at::functionalization::impl::from_functional_tensor( + std::get<2>(results))); +} + +at::Tensor XLANativeFunctions::diag_embed(const at::Tensor& self, + int64_t offset, int64_t dim1, + int64_t dim2) { + return at::functionalization::functionalize_aten_op::call(self, offset, dim1, dim2); +} + +at::Tensor XLANativeFunctions::embedding_symint(const at::Tensor& weight, + const at::Tensor& indices, + c10::SymInt padding_idx, + bool scale_grad_by_freq, + bool sparse) { + // TODO: for now route to native, which dispatches supported XLA operations. + // We need to make use of the TPU embedding core here eventually. + return at::functionalization::functionalize_aten_op_symint::call(weight, indices, padding_idx, scale_grad_by_freq, + sparse); +} + +at::Tensor XLANativeFunctions::_euclidean_dist(const at::Tensor& x1, + const at::Tensor& x2) { + return at::functionalization::functionalize_aten_op::call(x1, x2); +} + at::Tensor XLANativeFunctions::new_empty_strided_symint( const at::Tensor& self, at::SymIntArrayRef size, at::SymIntArrayRef stride, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { - return at::native::new_empty_strided_symint(self, size, stride, dtype, layout, - device, pin_memory); + return at::functionalization::functionalize_aten_op_symint::call(self, size, stride, dtype, layout, device, + pin_memory); } at::Tensor XLANativeFunctions::narrow_copy_symint(const at::Tensor& self, int64_t dim, c10::SymInt start, c10::SymInt length) { - return at::native::narrow_copy_dense_symint(self, dim, start, length); + return at::functionalization::functionalize_aten_op_symint::call(self, dim, start, length); } + at::Tensor XLANativeFunctions::pixel_shuffle(const at::Tensor& self, int64_t upscale_factor) { - return at::native::math_pixel_shuffle(self, upscale_factor); + return at::functionalization::functionalize_aten_op::call(self, upscale_factor); } + at::Tensor XLANativeFunctions::pixel_unshuffle(const at::Tensor& self, int64_t downscale_factor) { - return at::native::math_pixel_unshuffle(self, downscale_factor); + return at::functionalization::functionalize_aten_op::call(self, downscale_factor); } + +at::Tensor XLANativeFunctions::reshape_symint(const at::Tensor& self, + c10::SymIntArrayRef shape) { + return at::functionalization::functionalize_aten_op_symint::call(self, shape); +} + at::Tensor XLANativeFunctions::select_backward_symint( const at::Tensor& grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) { - return at::native::select_backward_symint(grad_output, input_sizes, dim, - index); + return at::functionalization::functionalize_aten_op_symint::call(grad_output, input_sizes, dim, index); } + +at::Tensor XLANativeFunctions::select_symint(const at::Tensor& self, + int64_t dim, c10::SymInt index) { + return at::functionalization::functionalize_aten_op_symint::call(self, dim, index); +} + +at::Tensor XLANativeFunctions::slice(const at::Tensor& self, int64_t dim, + c10::optional start, + c10::optional end, int64_t step) { + return at::functionalization::functionalize_aten_op::call(self, dim, start, end, step); +} + +at::Tensor XLANativeFunctions::t(const at::Tensor& self) { + return at::functionalization::functionalize_aten_op::call(self); +} + +at::Tensor XLANativeFunctions::_trilinear( + const at::Tensor& i1, const at::Tensor& i2, const at::Tensor& i3, + at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, + at::IntArrayRef sumdim, int64_t unroll_dim) { + return at::functionalization::functionalize_aten_op::call(i1, i2, i3, expand1, expand2, expand3, sumdim, + unroll_dim); +} + at::Tensor XLANativeFunctions::linalg_pinv( const at::Tensor& self, const c10::optional& atol, const c10::optional& rtol, bool hermitian) { - return at::native::linalg_pinv(self, atol, rtol, hermitian); + return at::functionalization::functionalize_aten_op::call(self, atol, rtol, hermitian); +} + +at::Tensor XLANativeFunctions::mvlgamma(const at::Tensor& self, int64_t p) { + return at::functionalization::functionalize_aten_op::call( + self, p); } at::Tensor XLANativeFunctions::diagonal_backward_symint( const at::Tensor& grad_output, at::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { - return at::native::diagonal_backward_symint(grad_output, input_sizes, offset, - dim1, dim2); + return at::functionalization::functionalize_aten_op_symint::call(grad_output, input_sizes, offset, dim1, dim2); } at::Tensor XLANativeFunctions::slice_backward(const at::Tensor& grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { - return at::native::slice_backward(grad_output, input_sizes, dim, start, end, - step); + return at::functionalization::functionalize_aten_op::call(grad_output, input_sizes, dim, start, end, step); } at::Tensor XLANativeFunctions::_cdist_forward( @@ -3216,4 +3470,10 @@ at::Tensor XLANativeFunctions::_cdist_forward( bridge::GetXlaTensor(x1), bridge::GetXlaTensor(x2), p)); } +at::Tensor XLANativeFunctions::permute(const at::Tensor& self, + at::IntArrayRef dims) { + return at::functionalization::functionalize_aten_op::call( + self, dims); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index 67045fc78812..b45f9465f27c 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -228,13 +228,8 @@ xla::XlaOp BuildPrelu(xla::XlaOp input, xla::XlaOp weight) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); const xla::Shape& weight_shape = XlaHelpers::ShapeOfXlaOp(weight); - int64_t weight_num = xla::ShapeUtil::ElementsIn(weight_shape); - int64_t broadcast_dim = weight_num == 1 ? 0 : 1; - xla::XlaOp zero = xla::Zero(input.builder(), input_shape.element_type()); - xla::XlaOp broadcasted_weight = - xla::BroadcastInDim(weight, input_shape.dimensions(), {broadcast_dim}); - xla::XlaOp product = xla::Mul(input, broadcasted_weight); + xla::XlaOp product = xla::Mul(input, weight); return xla::Select(xla::Gt(input, zero), input, product); } diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 8a7e936dd83b..1e653c049b3e 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1609,6 +1609,10 @@ void InitXlaModuleBindings(py::module m) { MapXlaEnvVarsToLazy(); InitXlaBackend(); }); + m.def("_replace_xla_tensor", + [](at::Tensor& self, const at::Tensor& source) -> at::Tensor& { + return XLANativeFunctions::set_(self, source); + }); /* The distributed runtime service is used by the PjRt GPU client. */ py::class_getDynamicValue() != dim_node_1->getDynamicValue() ? 1 : 0; +} + +std::string SizeNe::ToString() const { return "aten::size_ne"; } + +SizeGe::SizeGe(torch::lazy::Value a, torch::lazy::Value b) + : XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::size_ge")}, + {a, b}, + xla::ShapeUtil::MakeShape( + GetShapeDimensionType(/*device=*/nullptr), {}), + 1) { + const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0)); + const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1)); + XLA_CHECK(dim_node_0); + XLA_CHECK(dim_node_1); +}; + +int64_t SizeGe::getDynamicValue() const { + const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0)); + const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1)); + XLA_CHECK(dim_node_0); + XLA_CHECK(dim_node_1); + return dim_node_0->getDynamicValue() >= dim_node_1->getDynamicValue() ? 1 : 0; +} + +std::string SizeGe::ToString() const { return "aten::size_ge"; } + +SizeLt::SizeLt(torch::lazy::Value a, torch::lazy::Value b) + : XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::size_lt")}, + {a, b}, + xla::ShapeUtil::MakeShape( + GetShapeDimensionType(/*device=*/nullptr), {}), + 1) { + const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0)); + const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1)); + XLA_CHECK(dim_node_0); + XLA_CHECK(dim_node_1); +}; + +int64_t SizeLt::getDynamicValue() const { + const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0)); + const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1)); + XLA_CHECK(dim_node_0); + XLA_CHECK(dim_node_1); + return dim_node_0->getDynamicValue() < dim_node_1->getDynamicValue() ? 1 : 0; +} + +std::string SizeLt::ToString() const { return "aten::size_lt"; } + SizeConstant::SizeConstant(int64_t val) : Scalar(c10::Scalar{val}, xla::ShapeUtil::MakeShape( diff --git a/torch_xla/csrc/ops/dynamic_ir.h b/torch_xla/csrc/ops/dynamic_ir.h index 88fac1fec338..dadc60becd46 100644 --- a/torch_xla/csrc/ops/dynamic_ir.h +++ b/torch_xla/csrc/ops/dynamic_ir.h @@ -68,6 +68,51 @@ class SizeEq : public XlaNode, public torch::lazy::DimensionNode { } }; +class SizeNe : public XlaNode, public torch::lazy::DimensionNode { + public: + SizeNe(torch::lazy::Value a, torch::lazy::Value b); + int64_t getDynamicValue() const override; + int64_t getStaticValue() const override { + TORCH_CHECK(false, "Comparison operators should be using getDynamicValue"); + } + bool isSymbolic() const override { return true; } + std::string ToString() const override; + virtual XlaOpVector Lower(LoweringContext* loctx) const override { + // TODO: not sure we will ever need it? + TORCH_CHECK(false, "Lowering comparison nodes isn't supported yet!"); + } +}; + +class SizeGe : public XlaNode, public torch::lazy::DimensionNode { + public: + SizeGe(torch::lazy::Value a, torch::lazy::Value b); + int64_t getDynamicValue() const override; + int64_t getStaticValue() const override { + TORCH_CHECK(false, "Comparison operators should be using getDynamicValue"); + } + bool isSymbolic() const override { return true; } + std::string ToString() const override; + virtual XlaOpVector Lower(LoweringContext* loctx) const override { + // TODO: not sure we will ever need it? + TORCH_CHECK(false, "Lowering comparison nodes isn't supported yet!"); + } +}; + +class SizeLt : public XlaNode, public torch::lazy::DimensionNode { + public: + SizeLt(torch::lazy::Value a, torch::lazy::Value b); + int64_t getDynamicValue() const override; + int64_t getStaticValue() const override { + TORCH_CHECK(false, "Comparison operators should be using getDynamicValue"); + } + bool isSymbolic() const override { return true; } + std::string ToString() const override; + virtual XlaOpVector Lower(LoweringContext* loctx) const override { + // TODO: not sure we will ever need it? + TORCH_CHECK(false, "Lowering comparison nodes isn't supported yet!"); + } +}; + class SizeAdd : public XlaNode, public torch::lazy::DimensionNode { public: SizeAdd(torch::lazy::Value a, torch::lazy::Value b); diff --git a/torch_xla/csrc/ops/index_ops.cpp b/torch_xla/csrc/ops/index_ops.cpp index 3243024d3ce3..9f2ebac0a78c 100644 --- a/torch_xla/csrc/ops/index_ops.cpp +++ b/torch_xla/csrc/ops/index_ops.cpp @@ -2,6 +2,7 @@ #include #include +#include #include "tensorflow/compiler/xla/permutation_util.h" #include "third_party/xla_client/debug_macros.h" @@ -19,6 +20,7 @@ #include "torch_xla/csrc/ops/permute.h" #include "torch_xla/csrc/ops/scalar.h" #include "torch_xla/csrc/tensor_methods.h" +#include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/xla_graph_executor.h" #include "torch_xla/csrc/xla_lower_util.h" @@ -62,7 +64,8 @@ std::vector ExpandByteTensors( // Replace with nonzeros. auto nonzero = index->nonzero(); for (int64_t j = 0; j < index->dim(); j++) { - result.emplace_back(nonzero.select(1, j)); + // There is no tensor.select_copy. So at::select_copy is used. + result.emplace_back(at::select_copy(nonzero, 1, j)); } } else { result.emplace_back(index.value_or(at::Tensor())); @@ -244,7 +247,7 @@ CanonicalIndexInfo GetCanonicalIndexInfo( CheckIndexTensorTypes(orig_indices); // First expand ByteTensor (boolean masks) into 1 or more LongTensors, then // broadcast all index tensors together. - auto indices = at::expand_outplace(ExpandByteTensors(base, orig_indices)); + auto indices = xla_expand_outplace(ExpandByteTensors(base, orig_indices)); // If the non-null indices are not all adjacent, transpose base and indices // together so that they're adjacent at the front. CanonicalIndexInfo canonical_index_info = TransposeToFront(base, indices); diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index a3abf55551fc..6249c67c9607 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -622,6 +622,19 @@ bool XLATensor::ShouldSyncIrNode() { return this->data()->ir_value->op() != xla_device_data; } +bool XLASymNodeImpl::is_bool() { + c10::Symbol op = node()->op().op; + // Reference: + // https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/symbolic_shapes.py#L403 + if (op == c10::Symbol::fromQualString("aten::size_eq") || + op == c10::Symbol::fromQualString("aten::size_ne") || + op == c10::Symbol::fromQualString("aten::size_ge") || + op == c10::Symbol::fromQualString("aten::size_lt")) { + return true; + } + return false; +} + bool XLASymNodeImpl::is_int() { // TODO: handle not is int return true; @@ -683,8 +696,10 @@ c10::SymNode XLASymNodeImpl::eq(const c10::SymNode& other) { } c10::SymNode XLASymNodeImpl::ne(const c10::SymNode& other) { - XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ - << " has not been implemented."; + TORCH_LAZY_FN_COUNTER("xla::size_"); + auto p_other = dynamic_cast(other.get()); + auto n_ne = torch::lazy::MakeNode(node(), p_other->node()); + return c10::make_intrusive(n_ne); } c10::SymNode XLASymNodeImpl::gt(const c10::SymNode& other) { @@ -693,8 +708,10 @@ c10::SymNode XLASymNodeImpl::gt(const c10::SymNode& other) { } c10::SymNode XLASymNodeImpl::lt(const c10::SymNode& other) { - XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ - << " has not been implemented."; + TORCH_LAZY_FN_COUNTER("xla::size_"); + auto p_other = dynamic_cast(other.get()); + auto n_lt = torch::lazy::MakeNode(node(), p_other->node()); + return c10::make_intrusive(n_lt); } c10::SymNode XLASymNodeImpl::le(const c10::SymNode& other) { @@ -703,8 +720,10 @@ c10::SymNode XLASymNodeImpl::le(const c10::SymNode& other) { } c10::SymNode XLASymNodeImpl::ge(const c10::SymNode& other) { - XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ - << " has not been implemented."; + TORCH_LAZY_FN_COUNTER("xla::size_"); + auto p_other = dynamic_cast(other.get()); + auto n_ge = torch::lazy::MakeNode(node(), p_other->node()); + return c10::make_intrusive(n_ge); } c10::SymNode XLASymNodeImpl::ceil() { @@ -732,10 +751,56 @@ c10::SymNode XLASymNodeImpl::sym_max(const c10::SymNode& other) { << " has not been implemented."; } -c10::SymNode XLASymNodeImpl::clone() { +c10::SymNode XLASymNodeImpl::sym_or(const c10::SymNode& other) { + XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ + << " has not been implemented."; +} + +c10::SymNode XLASymNodeImpl::sym_and(const c10::SymNode& other) { + XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ + << " has not been implemented."; +} + +c10::SymNode XLASymNodeImpl::sym_not() { XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ << " has not been implemented."; } +// NB: self is ignored here, only the arguments are used +c10::SymNode XLASymNodeImpl::is_contiguous(at::ArrayRef sizes, + at::ArrayRef strides) { + XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ + << " has not been implemented."; +} +c10::SymNode XLASymNodeImpl::is_channels_last_contiguous_2d( + at::ArrayRef sizes, at::ArrayRef strides) { + XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ + << " has not been implemented."; +} +c10::SymNode XLASymNodeImpl::is_channels_last_contiguous_3d( + at::ArrayRef sizes, at::ArrayRef strides) { + XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ + << " has not been implemented."; +} +c10::SymNode XLASymNodeImpl::is_channels_last_strides_2d( + at::ArrayRef sizes, at::ArrayRef strides) { + XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ + << " has not been implemented."; +} +c10::SymNode XLASymNodeImpl::is_channels_last_strides_3d( + at::ArrayRef sizes, at::ArrayRef strides) { + XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ + << " has not been implemented."; +} +c10::SymNode XLASymNodeImpl::is_non_overlapping_and_dense( + at::ArrayRef sizes, at::ArrayRef strides) { + XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ + << " has not been implemented."; +} + +c10::SymNode XLASymNodeImpl::clone() { + TORCH_LAZY_FN_COUNTER("xla::size_"); + return c10::make_intrusive(node()); +} c10::SymNode XLASymNodeImpl::sym_float() { XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ @@ -762,6 +827,11 @@ double XLASymNodeImpl::guard_float(const char* file, int64_t line) { << " has not been implemented."; } +bool XLASymNodeImpl::guard_bool(const char* file, int64_t line) { + // TODO: Take advantages of file and line. + return bool_(); +} + int64_t XLASymNodeImpl::int_() { std::shared_ptr dn = torch_xla::DimCast(node()); return dn->getDynamicValue(); @@ -772,6 +842,8 @@ bool XLASymNodeImpl::bool_() { return dn->getDynamicValue() != 0; } +bool XLASymNodeImpl::has_hint() { return true; } + std::string XLASymNodeImpl::str() { return "<=" + std::to_string(DimCast(node().get())->getStaticValue()); } diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 05b7b6f83c0c..a3c4accca125 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -30,6 +30,7 @@ namespace torch_xla { class TORCH_API XLASymNodeImpl : public c10::SymNodeImpl { public: XLASymNodeImpl(torch::lazy::NodePtr ptr) : node_(std::move(ptr)) {} + bool is_bool() override; bool is_int() override; bool is_float() override; c10::SymNode add(const c10::SymNode& other) override; @@ -50,14 +51,37 @@ class TORCH_API XLASymNodeImpl : public c10::SymNodeImpl { c10::SymNode neg() override; c10::SymNode sym_min(const c10::SymNode& other) override; c10::SymNode sym_max(const c10::SymNode& other) override; + c10::SymNode sym_or(const c10::SymNode& other) override; + c10::SymNode sym_and(const c10::SymNode& other) override; + c10::SymNode sym_not() override; + // NB: self is ignored here, only the arguments are used + c10::SymNode is_contiguous(at::ArrayRef sizes, + at::ArrayRef strides) override; + c10::SymNode is_channels_last_contiguous_2d( + at::ArrayRef sizes, + at::ArrayRef strides) override; + c10::SymNode is_channels_last_contiguous_3d( + at::ArrayRef sizes, + at::ArrayRef strides) override; + c10::SymNode is_channels_last_strides_2d( + at::ArrayRef sizes, + at::ArrayRef strides) override; + c10::SymNode is_channels_last_strides_3d( + at::ArrayRef sizes, + at::ArrayRef strides) override; + c10::SymNode is_non_overlapping_and_dense( + at::ArrayRef sizes, + at::ArrayRef strides) override; c10::SymNode clone() override; c10::SymNode sym_float() override; c10::SymNode wrap_int(int64_t num) override; c10::SymNode wrap_float(double num) override; int64_t guard_int(const char* file, int64_t line) override; double guard_float(const char* file, int64_t line) override; + bool guard_bool(const char* file, int64_t line) override; int64_t int_() override; bool bool_() override; + bool has_hint() override; std::string str() override; torch::lazy::NodePtr node() { return node_; } @@ -84,25 +108,33 @@ class XLATensor : public torch::lazy::LazyTensor { ShardingSpecPtr sharding = nullptr) : torch::lazy::LazyTensor::Data(handle, device), logical_element_type(logical_element_type), - sharding(sharding) {} + sharding(sharding) { + alias_id = unique_id; + } Data(torch::lazy::Value ir_value, const torch::lazy::BackendDevice& device, c10::optional logical_element_type, ShardingSpecPtr sharding = nullptr) : torch::lazy::LazyTensor::Data(ir_value, device), logical_element_type(logical_element_type), - sharding(sharding) {} + sharding(sharding) { + alias_id = unique_id; + } Data(at::Tensor tensor_data, const torch::lazy::BackendDevice& device, ShardingSpecPtr sharding = nullptr) : torch::lazy::LazyTensor::Data(tensor_data, device), logical_element_type(tensor_data.scalar_type()), - sharding(sharding) {} + sharding(sharding) { + alias_id = unique_id; + } Data(std::shared_ptr view, const torch::lazy::BackendDevice& device, c10::optional logical_element_type, ShardingSpecPtr sharding = nullptr) : torch::lazy::LazyTensor::Data(device), view(std::move(view)), logical_element_type(logical_element_type), - sharding(sharding) {} + sharding(sharding) { + alias_id = unique_id; + } ~Data(); @@ -114,6 +146,10 @@ class XLATensor : public torch::lazy::LazyTensor { // A copy of the sharding spec is attached to the IR node via // `SetShardingSpec` and also during the sync tensor collection. ShardingSpecPtr sharding; + // This is used to enable XLA's InputOutputAlias. It's inited + // with unique_id, and then only get updated during the in-place + // op funtionalize pass to point to the input. + int64_t alias_id{0}; }; static XLATensorPtr Create(const at::Tensor& tensor, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 76fcedc8158b..0aac7da7890c 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1536,19 +1536,36 @@ XLATensorPtr lt(const XLATensorPtr& input, const XLATensorPtr& other) { return DispatchComparisonOp(at::aten::lt, input, other); } -void masked_fill_(XLATensorPtr& input, const XLATensorPtr& mask, - const at::Scalar& value) { +XLATensorPtr masked_fill(XLATensorPtr& input, const XLATensorPtr& mask, + const at::Scalar& value) { torch::lazy::ScopePusher ir_scope(at::aten::masked_fill.toQualString()); - input->SetIrValue(torch::lazy::MakeNode( - input->GetIrValue(), MaybeExpand(mask->GetIrValue(), input->shape()), + auto input_value = input->GetIrValue(); + // Expand input tensor to mask if needed (same as masked_scatter below). + // An additional check makes sure to only expand if the rank of input tensor + // is less than that of the mask tensor. + if (input->shape().get().rank() <= mask->shape().get().rank() && + input->shape().get().dimensions() < mask->shape().get().dimensions()) { + input_value = MaybeExpand(input->GetIrValue(), mask->shape()); + } + return input->CreateFrom(torch::lazy::MakeNode( + input_value, MaybeExpand(mask->GetIrValue(), GetXlaShape(input_value)), value)); } -void masked_scatter_(XLATensorPtr& input, const XLATensorPtr& mask, - const XLATensorPtr& source) { +XLATensorPtr masked_scatter(XLATensorPtr& input, const XLATensorPtr& mask, + const XLATensorPtr& source) { torch::lazy::ScopePusher ir_scope(at::aten::masked_scatter.toQualString()); - input->SetIrValue(torch::lazy::MakeNode( - input->GetIrValue(), MaybeExpand(mask->GetIrValue(), input->shape()), + auto input_value = input->GetIrValue(); + // This ensures that input tensor is at least the same shape as mask tensor. + // Note that we can't use the existing MaybeExpand function since + // input tensor may sometimes be bigger than the mask tensor, and MaybeExpand + // requires the first parameter to always be less or equal to the second + // parameter. + if (input->shape().get().dimensions() < mask->shape().get().dimensions()) { + input_value = MaybeExpand(input->GetIrValue(), mask->shape()); + } + return input->CreateFrom(torch::lazy::MakeNode( + input_value, MaybeExpand(mask->GetIrValue(), GetXlaShape(input_value)), source->GetIrValue())); } @@ -2359,16 +2376,6 @@ XLATensorPtr squeeze(const XLATensorPtr& input, int64_t dim) { return view(input, output_dimensions); } -void squeeze_(XLATensorPtr& input) { - input->SetIrValue(torch::lazy::MakeNode(input->GetIrValue(), -1)); -} - -void squeeze_(XLATensorPtr& input, int64_t dim) { - input->SetIrValue(torch::lazy::MakeNode( - input->GetIrValue(), torch::lazy::GetCanonicalDimensionIndex( - dim, input->shape().get().rank()))); -} - XLATensorPtr stack(absl::Span tensors, int64_t dim) { XLA_CHECK_GT(tensors.size(), 0); std::vector values; @@ -2532,18 +2539,6 @@ XLATensorPtr transpose(const XLATensorPtr& input, int64_t dim0, int64_t dim1) { return input->CreateViewTensor(std::move(view_info)); } -void transpose_(XLATensorPtr& input, int64_t dim0, int64_t dim1) { - auto input_shape = input->shape(); - if (input_shape.get().rank() <= 1) { - // no op if input rank <=1 - return; - } - auto permute_dims = torch::lazy::MakeTransposePermutation( - /*dim0=*/dim0, /*dim1=*/dim1, /*rank=*/input_shape.get().rank()); - ViewInfo view_info(ViewInfo::Type::kPermute, input_shape, permute_dims); - return input->ModifyCurrentView(std::move(view_info)); -} - std::tuple triangular_solve( const XLATensorPtr& rhs, const XLATensorPtr& lhs, bool left_side, bool upper, bool transpose, bool unitriangular) { diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 2ce51c3e4d2e..7c8c7b0b2a62 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -484,12 +484,11 @@ XLATensorPtr lt(const XLATensorPtr& input, const at::Scalar& other); XLATensorPtr lt(const XLATensorPtr& input, const XLATensorPtr& other); -// In-place version of the method above. -void masked_fill_(XLATensorPtr& input, const XLATensorPtr& mask, - const at::Scalar& value); +XLATensorPtr masked_fill(XLATensorPtr& input, const XLATensorPtr& mask, + const at::Scalar& value); -void masked_scatter_(XLATensorPtr& input, const XLATensorPtr& mask, - const XLATensorPtr& source); +XLATensorPtr masked_scatter(XLATensorPtr& input, const XLATensorPtr& mask, + const XLATensorPtr& source); XLATensorPtr masked_select(const XLATensorPtr& input, const XLATensorPtr& mask); diff --git a/torch_xla/csrc/tensor_util.h b/torch_xla/csrc/tensor_util.h index 4b71ff1a7c4c..d432fd60c162 100644 --- a/torch_xla/csrc/tensor_util.h +++ b/torch_xla/csrc/tensor_util.h @@ -107,4 +107,59 @@ bool RequiresRawTypeCasting(at::ScalarType scalar_type, xla::PrimitiveType GetShapeDimensionType( const torch::lazy::BackendDevice* device); +// The following functions are copied from aten/src/ATen/ExpandUtils.h just to +// replace the expand with expand_copy. +// TODO(alanwaketan): Fix the upstream. +inline std::tuple, c10::MaybeOwned, + c10::MaybeOwned> +xla_expand_outplace(const at::Tensor& to_expand1, const at::Tensor& to_expand2, + const at::Tensor& to_expand3, const char* api_name) { + at::check_defined({to_expand1, to_expand2, to_expand3}, api_name); + if (to_expand1.sizes().equals(to_expand2.sizes()) && + to_expand1.sizes().equals(to_expand3.sizes())) { + return std::make_tuple(c10::MaybeOwned::borrowed(to_expand1), + c10::MaybeOwned::borrowed(to_expand2), + c10::MaybeOwned::borrowed(to_expand3)); + } + + auto expanded_size12 = + at::infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes()); + auto expanded_size = + at::infer_size_dimvector(expanded_size12, to_expand3.sizes()); + return std::make_tuple(c10::MaybeOwned::owned( + at::expand_copy(to_expand1, expanded_size)), + c10::MaybeOwned::owned( + at::expand_copy(to_expand2, expanded_size)), + c10::MaybeOwned::owned( + at::expand_copy(to_expand3, expanded_size))); +} + +inline std::vector xla_expand_outplace(at::TensorList to_expand) { + // expands a list of Tensors; ignores undefined (null) tensors + bool first = true; + at::DimVector sizes; + for (const auto i : c10::irange(to_expand.size())) { + if (!to_expand[i].defined()) { + continue; + } else if (first) { + sizes = to_expand[i].sizes(); + first = false; + } else { + sizes = at::infer_size_dimvector(sizes, to_expand[i].sizes()); + } + } + + std::vector result(to_expand.size()); + for (const auto i : c10::irange(to_expand.size())) { + if (!to_expand[i].defined()) { + continue; + } else if (to_expand[i].sizes().equals(sizes)) { + result[i] = to_expand[i]; + } else { + result[i] = at::expand_copy(to_expand[i], sizes); + } + } + return result; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 2e9a10ecc361..8f2ce1f46dab 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1034,7 +1034,7 @@ XLAGraphExecutor::BuildInputOutputAliases( // those buffers are no longer needed after execution. for (size_t i = 0; i < indices.size(); ++i) { size_t tensor_index = indices[i]; - int64_t tensor_id = tensors[tensor_index]->GetUniqueId(); + int64_t tensor_id = tensors[tensor_index]->data()->alias_id; output_tensor_id_map[tensor_id] = i; } const auto& parameters_data = lowering_ctx->GetParametersData(); diff --git a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py index 3ca205f2ce17..74ef9d32e361 100644 --- a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py +++ b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py @@ -31,6 +31,7 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter from torch.nn.utils.rnn import PackedSequence +import torch_xla import torch_xla.core.xla_model as xm from .xla_flatten_params_wrapper import XlaFlattenParamsWrapper @@ -656,14 +657,14 @@ def _shard_parameters_(self, params_to_shard) -> None: ".", "_FSDP_SHARD_SEPARATOR_") self.register_parameter(p_shard._name, p_shard) self.sharded_params.append(p_shard) - # Free the full parameter storage (here we free its internal XLATensor) but keep the tensor itself - # for auto-grad tracing (like `torch.autograd.Variable` before the tensor-variable merge). - p.set_(p.new_zeros(1)) if p.device != self.xla_device: # cast to XLA device if not already on XLA p = p.to(self.xla_device).requires_grad_(p.requires_grad) # update p in full_params since id(p) changed after the casting self.full_params[idx] = p + # Free the full parameter storage (here we free its internal XLATensor) but keep the tensor itself + # for auto-grad tracing (like `torch.autograd.Variable` before the tensor-variable merge). + torch_xla._XLAC._replace_xla_tensor(p, p.new_zeros(1)) p._sharded_param = p_shard # add a handle to the sharded parameter p._has_full_param = False # deregister the full parameter tensors from their modules (so that they won't @@ -1361,10 +1362,12 @@ def _rebuild_full_params(self, self.optimization_barrier_op([p_padded]) with torch.autograd._unsafe_preserve_version_counter(p): if self._shard_param_on_dim_0: - p.set_(p_padded[:p_shard._orig_size[0]]) + torch_xla._XLAC._replace_xla_tensor( + p, p_padded[:p_shard._orig_size[0]]) else: - p.set_(p_padded[:p_shard._orig_size.numel()].view( - p_shard._orig_size)) + torch_xla._XLAC._replace_xla_tensor( + p, + p_padded[:p_shard._orig_size.numel()].view(p_shard._orig_size)) p._has_full_param = True self.has_full_params = True @@ -1395,7 +1398,7 @@ def _free_full_params(self, if p._has_full_param: # free the original full parameter with torch.autograd._unsafe_preserve_version_counter(p): - p.set_(self._dummy_data_placeholder) + torch_xla._XLAC._replace_xla_tensor(p, self._dummy_data_placeholder) p._has_full_param = False if apply_opt_barrier: diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index b89c92ed295e..ff9b7d1bc7d1 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -169,7 +169,8 @@ def send(self, tensors, dst_rank, tag=0): input_as_result = xm.send(t, channel_id) # Make the sent tensor depend on the token, such that the `send` # op can actually be built into the computation graph. - t.data = input_as_result + with torch.no_grad(): + t.copy_(input_as_result) results.append(input_as_result) return _ret_work(results) diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index dd583280f392..0970ae70f5bb 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -113,6 +113,7 @@ supported: - _copy_from - _copy_from_and_resize - _index_put_impl_ + - _linalg_slogdet - _linalg_svd - _local_scalar_dense - _log_softmax @@ -121,19 +122,19 @@ supported: - _softmax - _softmax_backward_data - _to_cpu - - _trilinear + - _to_copy - _unsafe_view - adaptive_max_pool2d - adaptive_max_pool2d_backward - add.Scalar - add.Tensor - addmm - - alias + - alias_copy - arange.start_out - argmax - argmin - - as_strided - - as_strided_ + - as_strided_copy + - as_strided_scatter - atan2 - avg_pool2d - avg_pool2d_backward @@ -158,21 +159,24 @@ supported: - constant_pad_nd - convolution_backward_overrideable - convolution_overrideable + - copy + - copy_ - cross - cumprod - cumsum + - detach_copy - diag - - diagonal + - diagonal_copy + - diagonal_scatter - div.Scalar - div.Tensor - div.Tensor_mode - dot - elu_backward - - embedding - embedding_dense_backward - empty.memory_format - empty_strided - - expand + - expand_copy - exponential_ - eye.m_out - eye.out @@ -199,15 +203,18 @@ supported: - leaky_relu_backward - lerp.Scalar - lerp.Tensor + - lift + - lift_fresh + - linalg_inv_ex - linspace - log - log1p - log2 - log10 - logsumexp - - masked_fill_.Scalar - - masked_fill_.Tensor - - masked_scatter_ + - masked_fill.Scalar + - masked_fill.Tensor + - masked_scatter - masked_select - max - max.dim @@ -248,13 +255,14 @@ supported: - normal.Tensor_float - normal.Tensor_Tensor - normal_ - - permute + - permute_copy - pow.Scalar - pow.Tensor_Scalar - pow.Tensor_Tensor - - prelu + - _prelu_kernel - prod - prod.dim_int + - _propagate_xla_data - put_ - qr - random_ @@ -281,25 +289,25 @@ supported: - scatter_add - scatter_reduce.two - select.int + - select_copy.int + - select_scatter - selu_ - set_.source_Tensor - sigmoid - sigmoid_backward - - slice.Tensor - - slogdet + - slice_copy.Tensor + - slice_scatter - smooth_l1_loss - smooth_l1_loss_backward - softplus - softplus_backward - sort - sort.stable - - split.Tensor - - split_with_sizes + - split_copy.Tensor + - split_with_sizes_copy - sqrt - - squeeze - - squeeze.dim - - squeeze_ - - squeeze_.dim + - squeeze_copy + - squeeze_copy.dim - stack - std - std.correction @@ -310,60 +318,73 @@ supported: - sum - sum.dim_IntList - svd - - t - - t_ + - t_copy - tanh_backward - threshold - threshold_backward - topk - trace - - transpose.int - - transpose_ + - transpose_copy.int - triangular_solve - - unbind.int + - unbind_copy.int - uniform_ - - unsqueeze - - unsqueeze_ + - unsqueeze_copy - upsample_bilinear2d - upsample_bilinear2d_backward - upsample_nearest2d - upsample_nearest2d_backward - var.correction - var_mean.correction - - view + - view_copy - where.self - xlogy.Tensor - zero_ - _native_batch_norm_legit - _native_batch_norm_legit.no_stats + # Note: [functionalization and CompositeExplicitAutograd] # Below are all operators that are "composite" in core, # but require us to explicitly re-enable functionalization in order to use them. # Why? These operators are all CompositeExplicitAutograd, which mean that they run # after functionalization, # but their implementations call view operators (which we need to functionalize away). + - affine_grid_generator - block_diag + - _convolution + - convolution_backward + - diag_embed + - embedding + - _euclidean_dist - slice_backward - diagonal_backward - new_empty_strided - narrow_copy - pixel_shuffle - pixel_unshuffle + - reshape - select_backward + - select.int + - slice.Tensor + - t + - _trilinear - linalg_pinv.atol_rtol_tensor - _cdist_forward + - mvlgamma + - permute # The same applies to these ops, but we already have direct lowerings for them - # - _trilinear # - logsumexp.out symint: - embedding - empty.memory_format - empty_strided - - expand + - expand_copy - new_empty_strided - - view + - view_copy - diagonal_backward - narrow_copy - select_backward + - select.int + # See Note: [functionalization and CompositeExplicitAutograd] + - reshape autograd: - einsum - max_pool2d