Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[examples] fix multiprocessing on Linux #39

Merged
merged 2 commits into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ If you want to maximize performance, you need to install with following steps ex

## `screen/`

Take a screen capture and process it. **This script only works on Windows.**
Take a screen capture and process it.

When you run the script, a translucent window appears. Position it at where you want to capture the screen and press the enter key to finalize the capture area.

Expand Down
13 changes: 8 additions & 5 deletions examples/optimal-performance/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import threading
import time
import tkinter as tk
from multiprocessing import Process, Queue
from multiprocessing import Process, Queue, get_context
from typing import List, Literal

import fire
Expand Down Expand Up @@ -174,17 +174,20 @@ def main(
"""
Main function to start the image generation and viewer processes.
"""
queue = Queue()
fps_queue = Queue()
process1 = Process(
ctx = get_context('spawn')
queue = ctx.Queue()
fps_queue = ctx.Queue()
process1 = ctx.Process(
target=image_generation_process,
args=(queue, fps_queue, prompt, model_id_or_path, batch_size, acceleration),
)
process1.start()

process2 = Process(target=receive_images, args=(queue, fps_queue))
process2 = ctx.Process(target=receive_images, args=(queue, fps_queue))
process2.start()

process1.join()
process2.join()

if __name__ == "__main__":
fire.Fire(main)
13 changes: 8 additions & 5 deletions examples/optimal-performance/single.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import sys
import time
from multiprocessing import Process, Queue
from multiprocessing import Process, Queue, get_context
from typing import Literal

import fire
Expand Down Expand Up @@ -72,17 +72,20 @@ def main(
"""
Main function to start the image generation and viewer processes.
"""
queue = Queue()
fps_queue = Queue()
process1 = Process(
ctx = get_context('spawn')
queue = ctx.Queue()
fps_queue = ctx.Queue()
process1 = ctx.Process(
target=image_generation_process,
args=(queue, fps_queue, prompt, model_id_or_path, acceleration),
)
process1.start()

process2 = Process(target=receive_images, args=(queue, fps_queue))
process2 = ctx.Process(target=receive_images, args=(queue, fps_queue))
process2.start()

process1.join()
process2.join()

if __name__ == "__main__":
fire.Fire(main)
14 changes: 8 additions & 6 deletions examples/screen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import time
import threading
from multiprocessing import Process, Queue
from multiprocessing import Process, Queue, get_context
from typing import List, Literal, Dict, Optional
import torch
import PIL.Image
Expand Down Expand Up @@ -216,10 +216,10 @@ def main(
Main function to start the image generation and viewer processes.
"""
monitor = dummy_screen(width, height)

queue = Queue()
fps_queue = Queue()
process1 = Process(
ctx = get_context('spawn')
queue = ctx.Queue()
fps_queue = ctx.Queue()
process1 = ctx.Process(
target=image_generation_process,
args=(
queue,
Expand All @@ -246,9 +246,11 @@ def main(
)
process1.start()

process2 = Process(target=receive_images, args=(queue, fps_queue))
process2 = ctx.Process(target=receive_images, args=(queue, fps_queue))
process2.start()

process1.join()
process2.join()

if __name__ == "__main__":
fire.Fire(main)