From a06fbd1e788f93381138c9cbe8b5dfe5af489ac3 Mon Sep 17 00:00:00 2001 From: GradientSurfer Date: Sat, 23 Dec 2023 22:35:45 -0600 Subject: [PATCH 1/2] [examples] fix multiprocessing on Linux - use multiprocessing context to specify the spawn start method, which fixes the "RuntimeError: Cannot re-initialize CUDA in forked subprocess" on Linux (verified with Ubuntu 22.04 and kernel 6.2.0-39) - call `.join()` to wait for processes to complete (avoids exiting program immediately) --- examples/optimal-performance/multi.py | 13 ++++++++----- examples/optimal-performance/single.py | 13 ++++++++----- examples/screen/main.py | 14 ++++++++------ 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/examples/optimal-performance/multi.py b/examples/optimal-performance/multi.py index 8d8ce691..c298d59f 100644 --- a/examples/optimal-performance/multi.py +++ b/examples/optimal-performance/multi.py @@ -3,7 +3,7 @@ import threading import time import tkinter as tk -from multiprocessing import Process, Queue +from multiprocessing import Process, Queue, get_context from typing import List, Literal import fire @@ -174,17 +174,20 @@ def main( """ Main function to start the image generation and viewer processes. """ - queue = Queue() - fps_queue = Queue() - process1 = Process( + ctx = get_context('spawn') + queue = ctx.Queue() + fps_queue = ctx.Queue() + process1 = ctx.Process( target=image_generation_process, args=(queue, fps_queue, prompt, model_id_or_path, batch_size, acceleration), ) process1.start() - process2 = Process(target=receive_images, args=(queue, fps_queue)) + process2 = ctx.Process(target=receive_images, args=(queue, fps_queue)) process2.start() + process1.join() + process2.join() if __name__ == "__main__": fire.Fire(main) diff --git a/examples/optimal-performance/single.py b/examples/optimal-performance/single.py index 8a2b16b5..65594aa8 100644 --- a/examples/optimal-performance/single.py +++ b/examples/optimal-performance/single.py @@ -1,7 +1,7 @@ import os import sys import time -from multiprocessing import Process, Queue +from multiprocessing import Process, Queue, get_context from typing import Literal import fire @@ -72,17 +72,20 @@ def main( """ Main function to start the image generation and viewer processes. """ - queue = Queue() - fps_queue = Queue() - process1 = Process( + ctx = get_context('spawn') + queue = ctx.Queue() + fps_queue = ctx.Queue() + process1 = ctx.Process( target=image_generation_process, args=(queue, fps_queue, prompt, model_id_or_path, acceleration), ) process1.start() - process2 = Process(target=receive_images, args=(queue, fps_queue)) + process2 = ctx.Process(target=receive_images, args=(queue, fps_queue)) process2.start() + process1.join() + process2.join() if __name__ == "__main__": fire.Fire(main) diff --git a/examples/screen/main.py b/examples/screen/main.py index a7dde65c..9f3705bd 100644 --- a/examples/screen/main.py +++ b/examples/screen/main.py @@ -2,7 +2,7 @@ import sys import time import threading -from multiprocessing import Process, Queue +from multiprocessing import Process, Queue, get_context from typing import List, Literal, Dict, Optional import torch import PIL.Image @@ -216,10 +216,10 @@ def main( Main function to start the image generation and viewer processes. """ monitor = dummy_screen(width, height) - - queue = Queue() - fps_queue = Queue() - process1 = Process( + ctx = get_context('spawn') + queue = ctx.Queue() + fps_queue = ctx.Queue() + process1 = ctx.Process( target=image_generation_process, args=( queue, @@ -246,9 +246,11 @@ def main( ) process1.start() - process2 = Process(target=receive_images, args=(queue, fps_queue)) + process2 = ctx.Process(target=receive_images, args=(queue, fps_queue)) process2.start() + process1.join() + process2.join() if __name__ == "__main__": fire.Fire(main) \ No newline at end of file From cce112502913b45fc156b22cad8215762ac8a9d6 Mon Sep 17 00:00:00 2001 From: GradientSurfer Date: Sat, 23 Dec 2023 23:22:12 -0600 Subject: [PATCH 2/2] update docs --- examples/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/README.md b/examples/README.md index 11ee05db..7a3502d0 100644 --- a/examples/README.md +++ b/examples/README.md @@ -10,7 +10,7 @@ If you want to maximize performance, you need to install with following steps ex ## `screen/` -Take a screen capture and process it. **This script only works on Windows.** +Take a screen capture and process it. When you run the script, a translucent window appears. Position it at where you want to capture the screen and press the enter key to finalize the capture area.