-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ONNX Embedding Model Thread-Safety Issue #23555
Comments
I've done a bunch of multithreaded testing and multithreaded embedding runs without Spring AI, so I'm a bit surprised that they are managing to trigger this kind of error. It shouldn't be possible to crash the JVM using the ORT Java API, so something's definitely busted in the ORT Java API, just not sure what.
|
Hi @Craigacp, I will do my best to provide you with all of the information in your request below.
|
I've made this small harness which should reproduce the error, but it doesn't (at least not on my M4 Pro Mac). Could you run it in your Windows environment and see if it crashes? import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.io.File;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;
import java.util.stream.LongStream;
public class MultithreadingTest {
private static final Logger logger = Logger.getLogger(MultithreadingTest.class.getName());
private static OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions makeOpts() {
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
return opts;
}
OrtSession makeSession(OrtSession.SessionOptions opts) throws OrtException {
Path path = new File(this.getClass().getResource("/all-minilm-l6-v2.onnx").getFile()).toPath();
String modelPath = path.toString();
OrtSession session = env.createSession(modelPath, opts);
return session;
}
static ThreadFactory createThreadFactory() {
return (Runnable runnable) -> {
Thread thread = new Thread(runnable);
thread.setDaemon(true);
return thread;
};
}
public ThreadPoolExecutor createThreadPoolExecutor() {
RejectedExecutionHandler rejectedExecutionHandler = new ThreadPoolExecutor.AbortPolicy();
BlockingQueue<Runnable> queue = new LinkedBlockingQueue<>(25);
ThreadPoolExecutor executor = new ThreadPoolExecutor(
2, 4, 60, TimeUnit.SECONDS,
queue, createThreadFactory(), rejectedExecutionHandler);
return executor;
}
//@Test
public void arrayTest() throws OrtException, ExecutionException, InterruptedException {
long[][] ids = new long[][]{LongStream.range(100, 600).toArray()};
long[][] mask = new long[1][500];
Arrays.fill(mask[0], 1);
ThreadPoolExecutor executor = createThreadPoolExecutor();
OrtSession.SessionOptions opts = makeOpts();
try (OrtSession session = makeSession(opts)) {
opts.close();
Runnable r = () -> {
try (OnnxTensor inputIds = OnnxTensor.createTensor(env, ids);
OnnxTensor attentionMask = OnnxTensor.createTensor(env, mask)) {
Map<String, OnnxTensor> input = new HashMap<>();
input.put("input_ids", inputIds);
input.put("attention_mask", attentionMask);
try (OrtSession.Result result = session.run(input)) {
float[][][] output = (float[][][]) result.get(0).getValue();
logger.info("Output is ["+output.length+"]["+output[0].length+"]["+output[0][0].length+"]");
}
} catch (OrtException e) {
throw new RuntimeException(e);
}
};
List<Future<?>> futures = new ArrayList<>();
for (int i = 0; i < 25; i++) {
futures.add(executor.submit(r));
}
for (Future<?> f : futures) {
f.get();
}
}
logger.info("Submitted tasks");
executor.shutdown();
executor.awaitTermination(1, TimeUnit.MINUTES);
logger.info("Shutdown executor");
}
public static void main(String[] args) throws OrtException, ExecutionException, InterruptedException {
MultithreadingTest t = new MultithreadingTest();
t.arrayTest();
}
} I've not replicated the async behaviour, as I've never used Spring or its async stuff, if you can tell me how to modify the thread pool/submit then maybe it'll trigger it? |
Got it!
I used your code without changing anything. |
Hmm, your version of import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.io.File;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;
import java.util.stream.LongStream;
public class MultithreadingTest {
private static final Logger logger = Logger.getLogger(MultithreadingTest.class.getName());
private static OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions makeOpts() {
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
return opts;
}
OrtSession makeSession(OrtSession.SessionOptions opts) throws OrtException {
Path path = new File(this.getClass().getResource("/all-minilm-l6-v2.onnx").getFile()).toPath();
String modelPath = path.toString();
OrtSession session = env.createSession(modelPath, opts);
return session;
}
static ThreadFactory createThreadFactory() {
return (Runnable runnable) -> {
Thread thread = new Thread(runnable);
thread.setDaemon(true);
return thread;
};
}
public ThreadPoolExecutor createThreadPoolExecutor() {
RejectedExecutionHandler rejectedExecutionHandler = new ThreadPoolExecutor.AbortPolicy();
BlockingQueue<Runnable> queue = new LinkedBlockingQueue<>(25);
ThreadPoolExecutor executor = new ThreadPoolExecutor(
2, 4, 60, TimeUnit.SECONDS,
queue, createThreadFactory(), rejectedExecutionHandler);
return executor;
}
//@Test
public void arrayTest() throws OrtException, ExecutionException, InterruptedException {
long[][] ids = new long[][]{LongStream.range(100, 600).toArray()};
long[][] mask = new long[1][500];
Arrays.fill(mask[0], 1);
long[][] type = new long[1][500];
Arrays.fill(type[0], 0);
ThreadPoolExecutor executor = createThreadPoolExecutor();
OrtSession.SessionOptions opts = makeOpts();
try (OrtSession session = makeSession(opts)) {
opts.close();
Runnable r = () -> {
try (OnnxTensor inputIds = OnnxTensor.createTensor(env, ids);
OnnxTensor attentionMask = OnnxTensor.createTensor(env, mask);
OnnxTensor tokenType = OnnxTensor.createTensor(env, type)) {
Map<String, OnnxTensor> input = new HashMap<>();
input.put("input_ids", inputIds);
input.put("attention_mask", attentionMask);
input.put("token_type_ids", tokenType);
try (OrtSession.Result result = session.run(input)) {
float[][][] output = (float[][][]) result.get(0).getValue();
logger.info("Output is ["+output.length+"]["+output[0].length+"]["+output[0][0].length+"]");
}
} catch (OrtException e) {
throw new RuntimeException(e);
}
};
List<Future<?>> futures = new ArrayList<>();
for (int i = 0; i < 25; i++) {
futures.add(executor.submit(r));
}
for (Future<?> f : futures) {
f.get();
}
}
logger.info("Submitted tasks");
executor.shutdown();
executor.awaitTermination(1, TimeUnit.MINUTES);
logger.info("Shutdown executor");
}
public static void main(String[] args) throws OrtException, ExecutionException, InterruptedException {
MultithreadingTest t = new MultithreadingTest();
t.arrayTest();
}
} |
This code seems to work fine. I can provide you with a small Spring project in which I replicate the problem if you wish. I have configured the thread pool with the Spring classes, but I assume you have configured it the same way. |
Ok, so maybe the JVM crash is something to do with the error handling pathway when it's overloaded. I'm still confused by the logger error you get when running in async. If you have a demo to drive it in Spring with the async task scheduler that would be helpful. I can try to replicate it in a Windows environment and then get the debugger on it. |
Also, could you test it with the latest release (1.20.1)? |
There are two issues as I am seeing it.
And the second one: why throwing an exception in multithreaded test produces Access Violation. |
That runtime error comes from @alfredogangemi 's model being a little different from the one I was testing, but this is confusing to me:
I don't know how we can not have the default logger registered, the Java code shouldn't let you do anything without an |
Hi @Craigacp, |
The test project replicates the default logger exception for me when running in Maven, and the crash when running in IntelliJ on my Windows box. I'll look into it. My suspicion is that the JVM crash is a consequence of the exceptions and something is weird in the exception handling path, but I've not run it down yet. |
There's a lot going on here, but one major issue is that Spring seems to throw away the embedding call and something isn't holding a strong reference to it so something is getting closed out from under ONNX Runtime when running async. Adding a timeout to the async executor and making it wait causes the test to complete without error:
I think that's a problem with Spring AI and how it works with async stuff, as the embedding model shouldn't be closed until all the tasks have finished. |
Hi @Craigacp , First of all, a huge thank you for taking the time to investigate this issue! I really appreciate your help. Now that I understand the root cause, it makes perfect sense. The issue is related to how Spring Boot handles unit tests. When the test finishes, Spring shuts down the application context, but the async executor threads are still running in the background. Since the embedding model (and possibly other beans) are managed by Spring, they get destroyed before the async tasks can complete, leading to issues when those threads try to access them. Thanks again for your support! |
The reproducer does crash for me on macOS, and while it's due to the async behaviour not waiting properly I'd really prefer the JVM not to crash if the user uses ORT wrong, so I'm still trying to figure out exactly what's causing the crash. Maybe the stdout has gone away? @yuslepukhin this is the stack trace I get back out of a debug build on macOS, does it make more sense to you? It bottoms out in the logger's ISink::Send method, but I don't know what could be null in there. The crash is intermittent too, sometimes we get an
|
The problem is not necessarily caused by nullptr, it would be easy to detect then. The intermittent nature of the crash points to a racing condition. Without reference to this specific issue, in general, this may be caused:
Ort::Environment object must be global to the process and not to be Autoclosed prematurely (how's my Java terminology ?:) |
That all sounds fine. Turns out I'd left myself a note in the commit which changed the behaviour of OrtEnvironment to note exactly the thing that causes this bug can occur (that Java daemon threads run concurrently with shutdown hooks, so the threads can see a partially shutdown OrtEnv) and it can't be prevented from Java - #10670. I'd completely forgotten about that interaction between shutdown hooks and daemon threads. |
Bug description
When using the default ONNX embedding model in Spring AI (
all-MiniLM-L6-v2
), running the embedding process asynchronously with aThreadPoolTaskExecutor
results in inconsistent behavior and occasional runtime exceptions. The issue does not occur when executing the process synchronously.Environment
Java Version: 17
Spring Boot Version: Latest
Spring AI Version: 1.0.0-M5
ONNX Model:
all-MiniLM-L6-v2
Steps to reproduce
Handling embedding asynchronously.
Expected behavior
The embedding process should run correctly across multiple threads.
Observed behavior
Logs
Other infos
The JVM crashed in some tests:
I understand that this issue is likely caused by the interaction between Spring AI and ONNX Runtime, so I will also open an issue on the Spring AI repository. However, in the meantime, could you provide more information regarding the error encountered?
The text was updated successfully, but these errors were encountered: