Skip to content

Commit

Permalink
port embedding bag check for 2.1.40 (#4504)(#4482)
Browse files Browse the repository at this point in the history
* add check before embedding bag

* fix for clang-format

* add headers
  • Loading branch information
leizhenyuan authored Jul 17, 2024
1 parent 8b74d6c commit 5717479
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions csrc/gpu/aten/operators/EmbeddingBag.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include <ATen/ATen.h>
#include <torch/torch.h>

#include <c10/util/Exception.h>
#include <core/Device.h>
#include <core/Memory.h>
#include <runtime/Utils.h>
#include <torch/torch.h>
#include <utils/DPCPP.h>

#include "BitonicMergeSort.h"
Expand Down Expand Up @@ -1182,6 +1182,21 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag(
const c10::optional<at::Tensor>& per_sample_weights_opt,
bool include_last_offset,
int64_t padding_idx) {
TORCH_CHECK(
indices.dim() == 1 || indices.dim() == 2,
"input has to be a 1D or 2D Tensor, but got Tensor of dimension ",
indices.dim());
if (indices.dim() == 1) {
TORCH_CHECK(
offsets.dim() == 1,
"offsets has to be a 1D Tensor, but got Tensor of dimension ",
offsets.dim());
}
TORCH_CHECK(
weight.dim() == 2,
"weight has to be a 2D Tensor, but got Tensor of dimension ",
weight.dim());

c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned =
at::borrow_from_optional_tensor(per_sample_weights_opt);
const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
Expand Down Expand Up @@ -1234,6 +1249,20 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_forward_only(
const c10::optional<Tensor>& per_sample_weights_opt,
bool include_last_offset,
int64_t padding_idx) {
TORCH_CHECK(
indices.dim() == 1 || indices.dim() == 2,
"input has to be a 1D or 2D Tensor, but got Tensor of dimension ",
indices.dim());
if (indices.dim() == 1) {
TORCH_CHECK(
offsets.dim() == 1,
"offsets has to be a 1D Tensor, but got Tensor of dimension ",
offsets.dim());
}
TORCH_CHECK(
weight.dim() == 2,
"weight has to be a 2D Tensor, but got Tensor of dimension ",
weight.dim());
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned =
at::borrow_from_optional_tensor(per_sample_weights_opt);
Expand Down

0 comments on commit 5717479

Please sign in to comment.