Skip to content

Commit

Permalink
Update XLATensor namespace calls to tensor_methods
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Dec 6, 2022
1 parent a81964f commit 9e5f141
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 35 deletions.
3 changes: 2 additions & 1 deletion torch_xla/csrc/aten_cpu_fallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ static std::unordered_map<std::string, ::xla::metrics::Counter*>
_cpu_fallback_counters;

void xla_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
std::cout << "WONJOO: at aten_cpu_fallback.cpp, xla_cpu_fallback1" << std::endl;
std::cout << "WONJOO: at aten_cpu_fallback.cpp, xla_cpu_fallback1"
<< std::endl;
XLA_FN_TRACK(3);
const auto name = c10::toString(op.operator_name());

Expand Down
22 changes: 15 additions & 7 deletions torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ std::vector<XLATensorPtr> GetXlaTensors(const at::ITensorListRef& tensors) {

torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber(
const at::Tensor& tensor, const torch::lazy::BackendDevice& device) {
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetXlaTensorOrCreateForWrappedNumber" << std::endl;
std::cout
<< "WONJOO: at aten_xla_bridge.cpp, GetXlaTensorOrCreateForWrappedNumber"
<< std::endl;
if (tensor.unsafeGetTensorImpl()->is_wrapped_number() ||
(tensor.dim() == 0 && tensor.numel() == 1)) {
return torch_xla::bridge::GetOrCreateXlaTensor(tensor, device);
Expand All @@ -116,7 +118,8 @@ torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber(

XLATensorPtr GetOrCreateXlaTensor(const at::Tensor& tensor,
const torch::lazy::BackendDevice& device) {
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetOrCreateXlaTensor" << std::endl;
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetOrCreateXlaTensor"
<< std::endl;
if (!tensor.defined()) {
return XLATensorPtr();
}
Expand All @@ -130,7 +133,8 @@ XLATensorPtr GetOrCreateXlaTensor(const at::Tensor& tensor,

XLATensorPtr GetOrCreateXlaTensor(const c10::optional<at::Tensor>& tensor,
const torch::lazy::BackendDevice& device) {
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetOrCreateXlaTensor" << std::endl;
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetOrCreateXlaTensor"
<< std::endl;
if (!IsDefined(tensor)) {
return XLATensorPtr();
}
Expand All @@ -142,7 +146,8 @@ XLATensorPtr GetOrCreateXlaTensor(const c10::optional<at::Tensor>& tensor,
std::vector<XLATensorPtr> GetOrCreateXlaTensors(
absl::Span<const at::Tensor> tensors,
const torch::lazy::BackendDevice& device) {
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetOrCreateXlaTensors" << std::endl;
std::cout << "WONJOO: at aten_xla_bridge.cpp, GetOrCreateXlaTensors"
<< std::endl;
std::vector<XLATensorPtr> xla_tensors;
for (const at::Tensor& tensor : tensors) {
xla_tensors.push_back(bridge::GetOrCreateXlaTensor(tensor, device));
Expand All @@ -151,7 +156,8 @@ std::vector<XLATensorPtr> GetOrCreateXlaTensors(
}

std::vector<at::Tensor> XlaCreateTensorList(const at::ITensorListRef& tensors) {
std::cout << "WONJOO: at aten_xla_bridge.cpp, XlaCreateTensorList" << std::endl;
std::cout << "WONJOO: at aten_xla_bridge.cpp, XlaCreateTensorList"
<< std::endl;
std::vector<at::Tensor> aten_xla_tensors(tensors.size());
std::vector<XLATensorPtr> xla_tensors;
// We need to separate out the defined tensors first, GetXlaTensor() doesn't
Expand Down Expand Up @@ -191,7 +197,8 @@ std::vector<at::Tensor> XlaCreateTensorList(const at::ITensorListRef& tensors) {

std::vector<c10::optional<at::Tensor>> XlaCreateOptTensorList(
const std::vector<c10::optional<at::Tensor>>& tensors) {
std::cout << "WONJOO: at aten_xla_bridge.cpp, XlaCreateOptTensorList" << std::endl;
std::cout << "WONJOO: at aten_xla_bridge.cpp, XlaCreateOptTensorList"
<< std::endl;
std::vector<c10::optional<at::Tensor>> opt_aten_xla_tensors(tensors.size());
std::vector<at::Tensor> materialized_tensors;
std::vector<bool> to_translate(tensors.size());
Expand Down Expand Up @@ -380,7 +387,8 @@ at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor) {

std::vector<at::Tensor> AtenFromXlaTensors(
absl::Span<const XLATensorPtr> xla_tensors) {
std::cout << "WONJOO: at aten_xla_bridge.cpp, AtenFromXlaTensors" << std::endl;
std::cout << "WONJOO: at aten_xla_bridge.cpp, AtenFromXlaTensors"
<< std::endl;
std::vector<at::Tensor> tensors;
tensors.reserve(xla_tensors.size());
for (auto& tensor : xla_tensors) {
Expand Down
58 changes: 31 additions & 27 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,10 +684,10 @@ at::Tensor XLANativeFunctions::as_strided_scatter(
storage_offset);
}
auto mutated_view_ = bridge::GetXlaTensor(mutated_view);
auto base_clone = XLATensor::clone(base_);
auto base_clone_slice = XLATensor::as_strided(
auto base_clone = tensor_methods::clone(base_);
auto base_clone_slice = tensor_methods::as_strided(
base_clone, xsize, xstride, XlaHelpers::I64Optional(storage_offset));
XLATensor::copy_(base_clone_slice, mutated_view_);
tensor_methods::copy_(base_clone_slice, mutated_view_);
return bridge::AtenFromXlaTensor(base_clone);
}

Expand Down Expand Up @@ -1041,9 +1041,10 @@ at::Tensor XLANativeFunctions::diagonal_scatter(const at::Tensor& base,
int64_t dim2) {
auto base_ = bridge::GetXlaTensor(base);
auto mutated_view_ = bridge::GetXlaTensor(mutated_view);
auto base_clone = XLATensor::clone(base_);
auto base_clone_slice = XLATensor::diagonal(base_clone, offset, dim1, dim2);
XLATensor::copy_(base_clone_slice, mutated_view_);
auto base_clone = tensor_methods::clone(base_);
auto base_clone_slice =
tensor_methods::diagonal(base_clone, offset, dim1, dim2);
tensor_methods::copy_(base_clone_slice, mutated_view_);
return bridge::AtenFromXlaTensor(base_clone);
}

Expand Down Expand Up @@ -2167,7 +2168,7 @@ at::Tensor& XLANativeFunctions::normal_(
}

at::Tensor XLANativeFunctions::permute_copy(const at::Tensor& self,
at::IntArrayRef dims) {
at::IntArrayRef dims) {
TORCH_LAZY_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(tensor_methods::permute(
bridge::GetXlaTensor(self), XlaHelpers::I64List(dims)));
Expand Down Expand Up @@ -2544,10 +2545,10 @@ at::Tensor XLANativeFunctions::scatter_add(const at::Tensor& self, int64_t dim,
}

at::Tensor XLANativeFunctions::select_copy(const at::Tensor& self, int64_t dim,
int64_t index) {
int64_t index) {
TORCH_LAZY_FN_COUNTER("xla::");
std::cout << "WONJOO: at XLANativeFunctions::select_copy1" << std::endl;
XLA_FN_COUNTER("xla::");
TORCH_LAZY_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
tensor_methods::select(bridge::GetXlaTensor(self), dim, index));
}
Expand All @@ -2564,9 +2565,9 @@ at::Tensor XLANativeFunctions::select_scatter(const at::Tensor& base,
<< std::endl;
auto base_ = bridge::GetXlaTensor(base);
auto mutated_view_ = bridge::GetXlaTensor(mutated_view);
auto base_clone = XLATensor::clone(base_);
auto base_clone_slice = XLATensor::select(base_clone, dim, index);
XLATensor::copy_(base_clone_slice, mutated_view_);
auto base_clone = tensor_methods::clone(base_);
auto base_clone_slice = tensor_methods::select(base_clone, dim, index);
tensor_methods::copy_(base_clone_slice, mutated_view_);
return bridge::AtenFromXlaTensor(base_clone);
}

Expand All @@ -2592,8 +2593,9 @@ at::Tensor XLANativeFunctions::sigmoid_backward(const at::Tensor& grad_output,
}

at::Tensor XLANativeFunctions::slice_copy(const at::Tensor& self, int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end, int64_t step) {
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step) {
TORCH_LAZY_FN_COUNTER("xla::");
int64_t start_val = start.has_value() ? start.value() : 0;
int64_t end_val = end.has_value() ? end.value() : INT64_MAX;
Expand All @@ -2606,12 +2608,12 @@ at::Tensor XLANativeFunctions::slice_scatter(
c10::optional<int64_t> start, c10::optional<int64_t> end, int64_t step) {
auto base_ = bridge::GetXlaTensor(base);
auto mutated_view_ = bridge::GetXlaTensor(mutated_view);
auto base_clone = XLATensor::clone(base_);
auto base_clone = tensor_methods::clone(base_);
int64_t start_val = start.has_value() ? start.value() : 0;
int64_t end_val = end.has_value() ? end.value() : INT64_MAX;
auto base_clone_slice =
XLATensor::slice(base_clone, dim, start_val, end_val, step);
XLATensor::copy_(base_clone_slice, mutated_view_);
tensor_methods::slice(base_clone, dim, start_val, end_val, step);
tensor_methods::copy_(base_clone_slice, mutated_view_);
return bridge::AtenFromXlaTensor(base_clone);
}

Expand Down Expand Up @@ -2683,8 +2685,8 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::sort(
}

std::vector<at::Tensor> XLANativeFunctions::split_copy(const at::Tensor& self,
int64_t split_size,
int64_t dim) {
int64_t split_size,
int64_t dim) {
TORCH_LAZY_FN_COUNTER("xla::");
auto xla_tensors =
tensor_methods::split(bridge::GetXlaTensor(self), split_size, dim);
Expand All @@ -2711,7 +2713,8 @@ at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self) {
tensor_methods::squeeze(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self, int64_t dim) {
at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self,
int64_t dim) {
TORCH_LAZY_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
tensor_methods::squeeze(bridge::GetXlaTensor(self), dim));
Expand Down Expand Up @@ -2882,8 +2885,8 @@ at::Tensor XLANativeFunctions::trace(const at::Tensor& self) {
tensor_methods::trace(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::transpose_copy(const at::Tensor& self, int64_t dim0,
int64_t dim1) {
at::Tensor XLANativeFunctions::transpose_copy(const at::Tensor& self,
int64_t dim0, int64_t dim1) {
TORCH_LAZY_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
tensor_methods::transpose(bridge::GetXlaTensor(self), dim0, dim1));
Expand All @@ -2903,7 +2906,7 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::triangular_solve(
}

std::vector<at::Tensor> XLANativeFunctions::unbind_copy(const at::Tensor& self,
int64_t dim) {
int64_t dim) {
TORCH_LAZY_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensors(
tensor_methods::unbind(bridge::GetXlaTensor(self), dim));
Expand All @@ -2923,7 +2926,8 @@ at::Tensor& XLANativeFunctions::uniform_(
return self;
}

at::Tensor XLANativeFunctions::unsqueeze_copy(const at::Tensor& self, int64_t dim) {
at::Tensor XLANativeFunctions::unsqueeze_copy(const at::Tensor& self,
int64_t dim) {
TORCH_LAZY_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
tensor_methods::unsqueeze(bridge::GetXlaTensor(self), dim));
Expand Down Expand Up @@ -3139,9 +3143,9 @@ at::Tensor XLANativeFunctions::pixel_unshuffle(const at::Tensor& self,
pixel_unshuffle)>::call(self, downscale_factor);
}

at::Tensor XLANativeFunctions::select_backward_symint(const at::Tensor& grad_output,
at::IntArrayRef input_sizes,
int64_t dim, int64_t index) {
at::Tensor XLANativeFunctions::select_backward_symint(
const at::Tensor& grad_output, c10::SymIntArrayRef input_sizes, int64_t dim,
c10::SymInt index) {
std::cout << "WONJOO: at XLANativeFunctions::select_backward" << std::endl;
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
select_backward)>::call(grad_output, input_sizes, dim, index);
Expand Down

0 comments on commit 9e5f141

Please sign in to comment.