Skip to content

Commit

Permalink
[AutoParallel] slice support parallel part1 (#58837)
Browse files Browse the repository at this point in the history
* slice support parallel part1
  • Loading branch information
wanghuancoder authored Nov 9, 2023
1 parent 383ffe3 commit 980e284
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
16 changes: 10 additions & 6 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1319,8 +1319,8 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
"tensor %s has not been initialized, we can only slice initialized "
"tensor please init it first with numpy or other tensor.",
self->tensor.name()));
auto tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
ParseIndexingSlice(tensor,

ParseIndexingSlice(self->tensor.dims(),
_index,
&slice_axes,
&slice_starts,
Expand Down Expand Up @@ -1388,7 +1388,7 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
// NOTE(zoooo0820): When all axes are decreased, the output will be 1-D
// with FLAGS_set_to_1d=True. In this case, one `None` should be pop out,
// otherwise the output shape will be not correct.
if (static_cast<int>(decrease_axis.size()) == tensor->dims().size()) {
if (static_cast<int>(decrease_axis.size()) == self->tensor.dims().size()) {
VLOG(1)
<< "Warning: In Tensor '__getitem__', if the number of scalar "
"elements "
Expand Down Expand Up @@ -1570,8 +1570,6 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
EAGER_TRY
VLOG(4) << "Call __setitem_eager_tensor";

auto self_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());

PyObject* _index = PyTuple_GET_ITEM(args, 0);
PyObject* value_obj = PyTuple_GET_ITEM(args, 1);
// NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New
Expand Down Expand Up @@ -1609,7 +1607,7 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
std::vector<int64_t> list_select_idxs;
// if index is a list, list_select_flag will be true
bool list_select_flag = false;
ParseIndexingSlice(self_tensor,
ParseIndexingSlice(self->tensor.dims(),
index_ptr,
&axes,
&starts,
Expand Down Expand Up @@ -1775,6 +1773,12 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
}
}
} else {
PADDLE_ENFORCE_EQ(self->tensor.is_dense_tensor(),
true,
platform::errors::InvalidArgument(
"This setitem mode only support DenseTensor."));
auto self_tensor =
static_cast<phi::DenseTensor*>(self->tensor.impl().get());
auto self_numpy = TensorToPyArray(*self_tensor, true);
VLOG(4) << "parse_index is false";
if (PyCheckTensor(_index)) {
Expand Down
8 changes: 2 additions & 6 deletions paddle/fluid/pybind/slice_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ static int _PySlice_GetIndices(PySliceObject* r,
return 0;
}

static void ParseIndexingSlice(phi::DenseTensor* tensor,
static void ParseIndexingSlice(phi::DDim shape,
PyObject* _index,
std::vector<int>* slice_axes,
std::vector<int>* slice_starts,
Expand All @@ -164,11 +164,7 @@ static void ParseIndexingSlice(phi::DenseTensor* tensor,
VLOG(4) << "Call Py_DECREF";
}
});
PADDLE_ENFORCE_EQ(
tensor->IsInitialized(),
true,
platform::errors::InvalidArgument("tensor has not been initialized"));
const auto& shape = tensor->dims();

const int rank = shape.size();
const int size = PyTuple_GET_SIZE(index);

Expand Down

0 comments on commit 980e284

Please sign in to comment.