Skip to content

Commit

Permalink
Merge pull request #8135 from JiayiFeng/dev_make_VarDesc_supporting_m…
Browse files Browse the repository at this point in the history
…ultiple_tensor

Add type `Reader` for `VarDesc`
  • Loading branch information
JiayiFeng authored Feb 6, 2018
2 parents 445c74c + e5227c2 commit c8ba6d5
Show file tree
Hide file tree
Showing 9 changed files with 244 additions and 24 deletions.
4 changes: 2 additions & 2 deletions paddle/framework/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ ParamGradInfoMap AppendBackward(
auto root_block = program_desc.MutableBlock(root_block_idx);

std::string fill_one_op_out = GradVarName(target.Name());
bool is_scalar = target.Shape() == std::vector<int64_t>{1};
bool is_scalar = target.GetShape() == std::vector<int64_t>{1};
PADDLE_ENFORCE(is_scalar, "target should be scalar");
VLOG(3) << "backward from loss=" << target.Name()
<< " data_type=" << target.GetDataType();
Expand Down Expand Up @@ -565,7 +565,7 @@ ParamGradInfoMap AppendBackward(

auto var = root_block->Var(fill_one_op_out);
var->SetDataType(target.GetDataType());
var->SetShape(target.Shape());
var->SetShape(target.GetShape());
auto& target_grad = retv[target.Name()];
target_grad.name_ = fill_one_op_out;
target_grad.block_idx_ = root_block_idx;
Expand Down
10 changes: 7 additions & 3 deletions paddle/framework/framework.proto
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ message LoDTensorArrayDesc {
optional int32 lod_level = 2 [ default = 0 ];
}

message Reader { repeated LoDTensorDesc lod_tensor = 1; }

message VarDesc {
enum VarType {
LOD_TENSOR = 1;
Expand All @@ -126,13 +128,15 @@ message VarDesc {
LOD_RANK_TABLE = 6;
LOD_TENSOR_ARRAY = 7;
PLACE_LIST = 8;
READER = 9;
}
required string name = 1;
required VarType type = 2;
optional LoDTensorDesc lod_tensor = 3;
optional TensorDesc selected_rows = 4;
optional bool persistable = 3 [ default = false ];
optional LoDTensorDesc lod_tensor = 4;
optional TensorDesc selected_rows = 5;
optional LoDTensorArrayDesc tensor_array = 6;
optional bool persistable = 5 [ default = false ];
optional Reader reader = 7;
}

message BlockDesc {
Expand Down
4 changes: 2 additions & 2 deletions paddle/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -458,11 +458,11 @@ DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
try {
auto shape = var->Shape();
auto shape = var->GetShape();
if (shape.empty()) {
return framework::make_ddim({0UL});
} else {
return framework::make_ddim(var->Shape());
return framework::make_ddim(var->GetShape());
}
} catch (...) {
VLOG(5) << "GetDim of variable " << name << " error";
Expand Down
4 changes: 2 additions & 2 deletions paddle/framework/program_desc_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ TEST(ProgramDesc, copy_ctor) {
ASSERT_NE(copy, var_before);
ASSERT_EQ(copy->Name(), var_before->Name());
ASSERT_EQ(copy->GetType(), var_before->GetType());
ASSERT_EQ(copy->Shape(), var_before->Shape());
ASSERT_EQ(copy->GetShape(), var_before->GetShape());
ASSERT_EQ(copy->Proto()->SerializeAsString(),
var_before->Proto()->SerializeAsString());
};
Expand Down Expand Up @@ -117,7 +117,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
ASSERT_NE(restored, var_before);
ASSERT_EQ(restored->Name(), var_before->Name());
ASSERT_EQ(restored->GetType(), var_before->GetType());
ASSERT_EQ(restored->Shape(), var_before->Shape());
ASSERT_EQ(restored->GetShape(), var_before->GetShape());
ASSERT_EQ(restored->Proto()->SerializeAsString(),
var_before->Proto()->SerializeAsString());
};
Expand Down
174 changes: 163 additions & 11 deletions paddle/framework/var_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,91 @@ void VarDesc::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
}

void VarDesc::SetTensorDescNum(size_t num) {
switch (desc_.type()) {
case proto::VarDesc::READER: {
auto *lod_tensors_ptr = desc_.mutable_reader()->mutable_lod_tensor();
lod_tensors_ptr->Clear();
for (size_t i = 0; i < num; ++i) {
lod_tensors_ptr->Add();
}
return;
} break;
default:
PADDLE_THROW(
"Setting 'sub_tensor_number' is not supported by the type of var %s.",
this->Name());
}
}

size_t VarDesc::GetTensorDescNum() const {
switch (desc_.type()) {
case proto::VarDesc::READER:
return desc_.reader().lod_tensor_size();
break;
default:
PADDLE_THROW(
"Getting 'sub_tensor_number' is not supported by the type of var %s.",
this->Name());
}
}

void VarDesc::SetShapes(
const std::vector<std::vector<int64_t>> &multiple_dims) {
PADDLE_ENFORCE_EQ(multiple_dims.size(), GetTensorDescNum(),
"The number of given shapes(%d) doesn't equal to the "
"number of sub tensor.",
multiple_dims.size(), GetTensorDescNum());
std::vector<proto::TensorDesc *> tensors = mutable_tensor_descs();
for (size_t i = 0; i < multiple_dims.size(); ++i) {
VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
}
}

std::vector<int64_t> VarDesc::GetShape() const {
return RepeatedToVector(tensor_desc().dims());
}

std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
std::vector<proto::TensorDesc> descs = tensor_descs();
std::vector<std::vector<int64_t>> res;
res.reserve(descs.size());
for (const auto &tensor_desc : descs) {
res.push_back(RepeatedToVector(tensor_desc.dims()));
}
return res;
}

void VarDesc::SetDataType(proto::DataType data_type) {
mutable_tensor_desc()->set_data_type(data_type);
}

std::vector<int64_t> VarDesc::Shape() const {
return RepeatedToVector(tensor_desc().dims());
void VarDesc::SetDataTypes(
const std::vector<proto::DataType> &multiple_data_type) {
PADDLE_ENFORCE_EQ(multiple_data_type.size(), GetTensorDescNum(),
"The number of given data types(%d) doesn't equal to the "
"number of sub tensor.",
multiple_data_type.size(), GetTensorDescNum());
std::vector<proto::TensorDesc *> tensor_descs = mutable_tensor_descs();
for (size_t i = 0; i < multiple_data_type.size(); ++i) {
tensor_descs[i]->set_data_type(multiple_data_type[i]);
}
}

proto::DataType VarDesc::GetDataType() const {
return tensor_desc().data_type();
}

std::vector<proto::DataType> VarDesc::GetDataTypes() const {
std::vector<proto::TensorDesc> descs = tensor_descs();
std::vector<proto::DataType> res;
res.reserve(descs.size());
for (const auto &tensor_desc : descs) {
res.push_back(tensor_desc.data_type());
}
return res;
}

void VarDesc::SetLoDLevel(int32_t lod_level) {
switch (desc_.type()) {
case proto::VarDesc::LOD_TENSOR:
Expand All @@ -47,8 +120,28 @@ void VarDesc::SetLoDLevel(int32_t lod_level) {
desc_.mutable_tensor_array()->set_lod_level(lod_level);
break;
default:
PADDLE_THROW("Tensor type=%d does not support LoDLevel",
desc_.tensor_array().lod_level());
PADDLE_THROW(
"Setting 'lod_level' is not supported by the type of var %s.",
this->Name());
}
}

void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
PADDLE_ENFORCE_EQ(multiple_lod_level.size(), GetTensorDescNum(),
"The number of given data types(%d) doesn't equal to the "
"number of sub tensor.",
multiple_lod_level.size(), GetTensorDescNum());
switch (desc_.type()) {
case proto::VarDesc::READER: {
size_t i = 0;
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) {
lod_tensor.set_lod_level(multiple_lod_level[i++]);
}
} break;
default:
PADDLE_THROW(
"Setting 'lod_levels' is not supported by the type of var %s.",
this->Name());
}
}

Expand All @@ -59,13 +152,31 @@ int32_t VarDesc::GetLoDLevel() const {
case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.tensor_array().lod_level();
default:
PADDLE_THROW("Tensor type=%d does not support LoDLevel",
desc_.tensor_array().lod_level());
PADDLE_THROW(
"Getting 'lod_level' is not supported by the type of var %s.",
this->Name());
}
}

std::vector<int32_t> VarDesc::GetLoDLevels() const {
std::vector<int32_t> res;
switch (desc_.type()) {
case proto::VarDesc::READER:
res.reserve(desc_.reader().lod_tensor_size());
for (auto &lod_tensor : desc_.reader().lod_tensor()) {
res.push_back(lod_tensor.lod_level());
}
return res;
break;
default:
PADDLE_THROW(
"Getting 'lod_levels' is not supported by the type of var %s.",
this->Name());
}
}

const proto::TensorDesc &VarDesc::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type");
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
switch (desc_.type()) {
case proto::VarDesc::SELECTED_ROWS:
return desc_.selected_rows();
Expand All @@ -74,13 +185,32 @@ const proto::TensorDesc &VarDesc::tensor_desc() const {
case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.tensor_array().tensor();
default:
PADDLE_THROW("The type of var %s is unsupported.", this->Name());
PADDLE_THROW(
"Getting 'tensor_desc' is not supported by the type of var %s.",
this->Name());
}
}

std::vector<proto::TensorDesc> VarDesc::tensor_descs() const {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
std::vector<proto::TensorDesc> res;
res.reserve(GetTensorDescNum());
switch (desc_.type()) {
case proto::VarDesc::READER:
for (const auto &lod_tensor : desc_.reader().lod_tensor()) {
res.push_back(lod_tensor.tensor());
}
return res;
default:
PADDLE_THROW(
"Getting 'tensor_descs' is not supported by the type of var "
"%s.",
this->Name());
}
}

proto::TensorDesc *VarDesc::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(),
"invoke MutableTensorDesc must after set type");
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
switch (desc_.type()) {
case proto::VarDesc::SELECTED_ROWS:
return desc_.mutable_selected_rows();
Expand All @@ -89,8 +219,30 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() {
case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.mutable_tensor_array()->mutable_tensor();
default:
PADDLE_THROW("Unexpected branch.");
PADDLE_THROW(
"Getting 'mutable_tensor_desc' is not supported by the type of var "
"%s.",
this->Name());
}
}

std::vector<proto::TensorDesc *> VarDesc::mutable_tensor_descs() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
std::vector<proto::TensorDesc *> res;
res.reserve(GetTensorDescNum());
switch (desc_.type()) {
case proto::VarDesc::READER:
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) {
res.push_back(lod_tensor.mutable_tensor());
}
return res;
default:
PADDLE_THROW(
"Getting 'tensor_descs' is not supported by the type of var "
"%s.",
this->Name());
}
}

} // namespace framework
} // namespace paddle
20 changes: 19 additions & 1 deletion paddle/framework/var_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,34 @@ class VarDesc {

void SetName(std::string name) { desc_.set_name(name); }

void SetTensorDescNum(size_t num);

size_t GetTensorDescNum() const;

void SetShape(const std::vector<int64_t> &dims);

void SetShapes(const std::vector<std::vector<int64_t>> &multiple_dims);

std::vector<int64_t> GetShape() const;

std::vector<std::vector<int64_t>> GetShapes() const;

void SetDataType(proto::DataType data_type);

std::vector<int64_t> Shape() const;
void SetDataTypes(const std::vector<proto::DataType> &multiple_data_type);

proto::DataType GetDataType() const;

std::vector<proto::DataType> GetDataTypes() const;

void SetLoDLevel(int32_t lod_level);

void SetLoDLevels(const std::vector<int32_t> &multiple_lod_level);

int32_t GetLoDLevel() const;

std::vector<int32_t> GetLoDLevels() const;

proto::VarDesc::VarType GetType() const;

void SetType(proto::VarDesc::VarType type);
Expand All @@ -90,7 +106,9 @@ class VarDesc {

private:
const proto::TensorDesc &tensor_desc() const;
std::vector<proto::TensorDesc> tensor_descs() const;
proto::TensorDesc *mutable_tensor_desc();
std::vector<proto::TensorDesc *> mutable_tensor_descs();

proto::VarDesc desc_;
};
Expand Down
2 changes: 1 addition & 1 deletion paddle/inference/io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void LoadPersistables(framework::Executor& executor,
VLOG(3) << "parameter's name: " << var->Name();

framework::VarDesc* new_var = load_block->Var(var->Name());
new_var->SetShape(var->Shape());
new_var->SetShape(var->GetShape());
new_var->SetDataType(var->GetDataType());
new_var->SetType(var->GetType());
new_var->SetLoDLevel(var->GetLoDLevel());
Expand Down
14 changes: 12 additions & 2 deletions paddle/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,20 @@ void BindVarDsec(py::module &m) {
py::return_value_policy::reference)
.def("set_name", &VarDesc::SetName)
.def("set_shape", &VarDesc::SetShape)
.def("set_shapes", &VarDesc::SetShapes)
.def("set_dtype", &VarDesc::SetDataType)
.def("shape", &VarDesc::Shape, py::return_value_policy::reference)
.def("set_dtypes", &VarDesc::SetDataTypes)
.def("set_tensor_num", &VarDesc::SetTensorDescNum)
.def("tensor_num", &VarDesc::GetTensorDescNum)
.def("shape", &VarDesc::GetShape, py::return_value_policy::reference)
.def("shapes", &VarDesc::GetShapes, py::return_value_policy::reference)
.def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference)
.def("dtypes", &VarDesc::GetDataTypes, py::return_value_policy::reference)
.def("lod_level", &VarDesc::GetLoDLevel)
.def("lod_levels", &VarDesc::GetLoDLevels,
py::return_value_policy::reference)
.def("set_lod_level", &VarDesc::SetLoDLevel)
.def("set_lod_levels", &VarDesc::SetLoDLevels)
.def("type", &VarDesc::GetType)
.def("set_type", &VarDesc::SetType)
.def("serialize_to_string", SerializeMessage<VarDesc>)
Expand All @@ -233,7 +242,8 @@ void BindVarDsec(py::module &m) {
.value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES)
.value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY)
.value("PLACE_LIST", proto::VarDesc::PLACE_LIST);
.value("PLACE_LIST", proto::VarDesc::PLACE_LIST)
.value("READER", proto::VarDesc::READER);
}

void BindOpDesc(py::module &m) {
Expand Down
Loading

0 comments on commit c8ba6d5

Please sign in to comment.