diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java index b9dad479b86..9c5aa7b3c07 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java @@ -85,10 +85,26 @@ protected NDList forwardInternal( Map container = new ConcurrentHashMap<>(); // forward try (OrtNDManager sub = (OrtNDManager) manager.newSubManager()) { - // feed data in to match names - for (int i = 0; i < inputNames.size(); ++i) { - OrtNDArray ortNDArray = sub.from(inputs.get(i)); - container.put(inputNames.get(i), ortNDArray.getTensor()); + // If input data has name + if (inputs.get(0).getName() != null) { + for (NDArray input : inputs) { + String name = input.getName(); + if (name == null) { + throw new IllegalArgumentException( + "All or none of input tensors must have a name."); + } + if (!inputNames.contains(name)) { + throw new IllegalArgumentException("Invalid input tensor name: " + name); + } + OrtNDArray ortNDArray = sub.from(input); + container.put(name, ortNDArray.getTensor()); + } + } else { + // feed data in to match names + for (int i = 0; i < inputNames.size(); ++i) { + OrtNDArray ortNDArray = sub.from(inputs.get(i)); + container.put(inputNames.get(i), ortNDArray.getTensor()); + } } OrtSession.Result results = session.run(container); @@ -100,6 +116,16 @@ protected NDList forwardInternal( } } + /** {@inheritDoc} */ + @Override + public PairList describeInput() { + PairList result = new PairList<>(); + for (String name : session.getInputNames()) { + result.add(name, null); + } + return result; + } + private NDList evaluateOutput(OrtSession.Result results) { NDList output = new NDList(); for (Map.Entry r : results) {