Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNXRuntime] fix naming #1608

Merged
merged 2 commits into from
Apr 28, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,26 @@ protected NDList forwardInternal(
Map<String, OnnxTensor> container = new ConcurrentHashMap<>();
// forward
try (OrtNDManager sub = (OrtNDManager) manager.newSubManager()) {
// feed data in to match names
for (int i = 0; i < inputNames.size(); ++i) {
OrtNDArray ortNDArray = sub.from(inputs.get(i));
container.put(inputNames.get(i), ortNDArray.getTensor());
// If input data has name
if (inputs.get(0).getName() != null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not assume all input has name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User should either provide all names to the NDArray or provides no name

for (NDArray input : inputs) {
String name = input.getName();
if (name == null) {
throw new IllegalArgumentException(
"All or none of input tensors must have a name.");
}
if (!inputNames.contains(name)) {
throw new IllegalArgumentException("Invalid input tensor name: " + name);
}
OrtNDArray ortNDArray = sub.from(input);
container.put(name, ortNDArray.getTensor());
}
} else {
// feed data in to match names
for (int i = 0; i < inputNames.size(); ++i) {
OrtNDArray ortNDArray = sub.from(inputs.get(i));
container.put(inputNames.get(i), ortNDArray.getTensor());
}
}

OrtSession.Result results = session.run(container);
Expand All @@ -100,6 +116,16 @@ protected NDList forwardInternal(
}
}

/** {@inheritDoc} */
@Override
public PairList<String, Shape> describeInput() {
PairList<String, Shape> result = new PairList<>();
for (String name : session.getInputNames()) {
result.add(name, null);
}
return result;
}

private NDList evaluateOutput(OrtSession.Result results) {
NDList output = new NDList();
for (Map.Entry<String, OnnxValue> r : results) {
Expand Down