Skip to content

Commit

Permalink
Handle combined case of 1.0 save format
Browse files Browse the repository at this point in the history
  • Loading branch information
zhhsplendid committed Aug 3, 2023
1 parent 7db4ce8 commit ae92b26
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 15 deletions.
12 changes: 2 additions & 10 deletions paddle/cinn/frontend/paddle/model_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,8 @@ void LoadModelPb(const std::string &model_dir,
VLOG(3) << "param_file is: " << param_file;
// Load model
VLOG(4) << "Start load model program...";
std::string prog_path = model_dir + "/__model__";
std::string param_file_temp = param_file;
if (combined) {
// In combined case, prog path is saved as [model_dir].pdmodel
// Param file is saved as [model_dir].pdiparams
// For example, /path/model.pdmodel , /path/model.pdiparams
// In this case, model_file = ".pdmodel", param_file = ".pdiparams"
prog_path = model_dir + model_file;
param_file_temp = model_dir + param_file;
}
std::string prog_path = model_dir + model_file;
std::string param_file_temp = model_dir + param_file;
framework_proto::ProgramDesc pb_proto_prog =
*LoadProgram(prog_path, model_from_memory);
pb::ProgramDesc pb_prog(&pb_proto_prog);
Expand Down
5 changes: 2 additions & 3 deletions paddle/cinn/frontend/paddle_model_convertor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,14 @@ Program PaddleModelConvertor::LoadModel(
paddle::cpp::ProgramDesc program_desc;
if (FLAGS_cinn_infer_model_version < 2.0) {
paddle::LoadModelPb(model_dir,
"__model__",
"",
"/__model__",
"/params",
scope_.get(),
&program_desc,
is_combined,
false,
target_);
} else {
is_combined = true;
paddle::LoadModelPb(model_dir,
".pdmodel",
".pdiparams",
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/frontend/paddle_model_to_program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,8 @@ std::unique_ptr<Program> PaddleModelToProgram::operator()(
paddle::cpp::ProgramDesc program_desc;
if (FLAGS_cinn_infer_model_version < 2.0) {
paddle::LoadModelPb(model_dir,
"__model__",
"",
"/__model__",
"/params",
scope_,
&program_desc,
is_combined,
Expand Down

0 comments on commit ae92b26

Please sign in to comment.