Skip to content

Commit

Permalink
fix param parsing issue when layer/blob name exceeds 255
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhangGe6 committed Oct 4, 2022
1 parent 59a6fa3 commit bc920ea
Showing 1 changed file with 38 additions and 11 deletions.
49 changes: 38 additions & 11 deletions tools/onnx/onnx2ncnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2930,6 +2930,30 @@ static void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, std::map<
}
}

// truncate layer/blob names when they exceed 255, which is the upper limit when parsing param in src/net.cpp
static std::string trunc_name(std::string name)
{
static int trunc_idx = 0;
static std::map<std::string, std::string> name_trunc_map;

const int max_len = 255;
if (name.size() <= max_len)
{
return name;
}
if (name_trunc_map.count(name))
{
return name_trunc_map[name];
}

std::string concat_name = name + "_t" + std::to_string(trunc_idx);
std::string trunc_name = concat_name.substr(concat_name.size() - max_len);
trunc_idx += 1;
name_trunc_map[name] = trunc_name;

return trunc_name;
}

int main(int argc, char** argv)
{
if (!(argc == 2 || argc == 4))
Expand Down Expand Up @@ -3433,7 +3457,7 @@ int main(int argc, char** argv)
if (weights.find(input_name) != weights.end())
continue;

fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", input_name.c_str(), input_name.c_str());
fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", trunc_name(input_name).c_str(), trunc_name(input_name).c_str());

int refcount = node_reference[input_name];
if (refcount <= 1)
Expand All @@ -3444,11 +3468,12 @@ int main(int argc, char** argv)
char splitname[256];
sprintf(splitname, "splitncnn_input%d", j);
fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
fprintf(pp, " %s", input_name.c_str());
fprintf(pp, " %s", trunc_name(input_name).c_str());

for (int k = 0; k < refcount; k++)
{
fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
std::string split_name = input_name + "_splitncnn_" + std::to_string(k);
fprintf(pp, " %s", trunc_name(split_name).c_str());
}
fprintf(pp, "\n");
}
Expand All @@ -3464,7 +3489,7 @@ int main(int argc, char** argv)
continue;
}

fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", input_name.c_str(), input_name.c_str());
fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", trunc_name(input_name).c_str(), trunc_name(input_name).c_str());

const onnx::TensorProto& M = weights[input_name];

Expand Down Expand Up @@ -3513,11 +3538,12 @@ int main(int argc, char** argv)
sprintf(splitname, "splitncnn_%d", internal_split);
fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);

fprintf(pp, " %s", input_name.c_str());
fprintf(pp, " %s", trunc_name(input_name).c_str());

for (int k = 0; k < refcount; k++)
{
fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
std::string split_name = input_name + "_splitncnn_" + std::to_string(k);
fprintf(pp, " %s", trunc_name(split_name).c_str());
}
fprintf(pp, "\n");

Expand Down Expand Up @@ -3939,7 +3965,7 @@ int main(int argc, char** argv)
fprintf(pp, "%-16s", op.c_str());
}

fprintf(pp, " %-24s %d %d", name.c_str(), input_size, output_size);
fprintf(pp, " %-24s %d %d", trunc_name(name).c_str(), input_size, output_size);

for (int j = 0; j < (int)node.input_size(); j++)
{
Expand All @@ -3966,14 +3992,14 @@ int main(int argc, char** argv)
input_name = input_name + splitsuffix;
}

fprintf(pp, " %s", input_name.c_str());
fprintf(pp, " %s", trunc_name(input_name).c_str());
}

for (int j = 0; j < output_size; j++)
{
const std::string& output_name = node.output(j);

fprintf(pp, " %s", output_name.c_str());
fprintf(pp, " %s", trunc_name(output_name).c_str());
}

if (op == "Abs")
Expand Down Expand Up @@ -6064,11 +6090,12 @@ int main(int argc, char** argv)
sprintf(splitname, "splitncnn_%d", internal_split);
fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);

fprintf(pp, " %s", output_name.c_str());
fprintf(pp, " %s", trunc_name(output_name).c_str());

for (int k = 0; k < refcount; k++)
{
fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k);
std::string split_name = output_name + "_splitncnn_" + std::to_string(k);
fprintf(pp, " %s", trunc_name(split_name).c_str());
}
fprintf(pp, "\n");

Expand Down

0 comments on commit bc920ea

Please sign in to comment.