Skip to content

Commit

Permalink
Update lift_fresh ops
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Dec 14, 2022
1 parent 9705938 commit dc3ed92
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
29 changes: 20 additions & 9 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ namespace {
at::Tensor to_meta(const at::Tensor& tensor) {
// undefined tensors can't be converted to the meta device, since they don't
// have sizes/strides
std::cout << "WONJOO: at aten_xla_type.cpp, to_meta1" << std::endl;
std::cout << "WONJOO: at aten_xla_type.cpp, to_meta2, tensor_is_functional=" << at::functionalization::impl::isFunctionalTensor(tensor) << std::endl;
if (!tensor.defined()) return tensor;
auto out = at::native::empty_strided_meta_symint(
tensor.sym_sizes(), tensor.sym_strides(),
Expand Down Expand Up @@ -458,6 +460,7 @@ at::Tensor& XLANativeFunctions::_amp_update_scale_(at::Tensor& current_scale,
at::Tensor XLANativeFunctions::_copy_from(const at::Tensor& self,
const at::Tensor& dst,
bool non_blocking) {
std::cout << "WONJOO: at aten_xla_type.cpp, _copy_from" << std::endl;
TORCH_LAZY_FN_COUNTER("xla::");
auto dst_tensor = bridge::TryGetXlaTensor(dst);
auto self_tensor = bridge::TryGetXlaTensor(self);
Expand Down Expand Up @@ -652,6 +655,7 @@ at::Tensor XLANativeFunctions::argmin(const at::Tensor& self,
at::Tensor XLANativeFunctions::as_strided_copy(
const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride,
c10::optional<int64_t> storage_offset) {
std::cout << "WONJOO: at aten_xla_type.cpp, as_strided_copy1" << std::endl;
TORCH_LAZY_FN_COUNTER("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
auto xsize = XlaHelpers::I64List(size);
Expand All @@ -671,6 +675,7 @@ at::Tensor XLANativeFunctions::as_strided_scatter(
const at::Tensor& base, const at::Tensor& mutated_view,
at::IntArrayRef size, at::IntArrayRef stride,
c10::optional<int64_t> storage_offset) {
std::cout << "WONJOO: at aten_xla_type.cpp, as_strided_scatter1" << std::endl;
TORCH_LAZY_FN_COUNTER("xla::");
auto base_ = bridge::GetXlaTensor(base);
auto xsize = XlaHelpers::I64List(size);
Expand Down Expand Up @@ -1148,6 +1153,7 @@ at::Tensor XLANativeFunctions::empty_symint(
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory,
c10::optional<at::MemoryFormat> /* memory_format */) {
std::cout << "WONJOO: at XLANativeFunctions::empty_symint" << std::endl;
TORCH_LAZY_FN_COUNTER("xla::");
auto size = c10::asIntArrayRefSlow(sym_size);
// PT empty*() are optimizations to avoid initializing the data when it is
Expand Down Expand Up @@ -1381,6 +1387,8 @@ at::Tensor& XLANativeFunctions::index_fill_(at::Tensor& self, int64_t dim,
at::Tensor& XLANativeFunctions::index_put_(
at::Tensor& self, const c10::List<c10::optional<at::Tensor>>& indices,
const at::Tensor& values, bool accumulate) {
std::cout << "WONJOO: at aten_xla_type.cpp, input_put_1" << std::endl;
std::cout << "WONJOO: at aten_xla_type.cpp, input_put_2, self.is_functional=" << at::functionalization::impl::isFunctionalTensor(self) << std::endl;
TORCH_LAZY_FN_COUNTER("xla::");
XLA_CHECK(self.scalar_type() == values.scalar_type());
CanonicalIndexInfo canonical_index_info =
Expand All @@ -1399,6 +1407,7 @@ at::Tensor& XLANativeFunctions::index_put_(
canonical_index_info.start_dim,
bridge::GetOrCreateXlaTensor(values, *device), accumulate,
canonical_index_info.result_permutation);
std::cout << "WONJOO: at aten_xla_type.cpp, input_put_3, self.is_functional=" << at::functionalization::impl::isFunctionalTensor(self) << std::endl;
return self;
}

Expand Down Expand Up @@ -1468,17 +1477,18 @@ at::Tensor XLANativeFunctions::lerp(const at::Tensor& self,
bridge::GetXlaTensor(self), bridge::GetXlaTensor(end), weight));
}

at::Tensor XLANativeFunctions::lift_fresh(const at::Tensor& self) {
std::cout << "WONJOO: at aten_xla_type.cpp, lift_fresh" << std::endl;
return at::functionalization::impl::to_functional_tensor(self);
// return at::functionalization::functionalize_aten_op<ATEN_OP(
// lift_fresh)>::call(self);
at::Tensor XLANativeFunctions::lift(const at::Tensor& tensor) {
std::cout << "WONJOO: at XLANativeFunctions::lift" << std::endl;
TORCH_INTERNAL_ASSERT(
!at::functionalization::impl::isFunctionalTensor(tensor));
return at::functionalization::impl::to_functional_tensor(tensor);
}

at::Tensor XLANativeFunctions::lift_fresh_copy(const at::Tensor& self) {
std::cout << "WONJOO: at aten_xla_type.cpp, lift_fresh_copy" << std::endl;
return at::functionalization::functionalize_aten_op<ATEN_OP(
lift_fresh_copy)>::call(self);
at::Tensor XLANativeFunctions::lift_fresh(const at::Tensor& tensor) {
std::cout << "WONJOO: at XLANativeFunctions::lift_fresh" << std::endl;
TORCH_INTERNAL_ASSERT(
!at::functionalization::impl::isFunctionalTensor(tensor));
return at::functionalization::impl::to_functional_tensor(tensor);
}

at::Tensor XLANativeFunctions::linspace(const at::Scalar& start,
Expand Down Expand Up @@ -3114,6 +3124,7 @@ at::Tensor XLANativeFunctions::block_diag(at::TensorList tensors) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
block_diag)>::call(tensors);
}

at::Tensor XLANativeFunctions::new_empty_strided_symint(
const at::Tensor& self, at::SymIntArrayRef size, at::SymIntArrayRef stride,
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
Expand Down
2 changes: 1 addition & 1 deletion xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ supported:
- leaky_relu_backward
- lerp.Scalar
- lerp.Tensor
- lift
- lift_fresh
- lift_fresh_copy
- linspace
- log
- log1p
Expand Down

0 comments on commit dc3ed92

Please sign in to comment.