diff --git a/.eslintrc.js b/.eslintrc.js index 9c70eff85..2e7258f6b 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -78,6 +78,8 @@ module.exports = { //extraNetworks.js requestGet: "readonly", popup: "readonly", + // profilerVisualization.js + createVisualizationTable: "readonly", // from python localization: "readonly", // progrssbar.js diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml deleted file mode 100644 index 5876e9410..000000000 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ /dev/null @@ -1,105 +0,0 @@ -name: Bug Report -description: You think something is broken in the UI -title: "[Bug]: " -labels: ["bug-report"] - -body: - - type: markdown - attributes: - value: | - > The title of the bug report should be short and descriptive. - > Use relevant keywords for searchability. - > Do not leave it blank, but also do not put an entire error log in it. - - type: checkboxes - attributes: - label: Checklist - description: | - Please perform basic debugging to see if extensions or configuration is the cause of the issue. - Basic debug procedure -  1. Disable all third-party extensions - check if extension is the cause -  2. Update extensions and webui - sometimes things just need to be updated -  3. Backup and remove your config.json and ui-config.json - check if the issue is caused by bad configuration -  4. Delete venv with third-party extensions disabled - sometimes extensions might cause wrong libraries to be installed -  5. Try a fresh installation webui in a different directory - see if a clean installation solves the issue - Before making a issue report please, check that the issue hasn't been reported recently. - options: - - label: The issue exists after disabling all extensions - - label: The issue exists on a clean installation of webui - - label: The issue is caused by an extension, but I believe it is caused by a bug in the webui - - label: The issue exists in the current version of the webui - - label: The issue has not been reported before recently - - label: The issue has been reported before but has not been fixed yet - - type: markdown - attributes: - value: | - > Please fill this form with as much information as possible. Don't forget to "Upload Sysinfo" and "What browsers" and provide screenshots if possible - - type: textarea - id: what-did - attributes: - label: What happened? - description: Tell us what happened in a very clear and simple way - placeholder: | - txt2img is not working as intended. - validations: - required: true - - type: textarea - id: steps - attributes: - label: Steps to reproduce the problem - description: Please provide us with precise step by step instructions on how to reproduce the bug - placeholder: | - 1. Go to ... - 2. Press ... - 3. ... - validations: - required: true - - type: textarea - id: what-should - attributes: - label: What should have happened? - description: Tell us what you think the normal behavior should be - placeholder: | - WebUI should ... - validations: - required: true - - type: dropdown - id: browsers - attributes: - label: What browsers do you use to access the UI ? - multiple: true - options: - - Mozilla Firefox - - Google Chrome - - Brave - - Apple Safari - - Microsoft Edge - - Android - - iOS - - Other - - type: textarea - id: sysinfo - attributes: - label: Sysinfo - description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file. - placeholder: | - 1. Go to WebUI Settings -> Sysinfo -> Download system info. - If WebUI fails to launch, use --dump-sysinfo commandline argument to generate the file - 2. Upload the Sysinfo as a attached file, Do NOT paste it in as plain text. - validations: - required: true - - type: textarea - id: logs - attributes: - label: Console logs - description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after the bug occured. If it's very long, provide a link to pastebin or similar service. - render: Shell - validations: - required: true - - type: textarea - id: misc - attributes: - label: Additional information - description: | - Please provide us with any relevant additional info or context. - Examples: -  I have updated my GPU driver recently. diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml deleted file mode 100644 index f58c94a9b..000000000 --- a/.github/ISSUE_TEMPLATE/config.yml +++ /dev/null @@ -1,5 +0,0 @@ -blank_issues_enabled: false -contact_links: - - name: WebUI Community Support - url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions - about: Please ask and answer questions here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml deleted file mode 100644 index 35a887408..000000000 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: Feature request -description: Suggest an idea for this project -title: "[Feature Request]: " -labels: ["enhancement"] - -body: - - type: checkboxes - attributes: - label: Is there an existing issue for this? - description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit. - options: - - label: I have searched the existing issues and checked the recent builds/commits - required: true - - type: markdown - attributes: - value: | - *Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible* - - type: textarea - id: feature - attributes: - label: What would your feature do ? - description: Tell us about your feature in a very clear and simple way, and what problem it would solve - validations: - required: true - - type: textarea - id: workflow - attributes: - label: Proposed workflow - description: Please provide us with step by step information on how you'd like the feature to be accessed and used - value: | - 1. Go to .... - 2. Press .... - 3. ... - validations: - required: true - - type: textarea - id: misc - attributes: - label: Additional information - description: Add any other context or screenshots about the feature request here. diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md deleted file mode 100644 index c9fcda2e2..000000000 --- a/.github/pull_request_template.md +++ /dev/null @@ -1,15 +0,0 @@ -## Description - -* a simple description of what you're trying to accomplish -* a summary of changes in code -* which issues it fixes, if any - -## Screenshots/videos: - - -## Checklist: - -- [ ] I have read [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) -- [ ] I have performed a self-review of my own code -- [ ] My code follows the [style guidelines](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing#code-style) -- [ ] My code passes [tests](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Tests) diff --git a/.github/workflows/on_pull_request.yaml b/.github/workflows/on_pull_request.yaml deleted file mode 100644 index 9e44c806a..000000000 --- a/.github/workflows/on_pull_request.yaml +++ /dev/null @@ -1,38 +0,0 @@ -name: Linter - -on: - - push - - pull_request - -jobs: - lint-python: - name: ruff - runs-on: ubuntu-latest - if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name - steps: - - name: Checkout Code - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: 3.11 - # NB: there's no cache: pip here since we're not installing anything - # from the requirements.txt file(s) in the repository; it's faster - # not to have GHA download an (at the time of writing) 4 GB cache - # of PyTorch and other dependencies. - - name: Install Ruff - run: pip install ruff==0.1.6 - - name: Run Ruff - run: ruff . - lint-js: - name: eslint - runs-on: ubuntu-latest - if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name - steps: - - name: Checkout Code - uses: actions/checkout@v3 - - name: Install Node.js - uses: actions/setup-node@v3 - with: - node-version: 18 - - run: npm i --ci - - run: npm run lint diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml deleted file mode 100644 index e075ba60d..000000000 --- a/.github/workflows/run_tests.yaml +++ /dev/null @@ -1,107 +0,0 @@ -name: Tests - -on: - - push - - pull_request - -env: - FORGE_CQ_TEST: "True" - -jobs: - test: - name: tests on CPU - runs-on: ubuntu-latest - if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name - steps: - - name: Checkout Code - uses: actions/checkout@v3 - - name: Set up Python 3.10 - uses: actions/setup-python@v4 - with: - python-version: 3.10.6 - cache: pip - cache-dependency-path: | - **/requirements*txt - launch.py - - name: Cache models - id: cache-models - uses: actions/cache@v3 - with: - path: models - key: "2023-12-30" - - name: Install test dependencies - run: pip install wait-for-it -r requirements-test.txt - env: - PIP_DISABLE_PIP_VERSION_CHECK: "1" - PIP_PROGRESS_BAR: "off" - - name: Setup environment - run: python launch.py --skip-torch-cuda-test --exit - env: - PIP_DISABLE_PIP_VERSION_CHECK: "1" - PIP_PROGRESS_BAR: "off" - TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu - WEBUI_LAUNCH_LIVE_OUTPUT: "1" - PYTHONUNBUFFERED: "1" - - name: Print installed packages - run: pip freeze - - name: Download models - run: | - declare -a urls=( - "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/realisticVisionV51_v51VAE.safetensors" - ) - for url in "${urls[@]}"; do - filename="models/Stable-diffusion/${url##*/}" # Extracts the last part of the URL - if [ ! -f "$filename" ]; then - curl -Lo "$filename" "$url" - fi - done - # - name: Download ControlNet models - # run: | - # declare -a urls=( - # "https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_canny.pth" - # ) - - # for url in "${urls[@]}"; do - # filename="models/ControlNet/${url##*/}" # Extracts the last part of the URL - # if [ ! -f "$filename" ]; then - # curl -Lo "$filename" "$url" - # fi - # done - - name: Start test server - run: > - python -m coverage run - --data-file=.coverage.server - launch.py - --skip-prepare-environment - --skip-torch-cuda-test - --test-server - --do-not-download-clip - --no-half - --disable-opt-split-attention - --always-cpu - --api-server-stop - --ckpt models/Stable-diffusion/realisticVisionV51_v51VAE.safetensors - 2>&1 | tee output.txt & - - name: Run tests - run: | - wait-for-it --service 127.0.0.1:7860 -t 20 - python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test - # TODO(huchenlei): Enable ControlNet tests. Currently it is too slow to run these tests on CPU with - # real SD model. We need to find a way to load empty SD model. - # - name: Run ControlNet tests - # run: > - # python -m pytest - # --junitxml=test/results.xml - # --cov ./extensions-builtin/sd_forge_controlnet - # --cov-report=xml - # --verify-base-url - # ./extensions-builtin/sd_forge_controlnet/tests - - name: Kill test server - if: always() - run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10 - - name: Upload main app output - uses: actions/upload-artifact@v3 - if: always() - with: - name: output - path: output.txt diff --git a/.github/workflows/warns_merge_master.yml b/.github/workflows/warns_merge_master.yml deleted file mode 100644 index ae2aab6ba..000000000 --- a/.github/workflows/warns_merge_master.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Pull requests can't target master branch - -"on": - pull_request: - types: - - opened - - synchronize - - reopened - branches: - - master - -jobs: - check: - runs-on: ubuntu-latest - steps: - - name: Warning marge into master - run: | - echo -e "::warning::This pull request directly merge into \"master\" branch, normally development happens on \"dev\" branch." - exit 1 diff --git a/.gitignore b/.gitignore index ca7c47ee1..b572a8f42 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,9 @@ __pycache__ *.ckpt *.safetensors *.pth +*.dev.js +.DS_Store +/output/ /ESRGAN/* /SwinIR/* /repositories @@ -39,6 +42,9 @@ notification.mp3 /package-lock.json /.coverage* /test/test_outputs +/cache +trace.json +/sysinfo-????-??-??-??-??.json /test/results.xml coverage.xml -**/tests/**/expectations \ No newline at end of file +**/tests/**/expectations diff --git a/CHANGELOG.md b/CHANGELOG.md index 67429bbff..301bfd068 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,407 @@ +## 1.10.0 + +### Features: +* A lot of performance improvements (see below in Performance section) +* Stable Diffusion 3 support ([#16030](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16030), [#16164](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16164), [#16212](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16212)) + * Recommended Euler sampler; DDIM and other timestamp samplers currently not supported + * T5 text model is disabled by default, enable it in settings +* New schedulers: + * Align Your Steps ([#15751](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15751)) + * KL Optimal ([#15608](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608)) + * Normal ([#16149](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16149)) + * DDIM ([#16149](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16149)) + * Simple ([#16142](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16142)) + * Beta ([#16235](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16235)) +* New sampler: DDIM CFG++ ([#16035](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16035)) + +### Minor: +* Option to skip CFG on early steps ([#15607](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15607)) +* Add --models-dir option ([#15742](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15742)) +* Allow mobile users to open context menu by using two fingers press ([#15682](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15682)) +* Infotext: add Lora name as TI hashes for bundled Textual Inversion ([#15679](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15679)) +* Check model's hash after downloading it to prevent corruped downloads ([#15602](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15602)) +* More extension tag filtering options ([#15627](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15627)) +* When saving AVIF, use JPEG's quality setting ([#15610](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15610)) +* Add filename pattern: `[basename]` ([#15978](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15978)) +* Add option to enable clip skip for clip L on SDXL ([#15992](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15992)) +* Option to prevent screen sleep during generation ([#16001](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16001)) +* ToggleLivePriview button in image viewer ([#16065](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16065)) +* Remove ui flashing on reloading and fast scrollong ([#16153](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16153)) +* option to disable save button log.csv ([#16242](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16242)) + +### Extensions and API: +* Add process_before_every_sampling hook ([#15984](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15984)) +* Return HTTP 400 instead of 404 on invalid sampler error ([#16140](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16140)) + +### Performance: +* [Performance 1/6] use_checkpoint = False ([#15803](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15803)) +* [Performance 2/6] Replace einops.rearrange with torch native ops ([#15804](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15804)) +* [Performance 4/6] Precompute is_sdxl_inpaint flag ([#15806](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15806)) +* [Performance 5/6] Prevent unnecessary extra networks bias backup ([#15816](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15816)) +* [Performance 6/6] Add --precision half option to avoid casting during inference ([#15820](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15820)) +* [Performance] LDM optimization patches ([#15824](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15824)) +* [Performance] Keep sigmas on CPU ([#15823](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15823)) +* Check for nans in unet only once, after all steps have been completed +* Added pption to run torch profiler for image generation + +### Bug Fixes: +* Fix for grids without comprehensive infotexts ([#15958](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15958)) +* feat: lora partial update precede full update ([#15943](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15943)) +* Fix bug where file extension had an extra '.' under some circumstances ([#15893](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15893)) +* Fix corrupt model initial load loop ([#15600](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15600)) +* Allow old sampler names in API ([#15656](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15656)) +* more old sampler scheduler compatibility ([#15681](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15681)) +* Fix Hypertile xyz ([#15831](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15831)) +* XYZ CSV skipinitialspace ([#15832](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15832)) +* fix soft inpainting on mps and xpu, torch_utils.float64 ([#15815](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15815)) +* fix extention update when not on main branch ([#15797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15797)) +* update pickle safe filenames +* use relative path for webui-assets css ([#15757](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15757)) +* When creating a virtual environment, upgrade pip in webui.bat/webui.sh ([#15750](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15750)) +* Fix AttributeError ([#15738](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15738)) +* use script_path for webui root in launch_utils ([#15705](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15705)) +* fix extra batch mode P Transparency ([#15664](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15664)) +* use gradio theme colors in css ([#15680](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15680)) +* Fix dragging text within prompt input ([#15657](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15657)) +* Add correct mimetype for .mjs files ([#15654](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15654)) +* QOL Items - handle metadata issues more cleanly for SD models, Loras and embeddings ([#15632](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15632)) +* replace wsl-open with wslpath and explorer.exe ([#15968](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15968)) +* Fix SDXL Inpaint ([#15976](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15976)) +* multi size grid ([#15988](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15988)) +* fix Replace preview ([#16118](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16118)) +* Possible fix of wrong scale in weight decomposition ([#16151](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16151)) +* Ensure use of python from venv on Mac and Linux ([#16116](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16116)) +* Prioritize python3.10 over python3 if both are available on Linux and Mac (with fallback) ([#16092](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16092)) +* stoping generation extras ([#16085](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16085)) +* Fix SD2 loading ([#16078](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16078), [#16079](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16079)) +* fix infotext Lora hashes for hires fix different lora ([#16062](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16062)) +* Fix sampler scheduler autocorrection warning ([#16054](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16054)) +* fix ui flashing on reloading and fast scrollong ([#16153](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16153)) +* fix upscale logic ([#16239](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16239)) +* [bug] do not break progressbar on non-job actions (add wrap_gradio_call_no_job) ([#16202](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16202)) +* fix OSError: cannot write mode P as JPEG ([#16194](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16194)) + +### Other: +* fix changelog #15883 -> #15882 ([#15907](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15907)) +* ReloadUI backgroundColor --background-fill-primary ([#15864](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15864)) +* Use different torch versions for Intel and ARM Macs ([#15851](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15851)) +* XYZ override rework ([#15836](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15836)) +* scroll extensions table on overflow ([#15830](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15830)) +* img2img batch upload method ([#15817](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15817)) +* chore: sync v1.8.0 packages according to changelog ([#15783](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15783)) +* Add AVIF MIME type support to mimetype definitions ([#15739](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15739)) +* Update imageviewer.js ([#15730](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15730)) +* no-referrer ([#15641](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15641)) +* .gitignore trace.json ([#15980](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15980)) +* Bump spandrel to 0.3.4 ([#16144](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16144)) +* Defunct --max-batch-count ([#16119](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16119)) +* docs: update bug_report.yml ([#16102](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16102)) +* Maintaining Project Compatibility for Python 3.9 Users Without Upgrade Requirements. ([#16088](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16088), [#16169](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16169), [#16192](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16192)) +* Update torch for ARM Macs to 2.3.1 ([#16059](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16059)) +* remove deprecated setting dont_fix_second_order_samplers_schedule ([#16061](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16061)) +* chore: fix typos ([#16060](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16060)) +* shlex.join launch args in console log ([#16170](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16170)) +* activate venv .bat ([#16231](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16231)) +* add ids to the resize tabs in img2img ([#16218](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16218)) +* update installation guide linux ([#16178](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16178)) +* Robust sysinfo ([#16173](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16173)) +* do not send image size on paste inpaint ([#16180](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16180)) +* Fix noisy DS_Store files for MacOS ([#16166](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16166)) + + +## 1.9.4 + +### Bug Fixes: +* pin setuptools version to fix the startup error ([#15882](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15882)) + +## 1.9.3 + +### Bug Fixes: +* fix get_crop_region_v2 ([#15594](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15594)) + +## 1.9.2 + +### Extensions and API: +* restore 1.8.0-style naming of scripts + +## 1.9.1 + +### Minor: +* Add avif support ([#15582](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15582)) +* Add filename patterns: `[sampler_scheduler]` and `[scheduler]` ([#15581](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15581)) + +### Extensions and API: +* undo adding scripts to sys.modules +* Add schedulers API endpoint ([#15577](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15577)) +* Remove API upscaling factor limits ([#15560](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15560)) + +### Bug Fixes: +* Fix images do not match / Coordinate 'right' is less than 'left' ([#15534](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15534)) +* fix: remove_callbacks_for_function should also remove from the ordered map ([#15533](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15533)) +* fix x1 upscalers ([#15555](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15555)) +* Fix cls.__module__ value in extension script ([#15532](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15532)) +* fix typo in function call (eror -> error) ([#15531](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15531)) + +### Other: +* Hide 'No Image data blocks found.' message ([#15567](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15567)) +* Allow webui.sh to be runnable from arbitrary directories containing a .git file ([#15561](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15561)) +* Compatibility with Debian 11, Fedora 34+ and openSUSE 15.4+ ([#15544](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15544)) +* numpy DeprecationWarning product -> prod ([#15547](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15547)) +* get_crop_region_v2 ([#15583](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15583), [#15587](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15587)) + + +## 1.9.0 + +### Features: +* Make refiner switchover based on model timesteps instead of sampling steps ([#14978](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14978)) +* add an option to have old-style directory view instead of tree view; stylistic changes for extra network sorting/search controls +* add UI for reordering callbacks, support for specifying callback order in extension metadata ([#15205](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15205)) +* Sgm uniform scheduler for SDXL-Lightning models ([#15325](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15325)) +* Scheduler selection in main UI ([#15333](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15333), [#15361](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15361), [#15394](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15394)) + +### Minor: +* "open images directory" button now opens the actual dir ([#14947](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14947)) +* Support inference with LyCORIS BOFT networks ([#14871](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14871), [#14973](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14973)) +* make extra network card description plaintext by default, with an option to re-enable HTML as it was +* resize handle for extra networks ([#15041](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15041)) +* cmd args: `--unix-filenames-sanitization` and `--filenames-max-length` ([#15031](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15031)) +* show extra networks parameters in HTML table rather than raw JSON ([#15131](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15131)) +* Add DoRA (weight-decompose) support for LoRA/LoHa/LoKr ([#15160](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15160), [#15283](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15283)) +* Add '--no-prompt-history' cmd args for disable last generation prompt history ([#15189](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15189)) +* update preview on Replace Preview ([#15201](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15201)) +* only fetch updates for extensions' active git branches ([#15233](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15233)) +* put upscale postprocessing UI into an accordion ([#15223](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15223)) +* Support dragdrop for URLs to read infotext ([#15262](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15262)) +* use diskcache library for caching ([#15287](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15287), [#15299](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15299)) +* Allow PNG-RGBA for Extras Tab ([#15334](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15334)) +* Support cover images embedded in safetensors metadata ([#15319](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15319)) +* faster interrupt when using NN upscale ([#15380](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15380)) +* Extras upscaler: an input field to limit maximul side length for the output image ([#15293](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15293), [#15415](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15415), [#15417](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15417), [#15425](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15425)) +* add an option to hide postprocessing options in Extras tab + +### Extensions and API: +* ResizeHandleRow - allow overriden column scale parametr ([#15004](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15004)) +* call script_callbacks.ui_settings_callback earlier; fix extra-options-section built-in extension killing the ui if using a setting that doesn't exist +* make it possible to use zoom.js outside webui context ([#15286](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15286), [#15288](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15288)) +* allow variants for extension name in metadata.ini ([#15290](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15290)) +* make reloading UI scripts optional when doing Reload UI, and off by default +* put request: gr.Request at start of img2img function similar to txt2img +* open_folder as util ([#15442](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15442)) +* make it possible to import extensions' script files as `import scripts.` ([#15423](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15423)) + +### Performance: +* performance optimization for extra networks HTML pages +* optimization for extra networks filtering +* optimization for extra networks sorting + +### Bug Fixes: +* prevent escape button causing an interrupt when no generation has been made yet +* [bug] avoid doble upscaling in inpaint ([#14966](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14966)) +* possible fix for reload button not appearing in some cases for extra networks. +* fix: the `split_threshold` parameter does not work when running Split oversized images ([#15006](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15006)) +* Fix resize-handle visability for vertical layout (mobile) ([#15010](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15010)) +* register_tmp_file also for mtime ([#15012](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15012)) +* Protect alphas_cumprod during refiner switchover ([#14979](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14979)) +* Fix EXIF orientation in API image loading ([#15062](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15062)) +* Only override emphasis if actually used in prompt ([#15141](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15141)) +* Fix emphasis infotext missing from `params.txt` ([#15142](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15142)) +* fix extract_style_text_from_prompt #15132 ([#15135](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15135)) +* Fix Soft Inpaint for AnimateDiff ([#15148](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15148)) +* edit-attention: deselect surrounding whitespace ([#15178](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15178)) +* chore: fix font not loaded ([#15183](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15183)) +* use natural sort in extra networks when ordering by path +* Fix built-in lora system bugs caused by torch.nn.MultiheadAttention ([#15190](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15190)) +* Avoid error from None in get_learned_conditioning ([#15191](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15191)) +* Add entry to MassFileLister after writing metadata ([#15199](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15199)) +* fix issue with Styles when Hires prompt is used ([#15269](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15269), [#15276](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15276)) +* Strip comments from hires fix prompt ([#15263](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15263)) +* Make imageviewer event listeners browser consistent ([#15261](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15261)) +* Fix AttributeError in OFT when trying to get MultiheadAttention weight ([#15260](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15260)) +* Add missing .mean() back ([#15239](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15239)) +* fix "Restore progress" button ([#15221](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15221)) +* fix ui-config for InputAccordion [custom_script_source] ([#15231](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15231)) +* handle 0 wheel deltaY ([#15268](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15268)) +* prevent alt menu for firefox ([#15267](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15267)) +* fix: fix syntax errors ([#15179](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15179)) +* restore outputs path ([#15307](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15307)) +* Escape btn_copy_path filename ([#15316](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15316)) +* Fix extra networks buttons when filename contains an apostrophe ([#15331](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15331)) +* escape brackets in lora random prompt generator ([#15343](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15343)) +* fix: Python version check for PyTorch installation compatibility ([#15390](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15390)) +* fix typo in call_queue.py ([#15386](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15386)) +* fix: when find already_loaded model, remove loaded by array index ([#15382](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15382)) +* minor bug fix of sd model memory management ([#15350](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15350)) +* Fix CodeFormer weight ([#15414](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15414)) +* Fix: Remove script callbacks in ordered_callbacks_map ([#15428](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15428)) +* fix limited file write (thanks, Sylwia) +* Fix extra-single-image API not doing upscale failed ([#15465](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15465)) +* error handling paste_field callables ([#15470](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15470)) + +### Hardware: +* Add training support and change lspci for Ascend NPU ([#14981](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14981)) +* Update to ROCm5.7 and PyTorch ([#14820](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14820)) +* Better workaround for Navi1, removing --pre for Navi3 ([#15224](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15224)) +* Ascend NPU wiki page ([#15228](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15228)) + +### Other: +* Update comment for Pad prompt/negative prompt v0 to add a warning about truncation, make it override the v1 implementation +* support resizable columns for touch (tablets) ([#15002](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15002)) +* Fix #14591 using translated content to do categories mapping ([#14995](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14995)) +* Use `absolute` path for normalized filepath ([#15035](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15035)) +* resizeHandle handle double tap ([#15065](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15065)) +* --dat-models-path cmd flag ([#15039](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15039)) +* Add a direct link to the binary release ([#15059](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15059)) +* upscaler_utils: Reduce logging ([#15084](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15084)) +* Fix various typos with crate-ci/typos ([#15116](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15116)) +* fix_jpeg_live_preview ([#15102](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15102)) +* [alternative fix] can't load webui if selected wrong extra option in ui ([#15121](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15121)) +* Error handling for unsupported transparency ([#14958](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14958)) +* Add model description to searched terms ([#15198](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15198)) +* bump action version ([#15272](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15272)) +* PEP 604 annotations ([#15259](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15259)) +* Automatically Set the Scale by value when user selects an Upscale Model ([#15244](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15244)) +* move postprocessing-for-training into builtin extensions ([#15222](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15222)) +* type hinting in shared.py ([#15211](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15211)) +* update ruff to 0.3.3 +* Update pytorch lightning utilities ([#15310](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15310)) +* Add Size as an XYZ Grid option ([#15354](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15354)) +* Use HF_ENDPOINT variable for HuggingFace domain with default ([#15443](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15443)) +* re-add update_file_entry ([#15446](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15446)) +* create_infotext allow index and callable, re-work Hires prompt infotext ([#15460](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15460)) +* update restricted_opts to include more options for --hide-ui-dir-config ([#15492](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15492)) + + +## 1.8.0 + +### Features: +* Update torch to version 2.1.2 +* Soft Inpainting ([#14208](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14208)) +* FP8 support ([#14031](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14031), [#14327](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14327)) +* Support for SDXL-Inpaint Model ([#14390](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14390)) +* Use Spandrel for upscaling and face restoration architectures ([#14425](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14425), [#14467](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14467), [#14473](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14473), [#14474](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14474), [#14477](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14477), [#14476](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14476), [#14484](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14484), [#14500](https://github.com/AUTOMATIC1111/stable-difusion-webui/pull/14500), [#14501](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14501), [#14504](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14504), [#14524](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14524), [#14809](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14809)) +* Automatic backwards version compatibility (when loading infotexts from old images with program version specified, will add compatibility settings) +* Implement zero terminal SNR noise schedule option (**[SEED BREAKING CHANGE](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Seed-breaking-changes#180-dev-170-225-2024-01-01---zero-terminal-snr-noise-schedule-option)**, [#14145](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14145), [#14979](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14979)) +* Add a [✨] button to run hires fix on selected image in the gallery (with help from [#14598](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14598), [#14626](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14626), [#14728](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14728)) +* [Separate assets repository](https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets); serve fonts locally rather than from google's servers +* Official LCM Sampler Support ([#14583](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14583)) +* Add support for DAT upscaler models ([#14690](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14690), [#15039](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15039)) +* Extra Networks Tree View ([#14588](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14588), [#14900](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14900)) +* NPU Support ([#14801](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14801)) +* Prompt comments support + +### Minor: +* Allow pasting in WIDTHxHEIGHT strings into the width/height fields ([#14296](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14296)) +* add option: Live preview in full page image viewer ([#14230](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14230), [#14307](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14307)) +* Add keyboard shortcuts for generate/skip/interrupt ([#14269](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14269)) +* Better TCMALLOC support on different platforms ([#14227](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14227), [#14883](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14883), [#14910](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14910)) +* Lora not found warning ([#14464](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14464)) +* Adding negative prompts to Loras in extra networks ([#14475](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14475)) +* xyz_grid: allow varying the seed along an axis separate from axis options ([#12180](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12180)) +* option to convert VAE to bfloat16 (implementation of [#9295](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9295)) +* Better IPEX support ([#14229](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14229), [#14353](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14353), [#14559](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14559), [#14562](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14562), [#14597](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14597)) +* Option to interrupt after current generation rather than immediately ([#13653](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13653), [#14659](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14659)) +* Fullscreen Preview control fading/disable ([#14291](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14291)) +* Finer settings freezing control ([#13789](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13789)) +* Increase Upscaler Limits ([#14589](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14589)) +* Adjust brush size with hotkeys ([#14638](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14638)) +* Add checkpoint info to csv log file when saving images ([#14663](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14663)) +* Make more columns resizable ([#14740](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14740), [#14884](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14884)) +* Add an option to not overlay original image for inpainting for #14727 +* Add Pad conds v0 option to support same generation with DDIM as before 1.6.0 +* Add "Interrupting..." placeholder. +* Button for refresh extensions list ([#14857](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14857)) +* Add an option to disable normalization after calculating emphasis. ([#14874](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14874)) +* When counting tokens, also include enabled styles (can be disabled in settings to revert to previous behavior) +* Configuration for the [📂] button for image gallery ([#14947](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14947)) +* Support inference with LyCORIS BOFT networks ([#14871](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14871), [#14973](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14973)) +* support resizable columns for touch (tablets) ([#15002](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15002)) + +### Extensions and API: +* Removed packages from requirements: basicsr, gfpgan, realesrgan; as well as their dependencies: absl-py, addict, beautifulsoup4, future, gdown, grpcio, importlib-metadata, lmdb, lpips, Markdown, platformdirs, PySocks, soupsieve, tb-nightly, tensorboard-data-server, tomli, Werkzeug, yapf, zipp, soupsieve +* Enable task ids for API ([#14314](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14314)) +* add override_settings support for infotext API +* rename generation_parameters_copypaste module to infotext_utils +* prevent crash due to Script __init__ exception ([#14407](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14407)) +* Bump numpy to 1.26.2 ([#14471](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14471)) +* Add utility to inspect a model's dtype/device ([#14478](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14478)) +* Implement general forward method for all method in built-in lora ext ([#14547](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14547)) +* Execute model_loaded_callback after moving to target device ([#14563](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14563)) +* Add self to CFGDenoiserParams ([#14573](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14573)) +* Allow TLS with API only mode (--nowebui) ([#14593](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14593)) +* New callback: postprocess_image_after_composite ([#14657](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14657)) +* modules/api/api.py: add api endpoint to refresh embeddings list ([#14715](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14715)) +* set_named_arg ([#14773](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14773)) +* add before_token_counter callback and use it for prompt comments +* ResizeHandleRow - allow overridden column scale parameter ([#15004](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15004)) + +### Performance: +* Massive performance improvement for extra networks directories with a huge number of files in them in an attempt to tackle #14507 ([#14528](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14528)) +* Reduce unnecessary re-indexing extra networks directory ([#14512](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14512)) +* Avoid unnecessary `isfile`/`exists` calls ([#14527](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14527)) + +### Bug Fixes: +* fix multiple bugs related to styles multi-file support ([#14203](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14203), [#14276](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14276), [#14707](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14707)) +* Lora fixes ([#14300](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14300), [#14237](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14237), [#14546](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14546), [#14726](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14726)) +* Re-add setting lost as part of e294e46 ([#14266](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14266)) +* fix extras caption BLIP ([#14330](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14330)) +* include infotext into saved init image for img2img ([#14452](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14452)) +* xyz grid handle axis_type is None ([#14394](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14394)) +* Update Added (Fixed) IPV6 Functionality When there is No Webui Argument Passed webui.py ([#14354](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14354)) +* fix API thread safe issues of txt2img and img2img ([#14421](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14421)) +* handle selectable script_index is None ([#14487](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14487)) +* handle config.json failed to load ([#14525](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14525), [#14767](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14767)) +* paste infotext cast int as float ([#14523](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14523)) +* Ensure GRADIO_ANALYTICS_ENABLED is set early enough ([#14537](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14537)) +* Fix logging configuration again ([#14538](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14538)) +* Handle CondFunc exception when resolving attributes ([#14560](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14560)) +* Fix extras big batch crashes ([#14699](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14699)) +* Fix using wrong model caused by alias ([#14655](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14655)) +* Add # to the invalid_filename_chars list ([#14640](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14640)) +* Fix extension check for requirements ([#14639](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14639)) +* Fix tab indexes are reset after restart UI ([#14637](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14637)) +* Fix nested manual cast ([#14689](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14689)) +* Keep postprocessing upscale selected tab after restart ([#14702](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14702)) +* XYZ grid: filter out blank vals when axis is int or float type (like int axis seed) ([#14754](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14754)) +* fix CLIP Interrogator topN regex ([#14775](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14775)) +* Fix dtype error in MHA layer/change dtype checking mechanism for manual cast ([#14791](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14791)) +* catch load style.csv error ([#14814](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14814)) +* fix error when editing extra networks card +* fix extra networks metadata failing to work properly when you create the .json file with metadata for the first time. +* util.walk_files extensions case insensitive ([#14879](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14879)) +* if extensions page not loaded, prevent apply ([#14873](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14873)) +* call the right function for token counter in img2img +* Fix the bugs that search/reload will disappear when using other ExtraNetworks extensions ([#14939](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14939)) +* Gracefully handle mtime read exception from cache ([#14933](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14933)) +* Only trigger interrupt on `Esc` when interrupt button visible ([#14932](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14932)) +* Disable prompt token counters option actually disables token counting rather than just hiding results. +* avoid double upscaling in inpaint ([#14966](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14966)) +* Fix #14591 using translated content to do categories mapping ([#14995](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14995)) +* fix: the `split_threshold` parameter does not work when running Split oversized images ([#15006](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15006)) +* Fix resize-handle for mobile ([#15010](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15010), [#15065](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15065)) + +### Other: +* Assign id for "extra_options". Replace numeric field with slider. ([#14270](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14270)) +* change state dict comparison to ref compare ([#14216](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14216)) +* Bump torch-rocm to 5.6/5.7 ([#14293](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14293)) +* Base output path off data path ([#14446](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14446)) +* reorder training preprocessing modules in extras tab ([#14367](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14367)) +* Remove `cleanup_models` code ([#14472](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14472)) +* only rewrite ui-config when there is change ([#14352](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14352)) +* Fix lint issue from 501993eb ([#14495](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14495)) +* Update README.md ([#14548](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14548)) +* hires button, fix seeds () +* Logging: set formatter correctly for fallback logger too ([#14618](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14618)) +* Read generation info from infotexts rather than json for internal needs (save, extract seed from generated pic) ([#14645](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14645)) +* improve get_crop_region ([#14709](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14709)) +* Bump safetensors' version to 0.4.2 ([#14782](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14782)) +* add tooltip create_submit_box ([#14803](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14803)) +* extensions tab table row hover highlight ([#14885](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14885)) +* Always add timestamp to displayed image ([#14890](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14890)) +* Added core.filemode=false so doesn't track changes in file permission… ([#14930](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14930)) +* Normalize command-line argument paths ([#14934](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14934), [#15035](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15035)) +* Use original App Title in progress bar ([#14916](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14916)) +* register_tmp_file also for mtime ([#15012](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15012)) + ## 1.7.0 ### Features: @@ -40,7 +444,8 @@ * infotext updates: add option to disregard certain infotext fields, add option to not include VAE in infotext, add explanation to infotext settings page, move some options to infotext settings page * add FP32 fallback support on sd_vae_approx ([#14046](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046)) * support XYZ scripts / split hires path from unet ([#14126](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14126)) -* allow use of mutiple styles csv files ([#14125](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14125)) +* allow use of multiple styles csv files ([#14125](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14125)) +* make extra network card description plaintext by default, with an option (Treat card description as HTML) to re-enable HTML as it was (originally by [#13241](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13241)) ### Extensions and API: * update gradio to 3.41.2 @@ -176,7 +581,7 @@ * new samplers: Restart, DPM++ 2M SDE Exponential, DPM++ 2M SDE Heun, DPM++ 2M SDE Heun Karras, DPM++ 2M SDE Heun Exponential, DPM++ 3M SDE, DPM++ 3M SDE Karras, DPM++ 3M SDE Exponential ([#12300](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12300), [#12519](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12519), [#12542](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12542)) * rework DDIM, PLMS, UniPC to use CFG denoiser same as in k-diffusion samplers: * makes all of them work with img2img - * makes prompt composition posssible (AND) + * makes prompt composition possible (AND) * makes them available for SDXL * always show extra networks tabs in the UI ([#11808](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11808)) * use less RAM when creating models ([#11958](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11958), [#12599](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12599)) @@ -352,7 +757,7 @@ * user metadata system for custom networks * extended Lora metadata editor: set activation text, default weight, view tags, training info * Lora extension rework to include other types of networks (all that were previously handled by LyCORIS extension) - * show github stars for extenstions + * show github stars for extensions * img2img batch mode can read extra stuff from png info * img2img batch works with subdirectories * hotkeys to move prompt elements: alt+left/right @@ -571,7 +976,7 @@ * do not wait for Stable Diffusion model to load at startup * add filename patterns: `[denoising]` * directory hiding for extra networks: dirs starting with `.` will hide their cards on extra network tabs unless specifically searched for - * LoRA: for the `<...>` text in prompt, use name of LoRA that is in the metdata of the file, if present, instead of filename (both can be used to activate LoRA) + * LoRA: for the `<...>` text in prompt, use name of LoRA that is in the metadata of the file, if present, instead of filename (both can be used to activate LoRA) * LoRA: read infotext params from kohya-ss's extension parameters if they are present and if his extension is not active * LoRA: fix some LoRAs not working (ones that have 3x3 convolution layer) * LoRA: add an option to use old method of applying LoRAs (producing same results as with kohya-ss) @@ -601,7 +1006,7 @@ * fix gamepad navigation * make the lightbox fullscreen image function properly * fix squished thumbnails in extras tab - * keep "search" filter for extra networks when user refreshes the tab (previously it showed everthing after you refreshed) + * keep "search" filter for extra networks when user refreshes the tab (previously it showed everything after you refreshed) * fix webui showing the same image if you configure the generation to always save results into same file * fix bug with upscalers not working properly * fix MPS on PyTorch 2.0.1, Intel Macs @@ -619,7 +1024,7 @@ * switch to PyTorch 2.0.0 (except for AMD GPUs) * visual improvements to custom code scripts * add filename patterns: `[clip_skip]`, `[hasprompt<>]`, `[batch_number]`, `[generation_number]` - * add support for saving init images in img2img, and record their hashes in infotext for reproducability + * add support for saving init images in img2img, and record their hashes in infotext for reproducibility * automatically select current word when adjusting weight with ctrl+up/down * add dropdowns for X/Y/Z plot * add setting: Stable Diffusion/Random number generator source: makes it possible to make images generated from a given manual seed consistent across different GPUs diff --git a/README.md b/README.md index 8eae83ceb..60ffac6f1 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,11 @@ -# Under Construction - -**Oops we are upgrading the repo now ... Please come back several hours later ...** - # Stable Diffusion WebUI Forge Stable Diffusion WebUI Forge is a platform on top of [Stable Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) (based on [Gradio](https://www.gradio.app/)) to make development easier, optimize resource management, speed up inference, and study experimental features. The name "Forge" is inspired from "Minecraft Forge". This project is aimed at becoming SD WebUI's Forge. +This repo will undergo major change very recently. See also the [Announcement](https://github.com/lllyasviel/stable-diffusion-webui-forge/discussions/801). + # Installing Forge If you are proficient in Git and you want to install Forge as another branch of SD-WebUI, please see [here](https://github.com/continue-revolution/sd-webui-animatediff/blob/forge/master/docs/how-to-use.md#you-have-a1111-and-you-know-git). In this way, you can reuse all SD checkpoints and all extensions you installed previously in your OG SD-WebUI, but you should know what you are doing. @@ -24,6 +22,669 @@ Note that running `update.bat` is important, otherwise you may be using a previo ![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/c49bd60d-82bd-4086-9859-88d472582b94) -### Previous Versions +## Previous Versions You can download previous versions [here](https://github.com/lllyasviel/stable-diffusion-webui-forge/discussions/849). + +# Screenshots of Comparison + +I tested with several devices, and this is a typical result from 8GB VRAM (3070ti laptop) with SDXL. + +**This is original WebUI:** + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/16893937-9ed9-4f8e-b960-70cd5d1e288f) + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/7bbc16fe-64ef-49e2-a595-d91bb658bd94) + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/de1747fd-47bc-482d-a5c6-0728dd475943) + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/96e5e171-2d74-41ba-9dcc-11bf68be7e16) + +(average about 7.4GB/8GB, peak at about 7.9GB/8GB) + +**This is WebUI Forge:** + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/ca5e05ed-bd86-4ced-8662-f41034648e8c) + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/3629ee36-4a99-4d9b-b371-12efb260a283) + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/6d13ebb7-c30d-4aa8-9242-c0b5a1af8c95) + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/c4f723c3-6ea7-4539-980b-0708ed2a69aa) + +(average and peak are all 6.3GB/8GB) + +You can see that Forge does not change WebUI results. Installing Forge is not a seed breaking change. + +Forge can perfectly keep WebUI unchanged even for most complicated prompts like `fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]`. + +All your previous works still work in Forge! + +# Forge Backend + +Forge backend removes all WebUI's codes related to resource management and reworked everything. All previous CMD flags like `medvram, lowvram, medvram-sdxl, precision full, no half, no half vae, attention_xxx, upcast unet`, ... are all **REMOVED**. Adding these flags will not cause error but they will not do anything now. **We highly encourage Forge users to remove all cmd flags and let Forge to decide how to load models.** + +Without any cmd flag, Forge can run SDXL with 4GB vram and SD1.5 with 2GB vram. + +**Some flags that you may still pay attention to:** + +1. `--always-offload-from-vram` (This flag will make things **slower** but less risky). This option will let Forge always unload models from VRAM. This can be useful if you use multiple software together and want Forge to use less VRAM and give some VRAM to other software, or when you are using some old extensions that will compete vram with Forge, or (very rarely) when you get OOM. + +2. `--cuda-malloc` (This flag will make things **faster** but more risky). This will ask pytorch to use *cudaMallocAsync* for tensor malloc. On some profilers I can observe performance gain at millisecond level, but the real speed up on most my devices are often unnoticed (about or less than 0.1 second per image). This cannot be set as default because many users reported issues that the async malloc will crash the program. Users need to enable this cmd flag at their own risk. + +3. `--cuda-stream` (This flag will make things **faster** but more risky). This will use pytorch CUDA streams (a special type of thread on GPU) to move models and compute tensors simultaneously. This can almost eliminate all model moving time, and speed up SDXL on 30XX/40XX devices with small VRAM (eg, RTX 4050 6GB, RTX 3060 Laptop 6GB, etc) by about 15\% to 25\%. However, this unfortunately cannot be set as default because I observe higher possibility of pure black images (Nan outputs) on 2060, and higher chance of OOM on 1080 and 2060. When the resolution is large, there is a chance that the computation time of one single attention layer is longer than the time for moving entire model to GPU. When that happens, the next attention layer will OOM since the GPU is filled with the entire model, and no remaining space is available for computing another attention layer. Most overhead detecting methods are not robust enough to be reliable on old devices (in my tests). Users need to enable this cmd flag at their own risk. + +4. `--pin-shared-memory` (This flag will make things **faster** but more risky). Effective only when used together with `--cuda-stream`. This will offload modules to Shared GPU Memory instead of system RAM when offloading models. On some 30XX/40XX devices with small VRAM (eg, RTX 4050 6GB, RTX 3060 Laptop 6GB, etc), I can observe significant (at least 20\%) speed-up for SDXL. However, this unfortunately cannot be set as default because the OOM of Shared GPU Memory is a much more severe problem than common GPU memory OOM. Pytorch does not provide any robust method to unload or detect Shared GPU Memory. Once the Shared GPU Memory OOM, the entire program will crash (observed with SDXL on GTX 1060/1050/1066), and there is no dynamic method to prevent or recover from the crash. Users need to enable this cmd flag at their own risk. + +If you really want to play with cmd flags, you can additionally control the GPU with: + +(extreme VRAM cases) + + --always-gpu + --always-cpu + +(rare attention cases) + + --attention-split + --attention-quad + --attention-pytorch + --disable-xformers + --disable-attention-upcast + +(float point type) + + --all-in-fp32 + --all-in-fp16 + --unet-in-bf16 + --unet-in-fp16 + --unet-in-fp8-e4m3fn + --unet-in-fp8-e5m2 + --vae-in-fp16 + --vae-in-fp32 + --vae-in-bf16 + --clip-in-fp8-e4m3fn + --clip-in-fp8-e5m2 + --clip-in-fp16 + --clip-in-fp32 + +(rare platforms) + + --directml + --disable-ipex-hijack + --pytorch-deterministic + +Again, Forge do not recommend users to use any cmd flags unless you are very sure that you really need these. + +# UNet Patcher + +Note that [Forge does not use any other software as backend](https://github.com/lllyasviel/stable-diffusion-webui-forge/discussions/169). The full name of the backend is `Stable Diffusion WebUI with Forge backend`, or for simplicity, the `Forge backend`. The API and python symbols are made similar to previous software only for reducing the learning cost of developers. + +Now developing an extension is super simple. We finally have a patchable UNet. + +Below is using one single file with 80 lines of codes to support FreeU: + +`extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py` + +```python +import torch +import gradio as gr +from modules import scripts + + +def Fourier_filter(x, threshold, scale): + x_freq = torch.fft.fftn(x.float(), dim=(-2, -1)) + x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1)) + B, C, H, W = x_freq.shape + mask = torch.ones((B, C, H, W), device=x.device) + crow, ccol = H // 2, W //2 + mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale + x_freq = x_freq * mask + x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1)) + x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real + return x_filtered.to(x.dtype) + + +def set_freeu_v2_patch(model, b1, b2, s1, s2): + model_channels = model.model.model_config.unet_config["model_channels"] + scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} + + def output_block_patch(h, hsp, *args, **kwargs): + scale = scale_dict.get(h.shape[1], None) + if scale is not None: + hidden_mean = h.mean(1).unsqueeze(1) + B = hidden_mean.shape[0] + hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) + hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) + hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / \ + (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) + h[:, :h.shape[1] // 2] = h[:, :h.shape[1] // 2] * ((scale[0] - 1) * hidden_mean + 1) + hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) + return h, hsp + + m = model.clone() + m.set_model_output_block_patch(output_block_patch) + return m + + +class FreeUForForge(scripts.Script): + def title(self): + return "FreeU Integrated" + + def show(self, is_img2img): + # make this extension visible in both txt2img and img2img tab. + return scripts.AlwaysVisible + + def ui(self, *args, **kwargs): + with gr.Accordion(open=False, label=self.title()): + freeu_enabled = gr.Checkbox(label='Enabled', value=False) + freeu_b1 = gr.Slider(label='B1', minimum=0, maximum=2, step=0.01, value=1.01) + freeu_b2 = gr.Slider(label='B2', minimum=0, maximum=2, step=0.01, value=1.02) + freeu_s1 = gr.Slider(label='S1', minimum=0, maximum=4, step=0.01, value=0.99) + freeu_s2 = gr.Slider(label='S2', minimum=0, maximum=4, step=0.01, value=0.95) + + return freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2 + + def process_before_every_sampling(self, p, *script_args, **kwargs): + # This will be called before every sampling. + # If you use highres fix, this will be called twice. + + freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2 = script_args + + if not freeu_enabled: + return + + unet = p.sd_model.forge_objects.unet + + unet = set_freeu_v2_patch(unet, freeu_b1, freeu_b2, freeu_s1, freeu_s2) + + p.sd_model.forge_objects.unet = unet + + # Below codes will add some logs to the texts below the image outputs on UI. + # The extra_generation_params does not influence results. + p.extra_generation_params.update(dict( + freeu_enabled=freeu_enabled, + freeu_b1=freeu_b1, + freeu_b2=freeu_b2, + freeu_s1=freeu_s1, + freeu_s2=freeu_s2, + )) + + return +``` + +It looks like this: + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/277bac6e-5ea7-4bff-b71a-e55a60cfc03c) + +Similar components like HyperTile, KohyaHighResFix, SAG, can all be implemented within 100 lines of codes (see also the codes). + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/06472b03-b833-4816-ab47-70712ac024d3) + +ControlNets can finally be called by different extensions. + +Implementing Stable Video Diffusion and Zero123 are also super simple now (see also the codes). + +*Stable Video Diffusion:* + +`extensions-builtin/sd_forge_svd/scripts/forge_svd.py` + +```python +import torch +import gradio as gr +import os +import pathlib + +from modules import script_callbacks +from modules.paths import models_path +from modules.ui_common import ToolButton, refresh_symbol +from modules import shared + +from modules_forge.forge_util import numpy_to_pytorch, pytorch_to_numpy +from ldm_patched.modules.sd import load_checkpoint_guess_config +from ldm_patched.contrib.external_video_model import VideoLinearCFGGuidance, SVD_img2vid_Conditioning +from ldm_patched.contrib.external import KSampler, VAEDecode + + +opVideoLinearCFGGuidance = VideoLinearCFGGuidance() +opSVD_img2vid_Conditioning = SVD_img2vid_Conditioning() +opKSampler = KSampler() +opVAEDecode = VAEDecode() + +svd_root = os.path.join(models_path, 'svd') +os.makedirs(svd_root, exist_ok=True) +svd_filenames = [] + + +def update_svd_filenames(): + global svd_filenames + svd_filenames = [ + pathlib.Path(x).name for x in + shared.walk_files(svd_root, allowed_extensions=[".pt", ".ckpt", ".safetensors"]) + ] + return svd_filenames + + +@torch.inference_mode() +@torch.no_grad() +def predict(filename, width, height, video_frames, motion_bucket_id, fps, augmentation_level, + sampling_seed, sampling_steps, sampling_cfg, sampling_sampler_name, sampling_scheduler, + sampling_denoise, guidance_min_cfg, input_image): + filename = os.path.join(svd_root, filename) + model_raw, _, vae, clip_vision = \ + load_checkpoint_guess_config(filename, output_vae=True, output_clip=False, output_clipvision=True) + model = opVideoLinearCFGGuidance.patch(model_raw, guidance_min_cfg)[0] + init_image = numpy_to_pytorch(input_image) + positive, negative, latent_image = opSVD_img2vid_Conditioning.encode( + clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level) + output_latent = opKSampler.sample(model, sampling_seed, sampling_steps, sampling_cfg, + sampling_sampler_name, sampling_scheduler, positive, + negative, latent_image, sampling_denoise)[0] + output_pixels = opVAEDecode.decode(vae, output_latent)[0] + outputs = pytorch_to_numpy(output_pixels) + return outputs + + +def on_ui_tabs(): + with gr.Blocks() as svd_block: + with gr.Row(): + with gr.Column(): + input_image = gr.Image(label='Input Image', source='upload', type='numpy', height=400) + + with gr.Row(): + filename = gr.Dropdown(label="SVD Checkpoint Filename", + choices=svd_filenames, + value=svd_filenames[0] if len(svd_filenames) > 0 else None) + refresh_button = ToolButton(value=refresh_symbol, tooltip="Refresh") + refresh_button.click( + fn=lambda: gr.update(choices=update_svd_filenames), + inputs=[], outputs=filename) + + width = gr.Slider(label='Width', minimum=16, maximum=8192, step=8, value=1024) + height = gr.Slider(label='Height', minimum=16, maximum=8192, step=8, value=576) + video_frames = gr.Slider(label='Video Frames', minimum=1, maximum=4096, step=1, value=14) + motion_bucket_id = gr.Slider(label='Motion Bucket Id', minimum=1, maximum=1023, step=1, value=127) + fps = gr.Slider(label='Fps', minimum=1, maximum=1024, step=1, value=6) + augmentation_level = gr.Slider(label='Augmentation Level', minimum=0.0, maximum=10.0, step=0.01, + value=0.0) + sampling_steps = gr.Slider(label='Sampling Steps', minimum=1, maximum=200, step=1, value=20) + sampling_cfg = gr.Slider(label='CFG Scale', minimum=0.0, maximum=50.0, step=0.1, value=2.5) + sampling_denoise = gr.Slider(label='Sampling Denoise', minimum=0.0, maximum=1.0, step=0.01, value=1.0) + guidance_min_cfg = gr.Slider(label='Guidance Min Cfg', minimum=0.0, maximum=100.0, step=0.5, value=1.0) + sampling_sampler_name = gr.Radio(label='Sampler Name', + choices=['euler', 'euler_ancestral', 'heun', 'heunpp2', 'dpm_2', + 'dpm_2_ancestral', 'lms', 'dpm_fast', 'dpm_adaptive', + 'dpmpp_2s_ancestral', 'dpmpp_sde', 'dpmpp_sde_gpu', + 'dpmpp_2m', 'dpmpp_2m_sde', 'dpmpp_2m_sde_gpu', + 'dpmpp_3m_sde', 'dpmpp_3m_sde_gpu', 'ddpm', 'lcm', 'ddim', + 'uni_pc', 'uni_pc_bh2'], value='euler') + sampling_scheduler = gr.Radio(label='Scheduler', + choices=['normal', 'karras', 'exponential', 'sgm_uniform', 'simple', + 'ddim_uniform'], value='karras') + sampling_seed = gr.Number(label='Seed', value=12345, precision=0) + + generate_button = gr.Button(value="Generate") + + ctrls = [filename, width, height, video_frames, motion_bucket_id, fps, augmentation_level, + sampling_seed, sampling_steps, sampling_cfg, sampling_sampler_name, sampling_scheduler, + sampling_denoise, guidance_min_cfg, input_image] + + with gr.Column(): + output_gallery = gr.Gallery(label='Gallery', show_label=False, object_fit='contain', + visible=True, height=1024, columns=4) + + generate_button.click(predict, inputs=ctrls, outputs=[output_gallery]) + return [(svd_block, "SVD", "svd")] + + +update_svd_filenames() +script_callbacks.on_ui_tabs(on_ui_tabs) +``` + +Note that although the above codes look like independent codes, they actually will automatically offload/unload any other models. For example, below is me opening webui, load SDXL, generated an image, then go to SVD, then generated image frames. You can see that the GPU memory is perfectly managed and the SDXL is moved to RAM then SVD is moved to GPU. + +Note that this management is fully automatic. This makes writing extensions super simple. + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/de1a2d05-344a-44d7-bab8-9ecc0a58a8d3) + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/14bcefcf-599f-42c3-bce9-3fd5e428dd91) + +Similarly, Zero123: + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/7685019c-7239-47fb-9cb5-2b7b33943285) + +### Write a simple ControlNet: + +Below is a simple extension to have a completely independent pass of ControlNet that never conflicts any other extensions: + +`extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py` + +Note that this extension is hidden because it is only for developers. To see it in UI, use `--show-controlnet-example`. + +The memory optimization in this example is fully automatic. You do not need to care about memory and inference speed, but you may want to cache objects if you wish. + +```python +# Use --show-controlnet-example to see this extension. + +import cv2 +import gradio as gr +import torch + +from modules import scripts +from modules.shared_cmd_options import cmd_opts +from modules_forge.shared import supported_preprocessors +from modules.modelloader import load_file_from_url +from ldm_patched.modules.controlnet import load_controlnet +from modules_forge.controlnet import apply_controlnet_advanced +from modules_forge.forge_util import numpy_to_pytorch +from modules_forge.shared import controlnet_dir + + +class ControlNetExampleForge(scripts.Script): + model = None + + def title(self): + return "ControlNet Example for Developers" + + def show(self, is_img2img): + # make this extension visible in both txt2img and img2img tab. + return scripts.AlwaysVisible + + def ui(self, *args, **kwargs): + with gr.Accordion(open=False, label=self.title()): + gr.HTML('This is an example controlnet extension for developers.') + gr.HTML('You see this extension because you used --show-controlnet-example') + input_image = gr.Image(source='upload', type='numpy') + funny_slider = gr.Slider(label='This slider does nothing. It just shows you how to transfer parameters.', + minimum=0.0, maximum=1.0, value=0.5) + + return input_image, funny_slider + + def process(self, p, *script_args, **kwargs): + input_image, funny_slider = script_args + + # This slider does nothing. It just shows you how to transfer parameters. + del funny_slider + + if input_image is None: + return + + # controlnet_canny_path = load_file_from_url( + # url='https://huggingface.co/lllyasviel/sd_control_collection/resolve/main/sai_xl_canny_256lora.safetensors', + # model_dir=model_dir, + # file_name='sai_xl_canny_256lora.safetensors' + # ) + controlnet_canny_path = load_file_from_url( + url='https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/control_v11p_sd15_canny_fp16.safetensors', + model_dir=controlnet_dir, + file_name='control_v11p_sd15_canny_fp16.safetensors' + ) + print('The model [control_v11p_sd15_canny_fp16.safetensors] download finished.') + + self.model = load_controlnet(controlnet_canny_path) + print('Controlnet loaded.') + + return + + def process_before_every_sampling(self, p, *script_args, **kwargs): + # This will be called before every sampling. + # If you use highres fix, this will be called twice. + + input_image, funny_slider = script_args + + if input_image is None or self.model is None: + return + + B, C, H, W = kwargs['noise'].shape # latent_shape + height = H * 8 + width = W * 8 + batch_size = p.batch_size + + preprocessor = supported_preprocessors['canny'] + + # detect control at certain resolution + control_image = preprocessor( + input_image, resolution=512, slider_1=100, slider_2=200, slider_3=None) + + # here we just use nearest neighbour to align input shape. + # You may want crop and resize, or crop and fill, or others. + control_image = cv2.resize( + control_image, (width, height), interpolation=cv2.INTER_NEAREST) + + # Output preprocessor result. Now called every sampling. Cache in your own way. + p.extra_result_images.append(control_image) + + print('Preprocessor Canny finished.') + + control_image_bchw = numpy_to_pytorch(control_image).movedim(-1, 1) + + unet = p.sd_model.forge_objects.unet + + # Unet has input, middle, output blocks, and we can give different weights + # to each layers in all blocks. + # Below is an example for stronger control in middle block. + # This is helpful for some high-res fix passes. (p.is_hr_pass) + positive_advanced_weighting = { + 'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2], + 'middle': [1.0], + 'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2] + } + negative_advanced_weighting = { + 'input': [0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95, 1.05, 1.15, 1.25], + 'middle': [1.05], + 'output': [0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95, 1.05, 1.15, 1.25] + } + + # The advanced_frame_weighting is a weight applied to each image in a batch. + # The length of this list must be same with batch size + # For example, if batch size is 5, the below list is [0.2, 0.4, 0.6, 0.8, 1.0] + # If you view the 5 images as 5 frames in a video, this will lead to + # progressively stronger control over time. + advanced_frame_weighting = [float(i + 1) / float(batch_size) for i in range(batch_size)] + + # The advanced_sigma_weighting allows you to dynamically compute control + # weights given diffusion timestep (sigma). + # For example below code can softly make beginning steps stronger than ending steps. + sigma_max = unet.model.model_sampling.sigma_max + sigma_min = unet.model.model_sampling.sigma_min + advanced_sigma_weighting = lambda s: (s - sigma_min) / (sigma_max - sigma_min) + + # You can even input a tensor to mask all control injections + # The mask will be automatically resized during inference in UNet. + # The size should be B 1 H W and the H and W are not important + # because they will be resized automatically + advanced_mask_weighting = torch.ones(size=(1, 1, 512, 512)) + + # But in this simple example we do not use them + positive_advanced_weighting = None + negative_advanced_weighting = None + advanced_frame_weighting = None + advanced_sigma_weighting = None + advanced_mask_weighting = None + + unet = apply_controlnet_advanced(unet=unet, controlnet=self.model, image_bchw=control_image_bchw, + strength=0.6, start_percent=0.0, end_percent=0.8, + positive_advanced_weighting=positive_advanced_weighting, + negative_advanced_weighting=negative_advanced_weighting, + advanced_frame_weighting=advanced_frame_weighting, + advanced_sigma_weighting=advanced_sigma_weighting, + advanced_mask_weighting=advanced_mask_weighting) + + p.sd_model.forge_objects.unet = unet + + # Below codes will add some logs to the texts below the image outputs on UI. + # The extra_generation_params does not influence results. + p.extra_generation_params.update(dict( + controlnet_info='You should see these texts below output images!', + )) + + return + + +# Use --show-controlnet-example to see this extension. +if not cmd_opts.show_controlnet_example: + del ControlNetExampleForge + +``` + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/822fa2fc-c9f4-4f58-8669-4b6680b91063) + + +### Add a preprocessor + +Below is the full codes to add a normalbae preprocessor with perfect memory managements. + +You can use arbitrary independent extensions to add a preprocessor. + +Your preprocessor will be read by all other extensions using `modules_forge.shared.preprocessors` + +Below codes are in `extensions-builtin\forge_preprocessor_normalbae\scripts\preprocessor_normalbae.py` + +```python +from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter +from modules_forge.shared import preprocessor_dir, add_supported_preprocessor +from modules_forge.forge_util import resize_image_with_pad +from modules.modelloader import load_file_from_url + +import types +import torch +import numpy as np + +from einops import rearrange +from annotator.normalbae.models.NNET import NNET +from annotator.normalbae import load_checkpoint +from torchvision import transforms + + +class PreprocessorNormalBae(Preprocessor): + def __init__(self): + super().__init__() + self.name = 'normalbae' + self.tags = ['NormalMap'] + self.model_filename_filters = ['normal'] + self.slider_resolution = PreprocessorParameter( + label='Resolution', minimum=128, maximum=2048, value=512, step=8, visible=True) + self.slider_1 = PreprocessorParameter(visible=False) + self.slider_2 = PreprocessorParameter(visible=False) + self.slider_3 = PreprocessorParameter(visible=False) + self.show_control_mode = True + self.do_not_need_model = False + self.sorting_priority = 100 # higher goes to top in the list + + def load_model(self): + if self.model_patcher is not None: + return + + model_path = load_file_from_url( + "https://huggingface.co/lllyasviel/Annotators/resolve/main/scannet.pt", + model_dir=preprocessor_dir) + + args = types.SimpleNamespace() + args.mode = 'client' + args.architecture = 'BN' + args.pretrained = 'scannet' + args.sampling_ratio = 0.4 + args.importance_ratio = 0.7 + model = NNET(args) + model = load_checkpoint(model_path, model) + self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + self.model_patcher = self.setup_model_patcher(model) + + def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, **kwargs): + input_image, remove_pad = resize_image_with_pad(input_image, resolution) + + self.load_model() + + self.move_all_model_patchers_to_gpu() + + assert input_image.ndim == 3 + image_normal = input_image + + with torch.no_grad(): + image_normal = self.send_tensor_to_model_device(torch.from_numpy(image_normal)) + image_normal = image_normal / 255.0 + image_normal = rearrange(image_normal, 'h w c -> 1 c h w') + image_normal = self.norm(image_normal) + + normal = self.model_patcher.model(image_normal) + normal = normal[0][-1][:, :3] + normal = ((normal + 1) * 0.5).clip(0, 1) + + normal = rearrange(normal[0], 'c h w -> h w c').cpu().numpy() + normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8) + + return remove_pad(normal_image) + + +add_supported_preprocessor(PreprocessorNormalBae()) + +``` + +# New features (that are not available in original WebUI) + +Thanks to Unet Patcher, many new things are possible now and supported in Forge, including SVD, Z123, masked Ip-adapter, masked controlnet, photomaker, etc. + +Masked Ip-Adapter + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/d26630f9-922d-4483-8bf9-f364dca5fd50) + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/03580ef7-235c-4b03-9ca6-a27677a5a175) + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/d9ed4a01-70d4-45b4-a6a7-2f765f158fae) + +Masked ControlNet + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/872d4785-60e4-4431-85c7-665c781dddaa) + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/335a3b33-1ef8-46ff-a462-9f1b4f2c49fc) + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/b3684a15-8895-414e-8188-487269dfcada) + +PhotoMaker + +(Note that photomaker is a special control that need you to add the trigger word "photomaker". Your prompt should be like "a photo of photomaker") + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/07b0b626-05b5-473b-9d69-3657624d59be) + +Marigold Depth + +![image](https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/bdf54148-892d-410d-8ed9-70b4b121b6e7) + +# New Samplers (that are not in origin) + + DDPM + DDPM Karras + DPM++ 2M Turbo + DPM++ 2M SDE Turbo + LCM Karras + Euler A Turbo + +# About Extensions + +ControlNet and TiledVAE are integrated, and you should uninstall these two extensions: + + sd-webui-controlnet + multidiffusion-upscaler-for-automatic1111 + +Note that **AnimateDiff** is under construction by [continue-revolution](https://github.com/continue-revolution) at [sd-webui-animatediff forge/master branch](https://github.com/continue-revolution/sd-webui-animatediff/tree/forge/master) and [sd-forge-animatediff](https://github.com/continue-revolution/sd-forge-animatediff) (they are in sync). (continue-revolution original words: prompt travel, inf t2v, controlnet v2v have been proven to work well; motion lora, i2i batch still under construction and may be finished in a week") + +Other extensions should work without problems, like: + + canvas-zoom + translations/localizations + Dynamic Prompts + Adetailer + Ultimate SD Upscale + Reactor + +However, if newer extensions use Forge, their codes can be much shorter. + +Usually if an old extension rework using Forge's unet patcher, 80% codes can be removed, especially when they need to call controlnet. + +# Contribution + +Forge uses a bot to get commits and codes from https://github.com/AUTOMATIC1111/stable-diffusion-webui/tree/dev every afternoon (if merge is automatically successful by a git bot, or by my compiler, or by my ChatGPT bot) or mid-night (if my compiler and my ChatGPT bot both failed to merge and I review it manually). + +All PRs that can be implemented in https://github.com/AUTOMATIC1111/stable-diffusion-webui/tree/dev should submit PRs there. + +Feel free to submit PRs related to the functionality of Forge here. diff --git a/_typos.toml b/_typos.toml new file mode 100644 index 000000000..1c63fe703 --- /dev/null +++ b/_typos.toml @@ -0,0 +1,5 @@ +[default.extend-words] +# Part of "RGBa" (Pillow's pre-multiplied alpha RGB mode) +Ba = "Ba" +# HSA is something AMD uses for their GPUs +HSA = "HSA" diff --git a/configs/alt-diffusion-inference.yaml b/configs/alt-diffusion-inference.yaml index cfbee72d7..4944ab5c8 100644 --- a/configs/alt-diffusion-inference.yaml +++ b/configs/alt-diffusion-inference.yaml @@ -40,7 +40,7 @@ model: use_spatial_transformer: True transformer_depth: 1 context_dim: 768 - use_checkpoint: True + use_checkpoint: False legacy: False first_stage_config: diff --git a/configs/alt-diffusion-m18-inference.yaml b/configs/alt-diffusion-m18-inference.yaml index 41a031d55..c60dca8c7 100644 --- a/configs/alt-diffusion-m18-inference.yaml +++ b/configs/alt-diffusion-m18-inference.yaml @@ -41,7 +41,7 @@ model: use_linear_in_transformer: True transformer_depth: 1 context_dim: 1024 - use_checkpoint: True + use_checkpoint: False legacy: False first_stage_config: diff --git a/configs/instruct-pix2pix.yaml b/configs/instruct-pix2pix.yaml index 4e896879d..564e50ae2 100644 --- a/configs/instruct-pix2pix.yaml +++ b/configs/instruct-pix2pix.yaml @@ -45,7 +45,7 @@ model: use_spatial_transformer: True transformer_depth: 1 context_dim: 768 - use_checkpoint: True + use_checkpoint: False legacy: False first_stage_config: diff --git a/configs/sd3-inference.yaml b/configs/sd3-inference.yaml new file mode 100644 index 000000000..bccb69d2e --- /dev/null +++ b/configs/sd3-inference.yaml @@ -0,0 +1,5 @@ +model: + target: modules.models.sd3.sd3_model.SD3Inferencer + params: + shift: 3 + state_dict: null diff --git a/configs/sd_xl_inpaint.yaml b/configs/sd_xl_inpaint.yaml index 3bad37218..f40f45e33 100644 --- a/configs/sd_xl_inpaint.yaml +++ b/configs/sd_xl_inpaint.yaml @@ -21,7 +21,7 @@ model: params: adm_in_channels: 2816 num_classes: sequential - use_checkpoint: True + use_checkpoint: False in_channels: 9 out_channels: 4 model_channels: 320 diff --git a/configs/v1-inference.yaml b/configs/v1-inference.yaml index d4effe569..25c4d9ed0 100644 --- a/configs/v1-inference.yaml +++ b/configs/v1-inference.yaml @@ -40,7 +40,7 @@ model: use_spatial_transformer: True transformer_depth: 1 context_dim: 768 - use_checkpoint: True + use_checkpoint: False legacy: False first_stage_config: diff --git a/configs/v1-inpainting-inference.yaml b/configs/v1-inpainting-inference.yaml index f9eec37d2..68c199f99 100644 --- a/configs/v1-inpainting-inference.yaml +++ b/configs/v1-inpainting-inference.yaml @@ -40,7 +40,7 @@ model: use_spatial_transformer: True transformer_depth: 1 context_dim: 768 - use_checkpoint: True + use_checkpoint: False legacy: False first_stage_config: diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py index 04adc5eb2..51ab18212 100644 --- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py @@ -301,7 +301,7 @@ def p_losses(self, x_start, t, noise=None): elif self.parameterization == "x0": target = x_start else: - raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported") loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) @@ -572,7 +572,7 @@ def delta_border(self, h, w): :param h: height :param w: width :return: normalized distance to image border, - wtith min distance = 0 at border and max dist = 0.5 at image center + with min distance = 0 at border and max dist = 0.5 at image center """ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) arr = self.meshgrid(h, w) / lower_right_corner @@ -880,7 +880,7 @@ def forward(self, x, c, *args, **kwargs): def apply_model(self, x_noisy, t, cond, return_ids=False): if isinstance(cond, dict): - # hybrid case, cond is exptected to be a dict + # hybrid case, cond is expected to be a dict pass else: if not isinstance(cond, list): @@ -916,7 +916,7 @@ def apply_model(self, x_noisy, t, cond, return_ids=False): cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] elif self.cond_stage_key == 'coordinates_bbox': - assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + assert 'original_image_size' in self.split_input_params, 'BoundingBoxRescaling is missing original_image_size' # assuming padding of unfold is always 0 and its dilation is always 1 n_patches_per_row = int((w - ks[0]) / stride[0] + 1) @@ -926,7 +926,7 @@ def apply_model(self, x_noisy, t, cond, return_ids=False): num_downs = self.first_stage_model.encoder.num_resolutions - 1 rescale_latent = 2 ** (num_downs) - # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # get top left positions of patches as conforming for the bbbox tokenizer, therefore we # need to rescale the tl patch coordinates to be in between (0,1) tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index 005ff32cb..17a620f77 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -9,6 +9,8 @@ def __init__(self): self.errors = {} """mapping of network names to the number of errors the network had during operation""" + remove_symbols = str.maketrans('', '', ":,") + def activate(self, p, params_list): additional = shared.opts.sd_lora @@ -43,22 +45,15 @@ def activate(self, p, params_list): networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims) if shared.opts.lora_add_hashes_to_infotext: - network_hashes = [] - for item in networks.loaded_networks: - shorthash = item.network_on_disk.shorthash - if not shorthash: - continue - - alias = item.mentioned_name - if not alias: - continue + if not getattr(p, "is_hr_pass", False) or not hasattr(p, "lora_hashes"): + p.lora_hashes = {} - alias = alias.replace(":", "").replace(",", "") - - network_hashes.append(f"{alias}: {shorthash}") + for item in networks.loaded_networks: + if item.network_on_disk.shorthash and item.mentioned_name: + p.lora_hashes[item.mentioned_name.translate(self.remove_symbols)] = item.network_on_disk.shorthash - if network_hashes: - p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes) + if p.lora_hashes: + p.extra_generation_params["Lora hashes"] = ', '.join(f'{k}: {v}' for k, v in p.lora_hashes.items()) def deactivate(self, p): if self.errors: diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 5eb7de96b..89987438a 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from modules import sd_models, cache, errors, hashes, shared +import modules.models.sd3.mmdit NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) @@ -29,7 +30,6 @@ def __init__(self, name, filename): def read_metadata(): metadata = sd_models.read_metadata_from_safetensors(filename) - metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text return metadata @@ -115,8 +115,17 @@ def __init__(self, net: Network, weights: NetworkWeights): self.sd_key = weights.sd_key self.sd_module = weights.sd_module - if hasattr(self.sd_module, 'weight'): + if isinstance(self.sd_module, modules.models.sd3.mmdit.QkvLinear): + s = self.sd_module.weight.shape + self.shape = (s[0] // 3, s[1]) + elif hasattr(self.sd_module, 'weight'): self.shape = self.sd_module.weight.shape + elif isinstance(self.sd_module, nn.MultiheadAttention): + # For now, only self-attn use Pytorch's MHA + # So assume all qkvo proj have same shape + self.shape = self.sd_module.out_proj.weight.shape + else: + self.shape = None self.ops = None self.extra_kwargs = {} @@ -146,6 +155,9 @@ def __init__(self, net: Network, weights: NetworkWeights): self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None self.scale = weights.w["scale"].item() if "scale" in weights.w else None + self.dora_scale = weights.w.get("dora_scale", None) + self.dora_norm_dims = len(self.shape) - 1 + def multiplier(self): if 'transformer' in self.sd_key[:20]: return self.network.te_multiplier @@ -160,6 +172,27 @@ def calc_scale(self): return 1.0 + def apply_weight_decompose(self, updown, orig_weight): + # Match the device/dtype + orig_weight = orig_weight.to(updown.dtype) + dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype) + updown = updown.to(orig_weight.device) + + merged_scale1 = updown + orig_weight + merged_scale1_norm = ( + merged_scale1.transpose(0, 1) + .reshape(merged_scale1.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims) + .transpose(0, 1) + ) + + dora_merged = ( + merged_scale1 * (dora_scale / merged_scale1_norm) + ) + final_updown = dora_merged - orig_weight + return final_updown + def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) @@ -175,7 +208,12 @@ def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if ex_bias is not None: ex_bias = ex_bias * self.multiplier() - return updown * self.calc_scale() * self.multiplier(), ex_bias + updown = updown * self.calc_scale() + + if self.dora_scale is not None: + updown = self.apply_weight_decompose(updown, orig_weight) + + return updown * self.multiplier(), ex_bias def calc_updown(self, target): raise NotImplementedError() diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 7a5a5269c..9ebeda6ec 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -1,3 +1,6 @@ +from __future__ import annotations +import gradio as gr +import logging import os import re @@ -26,6 +29,14 @@ def assign_network_names_to_compvis_modules(sd_model): pass +class BundledTIHash(str): + def __init__(self, hash_str): + self.hash = hash_str + + def __str__(self): + return self.hash if shared.opts.lora_bundled_ti_to_infotext else '' + + def load_network(name, network_on_disk): net = network.Network(name, network_on_disk) net.mtime = os.path.getmtime(network_on_disk.filename) @@ -46,6 +57,16 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No loaded_networks.clear() + unavailable_networks = [] + for name in names: + if name.lower() in forbidden_network_aliases and available_networks.get(name) is None: + unavailable_networks.append(name) + elif available_network_aliases.get(name) is None: + unavailable_networks.append(name) + + if unavailable_networks: + update_available_networks_by_names(unavailable_networks) + networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names] if any(x is None for x in networks_on_disk): list_available_networks() @@ -84,6 +105,28 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No return +def allowed_layer_without_weight(layer): + if isinstance(layer, torch.nn.LayerNorm) and not layer.elementwise_affine: + return True + + return False + + +def store_weights_backup(weight): + if weight is None: + return None + + return weight.to(devices.cpu, copy=True) + + +def restore_weights_backup(obj, field, weight): + if weight is None: + setattr(obj, field, None) + return + + getattr(obj, field).copy_(weight) + + def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): pass @@ -140,21 +183,15 @@ def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): pass -def list_available_networks(): - available_networks.clear() - available_network_aliases.clear() - forbidden_network_aliases.clear() - available_network_hash_lookup.clear() - forbidden_network_aliases.update({"none": 1, "Addams": 1}) - - os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) - +def process_network_files(names: list[str] | None = None): candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) for filename in candidates: if os.path.isdir(filename): continue - name = os.path.splitext(os.path.basename(filename))[0] + # if names is provided, only load networks with names in the list + if names and name not in names: + continue try: entry = network.NetworkOnDisk(name, filename) except OSError: # should catch FileNotFoundError and PermissionError etc. @@ -170,6 +207,22 @@ def list_available_networks(): available_network_aliases[entry.alias] = entry +def update_available_networks_by_names(names: list[str]): + process_network_files(names) + + +def list_available_networks(): + available_networks.clear() + available_network_aliases.clear() + forbidden_network_aliases.clear() + available_network_hash_lookup.clear() + forbidden_network_aliases.update({"none": 1, "Addams": 1}) + + os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) + + process_network_files() + + re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index daf870625..9e9e4ad8d 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -35,7 +35,8 @@ def before_ui(): "sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks), "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}), "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), - "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), + "lora_bundled_ti_to_infotext": shared.OptionInfo(True, "Add Lora name as TI hashes for bundled Textual Inversion").info('"Add Textual Inversion hashes to infotext" needs to be enabled'), + "lora_filter_disabled": shared.OptionInfo(True, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), "lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"), diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index 3160aecfa..b6c4d1c6a 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -21,10 +21,12 @@ def is_non_comma_tagset(tags): def build_tags(metadata): tags = {} - for _, tags_dict in metadata.get("ss_tag_frequency", {}).items(): - for tag, tag_count in tags_dict.items(): - tag = tag.strip() - tags[tag] = tags.get(tag, 0) + int(tag_count) + ss_tag_frequency = metadata.get("ss_tag_frequency", {}) + if ss_tag_frequency is not None and hasattr(ss_tag_frequency, 'items'): + for _, tags_dict in ss_tag_frequency.items(): + for tag, tag_count in tags_dict.items(): + tag = tag.strip() + tags[tag] = tags.get(tag, 0) + int(tag_count) if tags and is_non_comma_tagset(tags): new_tags = {} @@ -149,6 +151,8 @@ def generate_random_prompt_from_tags(self, tags): v = random.random() * max_count if count > v: + for x in "({[]})": + tag = tag.replace(x, '\\' + x) res.append(tag) return ", ".join(sorted(res)) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 28f82ea4b..35e71be3b 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -31,7 +31,7 @@ def create_item(self, name, index=None, enable_filter=True): "name": name, "filename": lora_on_disk.filename, "shorthash": lora_on_disk.shorthash, - "preview": self.find_preview(path), + "preview": self.find_preview(path) or self.find_embedded_preview(path, name, lora_on_disk.metadata), "description": self.find_description(path), "search_terms": search_terms, "local_preview": f"{path}.{shared.opts.samples_format}", @@ -60,7 +60,7 @@ def create_item(self, name, index=None, enable_filter=True): else: sd_version = lora_on_disk.sd_version - if shared.opts.lora_show_all or not enable_filter: + if shared.opts.lora_filter_disabled or not enable_filter or not shared.sd_model: pass elif sd_version == network.SdVersion.Unknown: model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1 diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js deleted file mode 100644 index df60c1a17..000000000 --- a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js +++ /dev/null @@ -1,968 +0,0 @@ -onUiLoaded(async() => { - const elementIDs = { - img2imgTabs: "#mode_img2img .tab-nav", - inpaint: "#img2maskimg", - inpaintSketch: "#inpaint_sketch", - rangeGroup: "#img2img_column_size", - sketch: "#img2img_sketch" - }; - const tabNameToElementId = { - "Inpaint sketch": elementIDs.inpaintSketch, - "Inpaint": elementIDs.inpaint, - "Sketch": elementIDs.sketch - }; - - - // Helper functions - // Get active tab - - /** - * Waits for an element to be present in the DOM. - */ - const waitForElement = (id) => new Promise(resolve => { - const checkForElement = () => { - const element = document.querySelector(id); - if (element) return resolve(element); - setTimeout(checkForElement, 100); - }; - checkForElement(); - }); - - function getActiveTab(elements, all = false) { - const tabs = elements.img2imgTabs.querySelectorAll("button"); - - if (all) return tabs; - - for (let tab of tabs) { - if (tab.classList.contains("selected")) { - return tab; - } - } - } - - // Get tab ID - function getTabId(elements) { - const activeTab = getActiveTab(elements); - return tabNameToElementId[activeTab.innerText]; - } - - // Wait until opts loaded - async function waitForOpts() { - for (; ;) { - if (window.opts && Object.keys(window.opts).length) { - return window.opts; - } - await new Promise(resolve => setTimeout(resolve, 100)); - } - } - - // Detect whether the element has a horizontal scroll bar - function hasHorizontalScrollbar(element) { - return element.scrollWidth > element.clientWidth; - } - - // Function for defining the "Ctrl", "Shift" and "Alt" keys - function isModifierKey(event, key) { - switch (key) { - case "Ctrl": - return event.ctrlKey; - case "Shift": - return event.shiftKey; - case "Alt": - return event.altKey; - default: - return false; - } - } - - // Check if hotkey is valid - function isValidHotkey(value) { - const specialKeys = ["Ctrl", "Alt", "Shift", "Disable"]; - return ( - (typeof value === "string" && - value.length === 1 && - /[a-z]/i.test(value)) || - specialKeys.includes(value) - ); - } - - // Normalize hotkey - function normalizeHotkey(hotkey) { - return hotkey.length === 1 ? "Key" + hotkey.toUpperCase() : hotkey; - } - - // Format hotkey for display - function formatHotkeyForDisplay(hotkey) { - return hotkey.startsWith("Key") ? hotkey.slice(3) : hotkey; - } - - // Create hotkey configuration with the provided options - function createHotkeyConfig(defaultHotkeysConfig, hotkeysConfigOpts) { - const result = {}; // Resulting hotkey configuration - const usedKeys = new Set(); // Set of used hotkeys - - // Iterate through defaultHotkeysConfig keys - for (const key in defaultHotkeysConfig) { - const userValue = hotkeysConfigOpts[key]; // User-provided hotkey value - const defaultValue = defaultHotkeysConfig[key]; // Default hotkey value - - // Apply appropriate value for undefined, boolean, or object userValue - if ( - userValue === undefined || - typeof userValue === "boolean" || - typeof userValue === "object" || - userValue === "disable" - ) { - result[key] = - userValue === undefined ? defaultValue : userValue; - } else if (isValidHotkey(userValue)) { - const normalizedUserValue = normalizeHotkey(userValue); - - // Check for conflicting hotkeys - if (!usedKeys.has(normalizedUserValue)) { - usedKeys.add(normalizedUserValue); - result[key] = normalizedUserValue; - } else { - console.error( - `Hotkey: ${formatHotkeyForDisplay( - userValue - )} for ${key} is repeated and conflicts with another hotkey. The default hotkey is used: ${formatHotkeyForDisplay( - defaultValue - )}` - ); - result[key] = defaultValue; - } - } else { - console.error( - `Hotkey: ${formatHotkeyForDisplay( - userValue - )} for ${key} is not valid. The default hotkey is used: ${formatHotkeyForDisplay( - defaultValue - )}` - ); - result[key] = defaultValue; - } - } - - return result; - } - - // Disables functions in the config object based on the provided list of function names - function disableFunctions(config, disabledFunctions) { - // Bind the hasOwnProperty method to the functionMap object to avoid errors - const hasOwnProperty = - Object.prototype.hasOwnProperty.bind(functionMap); - - // Loop through the disabledFunctions array and disable the corresponding functions in the config object - disabledFunctions.forEach(funcName => { - if (hasOwnProperty(funcName)) { - const key = functionMap[funcName]; - config[key] = "disable"; - } - }); - - // Return the updated config object - return config; - } - - /** - * The restoreImgRedMask function displays a red mask around an image to indicate the aspect ratio. - * If the image display property is set to 'none', the mask breaks. To fix this, the function - * temporarily sets the display property to 'block' and then hides the mask again after 300 milliseconds - * to avoid breaking the canvas. Additionally, the function adjusts the mask to work correctly on - * very long images. - */ - function restoreImgRedMask(elements) { - const mainTabId = getTabId(elements); - - if (!mainTabId) return; - - const mainTab = gradioApp().querySelector(mainTabId); - const img = mainTab.querySelector("img"); - const imageARPreview = gradioApp().querySelector("#imageARPreview"); - - if (!img || !imageARPreview) return; - - imageARPreview.style.transform = ""; - if (parseFloat(mainTab.style.width) > 865) { - const transformString = mainTab.style.transform; - const scaleMatch = transformString.match( - /scale\(([-+]?[0-9]*\.?[0-9]+)\)/ - ); - let zoom = 1; // default zoom - - if (scaleMatch && scaleMatch[1]) { - zoom = Number(scaleMatch[1]); - } - - imageARPreview.style.transformOrigin = "0 0"; - imageARPreview.style.transform = `scale(${zoom})`; - } - - if (img.style.display !== "none") return; - - img.style.display = "block"; - - setTimeout(() => { - img.style.display = "none"; - }, 400); - } - - const hotkeysConfigOpts = await waitForOpts(); - - // Default config - const defaultHotkeysConfig = { - canvas_hotkey_zoom: "Alt", - canvas_hotkey_adjust: "Ctrl", - canvas_hotkey_reset: "KeyR", - canvas_hotkey_fullscreen: "KeyS", - canvas_hotkey_move: "KeyF", - canvas_hotkey_overlap: "KeyO", - canvas_hotkey_shrink_brush: "KeyQ", - canvas_hotkey_grow_brush: "KeyW", - canvas_disabled_functions: [], - canvas_show_tooltip: true, - canvas_auto_expand: true, - canvas_blur_prompt: false, - }; - - const functionMap = { - "Zoom": "canvas_hotkey_zoom", - "Adjust brush size": "canvas_hotkey_adjust", - "Hotkey shrink brush": "canvas_hotkey_shrink_brush", - "Hotkey enlarge brush": "canvas_hotkey_grow_brush", - "Moving canvas": "canvas_hotkey_move", - "Fullscreen": "canvas_hotkey_fullscreen", - "Reset Zoom": "canvas_hotkey_reset", - "Overlap": "canvas_hotkey_overlap" - }; - - // Loading the configuration from opts - const preHotkeysConfig = createHotkeyConfig( - defaultHotkeysConfig, - hotkeysConfigOpts - ); - - // Disable functions that are not needed by the user - const hotkeysConfig = disableFunctions( - preHotkeysConfig, - preHotkeysConfig.canvas_disabled_functions - ); - - let isMoving = false; - let mouseX, mouseY; - let activeElement; - - const elements = Object.fromEntries( - Object.keys(elementIDs).map(id => [ - id, - gradioApp().querySelector(elementIDs[id]) - ]) - ); - const elemData = {}; - - // Apply functionality to the range inputs. Restore redmask and correct for long images. - const rangeInputs = elements.rangeGroup ? - Array.from(elements.rangeGroup.querySelectorAll("input")) : - [ - gradioApp().querySelector("#img2img_width input[type='range']"), - gradioApp().querySelector("#img2img_height input[type='range']") - ]; - - for (const input of rangeInputs) { - input?.addEventListener("input", () => restoreImgRedMask(elements)); - } - - function applyZoomAndPan(elemId, isExtension = true) { - const targetElement = gradioApp().querySelector(elemId); - - if (!targetElement) { - console.log("Element not found"); - return; - } - - targetElement.style.transformOrigin = "0 0"; - - elemData[elemId] = { - zoom: 1, - panX: 0, - panY: 0 - }; - let fullScreenMode = false; - - // Create tooltip - function createTooltip() { - const toolTipElemnt = - targetElement.querySelector(".image-container"); - const tooltip = document.createElement("div"); - tooltip.className = "canvas-tooltip"; - - // Creating an item of information - const info = document.createElement("i"); - info.className = "canvas-tooltip-info"; - info.textContent = ""; - - // Create a container for the contents of the tooltip - const tooltipContent = document.createElement("div"); - tooltipContent.className = "canvas-tooltip-content"; - - // Define an array with hotkey information and their actions - const hotkeysInfo = [ - { - configKey: "canvas_hotkey_zoom", - action: "Zoom canvas", - keySuffix: " + wheel" - }, - { - configKey: "canvas_hotkey_adjust", - action: "Adjust brush size", - keySuffix: " + wheel" - }, - {configKey: "canvas_hotkey_reset", action: "Reset zoom"}, - { - configKey: "canvas_hotkey_fullscreen", - action: "Fullscreen mode" - }, - {configKey: "canvas_hotkey_move", action: "Move canvas"}, - {configKey: "canvas_hotkey_overlap", action: "Overlap"} - ]; - - // Create hotkeys array with disabled property based on the config values - const hotkeys = hotkeysInfo.map(info => { - const configValue = hotkeysConfig[info.configKey]; - const key = info.keySuffix ? - `${configValue}${info.keySuffix}` : - configValue.charAt(configValue.length - 1); - return { - key, - action: info.action, - disabled: configValue === "disable" - }; - }); - - for (const hotkey of hotkeys) { - if (hotkey.disabled) { - continue; - } - - const p = document.createElement("p"); - p.innerHTML = `${hotkey.key} - ${hotkey.action}`; - tooltipContent.appendChild(p); - } - - // Add information and content elements to the tooltip element - tooltip.appendChild(info); - tooltip.appendChild(tooltipContent); - - // Add a hint element to the target element - toolTipElemnt.appendChild(tooltip); - } - - //Show tool tip if setting enable - if (hotkeysConfig.canvas_show_tooltip) { - createTooltip(); - } - - // In the course of research, it was found that the tag img is very harmful when zooming and creates white canvases. This hack allows you to almost never think about this problem, it has no effect on webui. - function fixCanvas() { - const activeTab = getActiveTab(elements).textContent.trim(); - - if (activeTab !== "img2img") { - const img = targetElement.querySelector(`${elemId} img`); - - if (img && img.style.display !== "none") { - img.style.display = "none"; - img.style.visibility = "hidden"; - } - } - } - - // Reset the zoom level and pan position of the target element to their initial values - function resetZoom() { - elemData[elemId] = { - zoomLevel: 1, - panX: 0, - panY: 0 - }; - - if (isExtension) { - targetElement.style.overflow = "hidden"; - } - - targetElement.isZoomed = false; - - fixCanvas(); - targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`; - - const canvas = gradioApp().querySelector( - `${elemId} canvas[key="interface"]` - ); - - toggleOverlap("off"); - fullScreenMode = false; - - const closeBtn = targetElement.querySelector("button[aria-label='Remove Image']"); - if (closeBtn) { - closeBtn.addEventListener("click", resetZoom); - } - - if (canvas && isExtension) { - const parentElement = targetElement.closest('[id^="component-"]'); - if ( - canvas && - parseFloat(canvas.style.width) > parentElement.offsetWidth && - parseFloat(targetElement.style.width) > parentElement.offsetWidth - ) { - fitToElement(); - return; - } - - } - - if ( - canvas && - !isExtension && - parseFloat(canvas.style.width) > 865 && - parseFloat(targetElement.style.width) > 865 - ) { - fitToElement(); - return; - } - - targetElement.style.width = ""; - } - - // Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements - function toggleOverlap(forced = "") { - const zIndex1 = "0"; - const zIndex2 = "998"; - - targetElement.style.zIndex = - targetElement.style.zIndex !== zIndex2 ? zIndex2 : zIndex1; - - if (forced === "off") { - targetElement.style.zIndex = zIndex1; - } else if (forced === "on") { - targetElement.style.zIndex = zIndex2; - } - } - - // Adjust the brush size based on the deltaY value from a mouse wheel event - function adjustBrushSize( - elemId, - deltaY, - withoutValue = false, - percentage = 5 - ) { - const input = - gradioApp().querySelector( - `${elemId} input[aria-label='Brush radius']` - ) || - gradioApp().querySelector( - `${elemId} button[aria-label="Use brush"]` - ); - - if (input) { - input.click(); - if (!withoutValue) { - const maxValue = - parseFloat(input.getAttribute("max")) || 100; - const changeAmount = maxValue * (percentage / 100); - const newValue = - parseFloat(input.value) + - (deltaY > 0 ? -changeAmount : changeAmount); - input.value = Math.min(Math.max(newValue, 0), maxValue); - input.dispatchEvent(new Event("change")); - } - } - } - - // Reset zoom when uploading a new image - const fileInput = gradioApp().querySelector( - `${elemId} input[type="file"][accept="image/*"].svelte-116rqfv` - ); - fileInput.addEventListener("click", resetZoom); - - // Update the zoom level and pan position of the target element based on the values of the zoomLevel, panX and panY variables - function updateZoom(newZoomLevel, mouseX, mouseY) { - newZoomLevel = Math.max(0.1, Math.min(newZoomLevel, 15)); - - elemData[elemId].panX += - mouseX - (mouseX * newZoomLevel) / elemData[elemId].zoomLevel; - elemData[elemId].panY += - mouseY - (mouseY * newZoomLevel) / elemData[elemId].zoomLevel; - - targetElement.style.transformOrigin = "0 0"; - targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${newZoomLevel})`; - - toggleOverlap("on"); - if (isExtension) { - targetElement.style.overflow = "visible"; - } - - return newZoomLevel; - } - - // Change the zoom level based on user interaction - function changeZoomLevel(operation, e) { - if (isModifierKey(e, hotkeysConfig.canvas_hotkey_zoom)) { - e.preventDefault(); - - let zoomPosX, zoomPosY; - let delta = 0.2; - if (elemData[elemId].zoomLevel > 7) { - delta = 0.9; - } else if (elemData[elemId].zoomLevel > 2) { - delta = 0.6; - } - - zoomPosX = e.clientX; - zoomPosY = e.clientY; - - fullScreenMode = false; - elemData[elemId].zoomLevel = updateZoom( - elemData[elemId].zoomLevel + - (operation === "+" ? delta : -delta), - zoomPosX - targetElement.getBoundingClientRect().left, - zoomPosY - targetElement.getBoundingClientRect().top - ); - - targetElement.isZoomed = true; - } - } - - /** - * This function fits the target element to the screen by calculating - * the required scale and offsets. It also updates the global variables - * zoomLevel, panX, and panY to reflect the new state. - */ - - function fitToElement() { - //Reset Zoom - targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`; - - let parentElement; - - if (isExtension) { - parentElement = targetElement.closest('[id^="component-"]'); - } else { - parentElement = targetElement.parentElement; - } - - - // Get element and screen dimensions - const elementWidth = targetElement.offsetWidth; - const elementHeight = targetElement.offsetHeight; - - const screenWidth = parentElement.clientWidth; - const screenHeight = parentElement.clientHeight; - - // Get element's coordinates relative to the parent element - const elementRect = targetElement.getBoundingClientRect(); - const parentRect = parentElement.getBoundingClientRect(); - const elementX = elementRect.x - parentRect.x; - - // Calculate scale and offsets - const scaleX = screenWidth / elementWidth; - const scaleY = screenHeight / elementHeight; - const scale = Math.min(scaleX, scaleY); - - const transformOrigin = - window.getComputedStyle(targetElement).transformOrigin; - const [originX, originY] = transformOrigin.split(" "); - const originXValue = parseFloat(originX); - const originYValue = parseFloat(originY); - - const offsetX = - (screenWidth - elementWidth * scale) / 2 - - originXValue * (1 - scale); - const offsetY = - (screenHeight - elementHeight * scale) / 2.5 - - originYValue * (1 - scale); - - // Apply scale and offsets to the element - targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`; - - // Update global variables - elemData[elemId].zoomLevel = scale; - elemData[elemId].panX = offsetX; - elemData[elemId].panY = offsetY; - - fullScreenMode = false; - toggleOverlap("off"); - } - - /** - * This function fits the target element to the screen by calculating - * the required scale and offsets. It also updates the global variables - * zoomLevel, panX, and panY to reflect the new state. - */ - - // Fullscreen mode - function fitToScreen() { - const canvas = gradioApp().querySelector( - `${elemId} canvas[key="interface"]` - ); - - if (!canvas) return; - - if (canvas.offsetWidth > 862 || isExtension) { - targetElement.style.width = (canvas.offsetWidth + 2) + "px"; - } - - if (isExtension) { - targetElement.style.overflow = "visible"; - } - - if (fullScreenMode) { - resetZoom(); - fullScreenMode = false; - return; - } - - //Reset Zoom - targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`; - - // Get scrollbar width to right-align the image - const scrollbarWidth = - window.innerWidth - document.documentElement.clientWidth; - - // Get element and screen dimensions - const elementWidth = targetElement.offsetWidth; - const elementHeight = targetElement.offsetHeight; - const screenWidth = window.innerWidth - scrollbarWidth; - const screenHeight = window.innerHeight; - - // Get element's coordinates relative to the page - const elementRect = targetElement.getBoundingClientRect(); - const elementY = elementRect.y; - const elementX = elementRect.x; - - // Calculate scale and offsets - const scaleX = screenWidth / elementWidth; - const scaleY = screenHeight / elementHeight; - const scale = Math.min(scaleX, scaleY); - - // Get the current transformOrigin - const computedStyle = window.getComputedStyle(targetElement); - const transformOrigin = computedStyle.transformOrigin; - const [originX, originY] = transformOrigin.split(" "); - const originXValue = parseFloat(originX); - const originYValue = parseFloat(originY); - - // Calculate offsets with respect to the transformOrigin - const offsetX = - (screenWidth - elementWidth * scale) / 2 - - elementX - - originXValue * (1 - scale); - const offsetY = - (screenHeight - elementHeight * scale) / 2 - - elementY - - originYValue * (1 - scale); - - // Apply scale and offsets to the element - targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`; - - // Update global variables - elemData[elemId].zoomLevel = scale; - elemData[elemId].panX = offsetX; - elemData[elemId].panY = offsetY; - - fullScreenMode = true; - toggleOverlap("on"); - } - - // Handle keydown events - function handleKeyDown(event) { - // Disable key locks to make pasting from the buffer work correctly - if ((event.ctrlKey && event.code === 'KeyV') || (event.ctrlKey && event.code === 'KeyC') || event.code === "F5") { - return; - } - - // before activating shortcut, ensure user is not actively typing in an input field - if (!hotkeysConfig.canvas_blur_prompt) { - if (event.target.nodeName === 'TEXTAREA' || event.target.nodeName === 'INPUT') { - return; - } - } - - - const hotkeyActions = { - [hotkeysConfig.canvas_hotkey_reset]: resetZoom, - [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap, - [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen, - [hotkeysConfig.canvas_hotkey_shrink_brush]: () => adjustBrushSize(elemId, 10), - [hotkeysConfig.canvas_hotkey_grow_brush]: () => adjustBrushSize(elemId, -10) - }; - - const action = hotkeyActions[event.code]; - if (action) { - event.preventDefault(); - action(event); - } - - if ( - isModifierKey(event, hotkeysConfig.canvas_hotkey_zoom) || - isModifierKey(event, hotkeysConfig.canvas_hotkey_adjust) - ) { - event.preventDefault(); - } - } - - // Get Mouse position - function getMousePosition(e) { - mouseX = e.offsetX; - mouseY = e.offsetY; - } - - // Simulation of the function to put a long image into the screen. - // We detect if an image has a scroll bar or not, make a fullscreen to reveal the image, then reduce it to fit into the element. - // We hide the image and show it to the user when it is ready. - - targetElement.isExpanded = false; - function autoExpand() { - const canvas = document.querySelector(`${elemId} canvas[key="interface"]`); - if (canvas) { - if (hasHorizontalScrollbar(targetElement) && targetElement.isExpanded === false) { - targetElement.style.visibility = "hidden"; - setTimeout(() => { - fitToScreen(); - resetZoom(); - targetElement.style.visibility = "visible"; - targetElement.isExpanded = true; - }, 10); - } - } - } - - targetElement.addEventListener("mousemove", getMousePosition); - - //observers - // Creating an observer with a callback function to handle DOM changes - const observer = new MutationObserver((mutationsList, observer) => { - for (let mutation of mutationsList) { - // If the style attribute of the canvas has changed, by observation it happens only when the picture changes - if (mutation.type === 'attributes' && mutation.attributeName === 'style' && - mutation.target.tagName.toLowerCase() === 'canvas') { - targetElement.isExpanded = false; - setTimeout(resetZoom, 10); - } - } - }); - - // Apply auto expand if enabled - if (hotkeysConfig.canvas_auto_expand) { - targetElement.addEventListener("mousemove", autoExpand); - // Set up an observer to track attribute changes - observer.observe(targetElement, {attributes: true, childList: true, subtree: true}); - } - - // Handle events only inside the targetElement - let isKeyDownHandlerAttached = false; - - function handleMouseMove() { - if (!isKeyDownHandlerAttached) { - document.addEventListener("keydown", handleKeyDown); - isKeyDownHandlerAttached = true; - - activeElement = elemId; - } - } - - function handleMouseLeave() { - if (isKeyDownHandlerAttached) { - document.removeEventListener("keydown", handleKeyDown); - isKeyDownHandlerAttached = false; - - activeElement = null; - } - } - - // Add mouse event handlers - targetElement.addEventListener("mousemove", handleMouseMove); - targetElement.addEventListener("mouseleave", handleMouseLeave); - - // Reset zoom when click on another tab - elements.img2imgTabs.addEventListener("click", resetZoom); - elements.img2imgTabs.addEventListener("click", () => { - // targetElement.style.width = ""; - if (parseInt(targetElement.style.width) > 865) { - setTimeout(fitToElement, 0); - } - }); - - targetElement.addEventListener("wheel", e => { - // change zoom level - const operation = e.deltaY > 0 ? "-" : "+"; - changeZoomLevel(operation, e); - - // Handle brush size adjustment with ctrl key pressed - if (isModifierKey(e, hotkeysConfig.canvas_hotkey_adjust)) { - e.preventDefault(); - - // Increase or decrease brush size based on scroll direction - adjustBrushSize(elemId, e.deltaY); - } - }); - - // Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element. - function handleMoveKeyDown(e) { - - // Disable key locks to make pasting from the buffer work correctly - if ((e.ctrlKey && e.code === 'KeyV') || (e.ctrlKey && event.code === 'KeyC') || e.code === "F5") { - return; - } - - // before activating shortcut, ensure user is not actively typing in an input field - if (!hotkeysConfig.canvas_blur_prompt) { - if (e.target.nodeName === 'TEXTAREA' || e.target.nodeName === 'INPUT') { - return; - } - } - - - if (e.code === hotkeysConfig.canvas_hotkey_move) { - if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) { - e.preventDefault(); - document.activeElement.blur(); - isMoving = true; - } - } - } - - function handleMoveKeyUp(e) { - if (e.code === hotkeysConfig.canvas_hotkey_move) { - isMoving = false; - } - } - - document.addEventListener("keydown", handleMoveKeyDown); - document.addEventListener("keyup", handleMoveKeyUp); - - // Detect zoom level and update the pan speed. - function updatePanPosition(movementX, movementY) { - let panSpeed = 2; - - if (elemData[elemId].zoomLevel > 8) { - panSpeed = 3.5; - } - - elemData[elemId].panX += movementX * panSpeed; - elemData[elemId].panY += movementY * panSpeed; - - // Delayed redraw of an element - requestAnimationFrame(() => { - targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${elemData[elemId].zoomLevel})`; - toggleOverlap("on"); - }); - } - - function handleMoveByKey(e) { - if (isMoving && elemId === activeElement) { - updatePanPosition(e.movementX, e.movementY); - targetElement.style.pointerEvents = "none"; - - if (isExtension) { - targetElement.style.overflow = "visible"; - } - - } else { - targetElement.style.pointerEvents = "auto"; - } - } - - // Prevents sticking to the mouse - window.onblur = function() { - isMoving = false; - }; - - // Checks for extension - function checkForOutBox() { - const parentElement = targetElement.closest('[id^="component-"]'); - if (parentElement.offsetWidth < targetElement.offsetWidth && !targetElement.isExpanded) { - resetZoom(); - targetElement.isExpanded = true; - } - - if (parentElement.offsetWidth < targetElement.offsetWidth && elemData[elemId].zoomLevel == 1) { - resetZoom(); - } - - if (parentElement.offsetWidth < targetElement.offsetWidth && targetElement.offsetWidth * elemData[elemId].zoomLevel > parentElement.offsetWidth && elemData[elemId].zoomLevel < 1 && !targetElement.isZoomed) { - resetZoom(); - } - } - - if (isExtension) { - targetElement.addEventListener("mousemove", checkForOutBox); - } - - - window.addEventListener('resize', (e) => { - resetZoom(); - - if (isExtension) { - targetElement.isExpanded = false; - targetElement.isZoomed = false; - } - }); - - gradioApp().addEventListener("mousemove", handleMoveByKey); - - - } - - applyZoomAndPan(elementIDs.sketch, false); - applyZoomAndPan(elementIDs.inpaint, false); - applyZoomAndPan(elementIDs.inpaintSketch, false); - - // Make the function global so that other extensions can take advantage of this solution - const applyZoomAndPanIntegration = async(id, elementIDs) => { - const mainEl = document.querySelector(id); - if (id.toLocaleLowerCase() === "none") { - for (const elementID of elementIDs) { - const el = await waitForElement(elementID); - if (!el) break; - applyZoomAndPan(elementID); - } - return; - } - - if (!mainEl) return; - mainEl.addEventListener("click", async() => { - for (const elementID of elementIDs) { - const el = await waitForElement(elementID); - if (!el) break; - applyZoomAndPan(elementID); - } - }, {once: true}); - }; - - window.applyZoomAndPan = applyZoomAndPan; // Only 1 elements, argument elementID, for example applyZoomAndPan("#txt2img_controlnet_ControlNet_input_image") - - window.applyZoomAndPanIntegration = applyZoomAndPanIntegration; // for any extension - - /* - The function `applyZoomAndPanIntegration` takes two arguments: - - 1. `id`: A string identifier for the element to which zoom and pan functionality will be applied on click. - If the `id` value is "none", the functionality will be applied to all elements specified in the second argument without a click event. - - 2. `elementIDs`: An array of string identifiers for elements. Zoom and pan functionality will be applied to each of these elements on click of the element specified by the first argument. - If "none" is specified in the first argument, the functionality will be applied to each of these elements without a click event. - - Example usage: - applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]); - In this example, zoom and pan functionality will be applied to the element with the identifier "txt2img_controlnet_ControlNet_input_image" upon clicking the element with the identifier "txt2img_controlnet". - */ - - // More examples - // Add integration with ControlNet txt2img One TAB - // applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]); - - // Add integration with ControlNet txt2img Tabs - // applyZoomAndPanIntegration("#txt2img_controlnet",Array.from({ length: 10 }, (_, i) => `#txt2img_controlnet_ControlNet-${i}_input_image`)); - - // Add integration with Inpaint Anything - // applyZoomAndPanIntegration("None", ["#ia_sam_image", "#ia_sel_mask"]); -}); diff --git a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py deleted file mode 100644 index 89b7c31f2..000000000 --- a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py +++ /dev/null @@ -1,17 +0,0 @@ -import gradio as gr -from modules import shared - -shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), { - "canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), - "canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), - "canvas_hotkey_shrink_brush": shared.OptionInfo("Q", "Shrink the brush size"), - "canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"), - "canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"), - "canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "), - "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"), - "canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"), - "canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"), - "canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"), - "canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"), - "canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size","Hotkey enlarge brush","Hotkey shrink brush","Moving canvas","Fullscreen","Reset Zoom","Overlap"]}), -})) diff --git a/extensions-builtin/canvas-zoom-and-pan/style.css b/extensions-builtin/canvas-zoom-and-pan/style.css deleted file mode 100644 index 5d8054e65..000000000 --- a/extensions-builtin/canvas-zoom-and-pan/style.css +++ /dev/null @@ -1,66 +0,0 @@ -.canvas-tooltip-info { - position: absolute; - top: 10px; - left: 10px; - cursor: help; - background-color: rgba(0, 0, 0, 0.3); - width: 20px; - height: 20px; - border-radius: 50%; - display: flex; - align-items: center; - justify-content: center; - flex-direction: column; - - z-index: 100; -} - -.canvas-tooltip-info::after { - content: ''; - display: block; - width: 2px; - height: 7px; - background-color: white; - margin-top: 2px; -} - -.canvas-tooltip-info::before { - content: ''; - display: block; - width: 2px; - height: 2px; - background-color: white; -} - -.canvas-tooltip-content { - display: none; - background-color: #f9f9f9; - color: #333; - border: 1px solid #ddd; - padding: 15px; - position: absolute; - top: 40px; - left: 10px; - width: 250px; - font-size: 16px; - opacity: 0; - border-radius: 8px; - box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2); - - z-index: 100; -} - -.canvas-tooltip:hover .canvas-tooltip-content { - display: block; - animation: fadeIn 0.5s; - opacity: 1; -} - -@keyframes fadeIn { - from {opacity: 0;} - to {opacity: 1;} -} - -.styler { - overflow:inherit !important; -} \ No newline at end of file diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index 4c10d9c7d..a91bea4fa 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -1,7 +1,7 @@ import math import gradio as gr -from modules import scripts, shared, ui_components, ui_settings, infotext_utils +from modules import scripts, shared, ui_components, ui_settings, infotext_utils, errors from modules.ui_components import FormColumn @@ -42,7 +42,11 @@ def ui(self, is_img2img): setting_name = extra_options[index] with FormColumn(): - comp = ui_settings.create_setting_component(setting_name) + try: + comp = ui_settings.create_setting_component(setting_name) + except KeyError: + errors.report(f"Can't add extra options for {setting_name} in ui") + continue self.comps.append(comp) self.setting_names.append(setting_name) diff --git a/extensions-builtin/forge_legacy_preprocessors/install.py b/extensions-builtin/forge_legacy_preprocessors/install.py index 3a9bd1172..e20854acd 100644 --- a/extensions-builtin/forge_legacy_preprocessors/install.py +++ b/extensions-builtin/forge_legacy_preprocessors/install.py @@ -13,7 +13,7 @@ def comparable_version(version: str) -> Tuple: - return tuple(version.split(".")) + return tuple(map(int, version.split("."))) def get_installed_version(package: str) -> Optional[str]: diff --git a/scripts/processing_autosized_crop.py b/extensions-builtin/postprocessing-for-training/scripts/postprocessing_autosized_crop.py similarity index 97% rename from scripts/processing_autosized_crop.py rename to extensions-builtin/postprocessing-for-training/scripts/postprocessing_autosized_crop.py index 7e6749898..1e83de607 100644 --- a/scripts/processing_autosized_crop.py +++ b/extensions-builtin/postprocessing-for-training/scripts/postprocessing_autosized_crop.py @@ -1,64 +1,64 @@ -from PIL import Image - -from modules import scripts_postprocessing, ui_components -import gradio as gr - - -def center_crop(image: Image, w: int, h: int): - iw, ih = image.size - if ih / h < iw / w: - sw = w * ih / h - box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih - else: - sh = h * iw / w - box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2 - return image.resize((w, h), Image.Resampling.LANCZOS, box) - - -def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold): - iw, ih = image.size - err = lambda w, h: 1 - (lambda x: x if x < 1 else 1 / x)(iw / ih / (w / h)) - wh = max(((w, h) for w in range(mindim, maxdim + 1, 64) for h in range(mindim, maxdim + 1, 64) - if minarea <= w * h <= maxarea and err(w, h) <= threshold), - key=lambda wh: (wh[0] * wh[1], -err(*wh))[::1 if objective == 'Maximize area' else -1], - default=None - ) - return wh and center_crop(image, *wh) - - -class ScriptPostprocessingAutosizedCrop(scripts_postprocessing.ScriptPostprocessing): - name = "Auto-sized crop" - order = 4020 - - def ui(self): - with ui_components.InputAccordion(False, label="Auto-sized crop") as enable: - gr.Markdown('Each image is center-cropped with an automatically chosen width and height.') - with gr.Row(): - mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="postprocess_multicrop_mindim") - maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="postprocess_multicrop_maxdim") - with gr.Row(): - minarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area lower bound", value=64 * 64, elem_id="postprocess_multicrop_minarea") - maxarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area upper bound", value=640 * 640, elem_id="postprocess_multicrop_maxarea") - with gr.Row(): - objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="postprocess_multicrop_objective") - threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="postprocess_multicrop_threshold") - - return { - "enable": enable, - "mindim": mindim, - "maxdim": maxdim, - "minarea": minarea, - "maxarea": maxarea, - "objective": objective, - "threshold": threshold, - } - - def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, mindim, maxdim, minarea, maxarea, objective, threshold): - if not enable: - return - - cropped = multicrop_pic(pp.image, mindim, maxdim, minarea, maxarea, objective, threshold) - if cropped is not None: - pp.image = cropped - else: - print(f"skipped {pp.image.width}x{pp.image.height} image (can't find suitable size within error threshold)") +from PIL import Image + +from modules import scripts_postprocessing, ui_components +import gradio as gr + + +def center_crop(image: Image, w: int, h: int): + iw, ih = image.size + if ih / h < iw / w: + sw = w * ih / h + box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih + else: + sh = h * iw / w + box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2 + return image.resize((w, h), Image.Resampling.LANCZOS, box) + + +def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold): + iw, ih = image.size + err = lambda w, h: 1 - (lambda x: x if x < 1 else 1 / x)(iw / ih / (w / h)) + wh = max(((w, h) for w in range(mindim, maxdim + 1, 64) for h in range(mindim, maxdim + 1, 64) + if minarea <= w * h <= maxarea and err(w, h) <= threshold), + key=lambda wh: (wh[0] * wh[1], -err(*wh))[::1 if objective == 'Maximize area' else -1], + default=None + ) + return wh and center_crop(image, *wh) + + +class ScriptPostprocessingAutosizedCrop(scripts_postprocessing.ScriptPostprocessing): + name = "Auto-sized crop" + order = 4020 + + def ui(self): + with ui_components.InputAccordion(False, label="Auto-sized crop") as enable: + gr.Markdown('Each image is center-cropped with an automatically chosen width and height.') + with gr.Row(): + mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="postprocess_multicrop_mindim") + maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="postprocess_multicrop_maxdim") + with gr.Row(): + minarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area lower bound", value=64 * 64, elem_id="postprocess_multicrop_minarea") + maxarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area upper bound", value=640 * 640, elem_id="postprocess_multicrop_maxarea") + with gr.Row(): + objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="postprocess_multicrop_objective") + threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="postprocess_multicrop_threshold") + + return { + "enable": enable, + "mindim": mindim, + "maxdim": maxdim, + "minarea": minarea, + "maxarea": maxarea, + "objective": objective, + "threshold": threshold, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, mindim, maxdim, minarea, maxarea, objective, threshold): + if not enable: + return + + cropped = multicrop_pic(pp.image, mindim, maxdim, minarea, maxarea, objective, threshold) + if cropped is not None: + pp.image = cropped + else: + print(f"skipped {pp.image.width}x{pp.image.height} image (can't find suitable size within error threshold)") diff --git a/scripts/postprocessing_caption.py b/extensions-builtin/postprocessing-for-training/scripts/postprocessing_caption.py similarity index 96% rename from scripts/postprocessing_caption.py rename to extensions-builtin/postprocessing-for-training/scripts/postprocessing_caption.py index 5592a8987..758222ae2 100644 --- a/scripts/postprocessing_caption.py +++ b/extensions-builtin/postprocessing-for-training/scripts/postprocessing_caption.py @@ -1,30 +1,30 @@ -from modules import scripts_postprocessing, ui_components, deepbooru, shared -import gradio as gr - - -class ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing): - name = "Caption" - order = 4040 - - def ui(self): - with ui_components.InputAccordion(False, label="Caption") as enable: - option = gr.CheckboxGroup(value=["Deepbooru"], choices=["Deepbooru", "BLIP"], show_label=False) - - return { - "enable": enable, - "option": option, - } - - def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option): - if not enable: - return - - captions = [pp.caption] - - if "Deepbooru" in option: - captions.append(deepbooru.model.tag(pp.image)) - - if "BLIP" in option: - captions.append(shared.interrogator.interrogate(pp.image.convert("RGB"))) - - pp.caption = ", ".join([x for x in captions if x]) +from modules import scripts_postprocessing, ui_components, deepbooru, shared +import gradio as gr + + +class ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing): + name = "Caption" + order = 4040 + + def ui(self): + with ui_components.InputAccordion(False, label="Caption") as enable: + option = gr.CheckboxGroup(value=["Deepbooru"], choices=["Deepbooru", "BLIP"], show_label=False) + + return { + "enable": enable, + "option": option, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option): + if not enable: + return + + captions = [pp.caption] + + if "Deepbooru" in option: + captions.append(deepbooru.model.tag(pp.image)) + + if "BLIP" in option: + captions.append(shared.interrogator.interrogate(pp.image.convert("RGB"))) + + pp.caption = ", ".join([x for x in captions if x]) diff --git a/scripts/postprocessing_create_flipped_copies.py b/extensions-builtin/postprocessing-for-training/scripts/postprocessing_create_flipped_copies.py similarity index 97% rename from scripts/postprocessing_create_flipped_copies.py rename to extensions-builtin/postprocessing-for-training/scripts/postprocessing_create_flipped_copies.py index b673003b6..e7bd34038 100644 --- a/scripts/postprocessing_create_flipped_copies.py +++ b/extensions-builtin/postprocessing-for-training/scripts/postprocessing_create_flipped_copies.py @@ -1,32 +1,32 @@ -from PIL import ImageOps, Image - -from modules import scripts_postprocessing, ui_components -import gradio as gr - - -class ScriptPostprocessingCreateFlippedCopies(scripts_postprocessing.ScriptPostprocessing): - name = "Create flipped copies" - order = 4030 - - def ui(self): - with ui_components.InputAccordion(False, label="Create flipped copies") as enable: - with gr.Row(): - option = gr.CheckboxGroup(value=["Horizontal"], choices=["Horizontal", "Vertical", "Both"], show_label=False) - - return { - "enable": enable, - "option": option, - } - - def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option): - if not enable: - return - - if "Horizontal" in option: - pp.extra_images.append(ImageOps.mirror(pp.image)) - - if "Vertical" in option: - pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)) - - if "Both" in option: - pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).transpose(Image.Transpose.FLIP_LEFT_RIGHT)) +from PIL import ImageOps, Image + +from modules import scripts_postprocessing, ui_components +import gradio as gr + + +class ScriptPostprocessingCreateFlippedCopies(scripts_postprocessing.ScriptPostprocessing): + name = "Create flipped copies" + order = 4030 + + def ui(self): + with ui_components.InputAccordion(False, label="Create flipped copies") as enable: + with gr.Row(): + option = gr.CheckboxGroup(value=["Horizontal"], choices=["Horizontal", "Vertical", "Both"], show_label=False) + + return { + "enable": enable, + "option": option, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option): + if not enable: + return + + if "Horizontal" in option: + pp.extra_images.append(ImageOps.mirror(pp.image)) + + if "Vertical" in option: + pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)) + + if "Both" in option: + pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).transpose(Image.Transpose.FLIP_LEFT_RIGHT)) diff --git a/scripts/postprocessing_focal_crop.py b/extensions-builtin/postprocessing-for-training/scripts/postprocessing_focal_crop.py similarity index 97% rename from scripts/postprocessing_focal_crop.py rename to extensions-builtin/postprocessing-for-training/scripts/postprocessing_focal_crop.py index cff1dbc54..08fd2ccfb 100644 --- a/scripts/postprocessing_focal_crop.py +++ b/extensions-builtin/postprocessing-for-training/scripts/postprocessing_focal_crop.py @@ -1,54 +1,54 @@ - -from modules import scripts_postprocessing, ui_components, errors -import gradio as gr - -from modules.textual_inversion import autocrop - - -class ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing): - name = "Auto focal point crop" - order = 4010 - - def ui(self): - with ui_components.InputAccordion(False, label="Auto focal point crop") as enable: - face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_face_weight") - entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_entropy_weight") - edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_edges_weight") - debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") - - return { - "enable": enable, - "face_weight": face_weight, - "entropy_weight": entropy_weight, - "edges_weight": edges_weight, - "debug": debug, - } - - def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, face_weight, entropy_weight, edges_weight, debug): - if not enable: - return - - if not pp.shared.target_width or not pp.shared.target_height: - return - - dnn_model_path = None - try: - dnn_model_path = autocrop.download_and_cache_models() - except Exception: - errors.report("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", exc_info=True) - - autocrop_settings = autocrop.Settings( - crop_width=pp.shared.target_width, - crop_height=pp.shared.target_height, - face_points_weight=face_weight, - entropy_points_weight=entropy_weight, - corner_points_weight=edges_weight, - annotate_image=debug, - dnn_model_path=dnn_model_path, - ) - - result, *others = autocrop.crop_image(pp.image, autocrop_settings) - - pp.image = result - pp.extra_images = [pp.create_copy(x, nametags=["focal-crop-debug"], disable_processing=True) for x in others] - + +from modules import scripts_postprocessing, ui_components, errors +import gradio as gr + +from modules.textual_inversion import autocrop + + +class ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing): + name = "Auto focal point crop" + order = 4010 + + def ui(self): + with ui_components.InputAccordion(False, label="Auto focal point crop") as enable: + face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_face_weight") + entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_entropy_weight") + edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_edges_weight") + debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") + + return { + "enable": enable, + "face_weight": face_weight, + "entropy_weight": entropy_weight, + "edges_weight": edges_weight, + "debug": debug, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, face_weight, entropy_weight, edges_weight, debug): + if not enable: + return + + if not pp.shared.target_width or not pp.shared.target_height: + return + + dnn_model_path = None + try: + dnn_model_path = autocrop.download_and_cache_models() + except Exception: + errors.report("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", exc_info=True) + + autocrop_settings = autocrop.Settings( + crop_width=pp.shared.target_width, + crop_height=pp.shared.target_height, + face_points_weight=face_weight, + entropy_points_weight=entropy_weight, + corner_points_weight=edges_weight, + annotate_image=debug, + dnn_model_path=dnn_model_path, + ) + + result, *others = autocrop.crop_image(pp.image, autocrop_settings) + + pp.image = result + pp.extra_images = [pp.create_copy(x, nametags=["focal-crop-debug"], disable_processing=True) for x in others] + diff --git a/scripts/postprocessing_split_oversized.py b/extensions-builtin/postprocessing-for-training/scripts/postprocessing_split_oversized.py similarity index 97% rename from scripts/postprocessing_split_oversized.py rename to extensions-builtin/postprocessing-for-training/scripts/postprocessing_split_oversized.py index 133e199b8..888740e34 100644 --- a/scripts/postprocessing_split_oversized.py +++ b/extensions-builtin/postprocessing-for-training/scripts/postprocessing_split_oversized.py @@ -1,71 +1,71 @@ -import math - -from modules import scripts_postprocessing, ui_components -import gradio as gr - - -def split_pic(image, inverse_xy, width, height, overlap_ratio): - if inverse_xy: - from_w, from_h = image.height, image.width - to_w, to_h = height, width - else: - from_w, from_h = image.width, image.height - to_w, to_h = width, height - h = from_h * to_w // from_w - if inverse_xy: - image = image.resize((h, to_w)) - else: - image = image.resize((to_w, h)) - - split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio))) - y_step = (h - to_h) / (split_count - 1) - for i in range(split_count): - y = int(y_step * i) - if inverse_xy: - splitted = image.crop((y, 0, y + to_h, to_w)) - else: - splitted = image.crop((0, y, to_w, y + to_h)) - yield splitted - - -class ScriptPostprocessingSplitOversized(scripts_postprocessing.ScriptPostprocessing): - name = "Split oversized images" - order = 4000 - - def ui(self): - with ui_components.InputAccordion(False, label="Split oversized images") as enable: - with gr.Row(): - split_threshold = gr.Slider(label='Threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_split_threshold") - overlap_ratio = gr.Slider(label='Overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="postprocess_overlap_ratio") - - return { - "enable": enable, - "split_threshold": split_threshold, - "overlap_ratio": overlap_ratio, - } - - def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, split_threshold, overlap_ratio): - if not enable: - return - - width = pp.shared.target_width - height = pp.shared.target_height - - if not width or not height: - return - - if pp.image.height > pp.image.width: - ratio = (pp.image.width * height) / (pp.image.height * width) - inverse_xy = False - else: - ratio = (pp.image.height * width) / (pp.image.width * height) - inverse_xy = True - - if ratio >= 1.0 or ratio > split_threshold: - return - - result, *others = split_pic(pp.image, inverse_xy, width, height, overlap_ratio) - - pp.image = result - pp.extra_images = [pp.create_copy(x) for x in others] - +import math + +from modules import scripts_postprocessing, ui_components +import gradio as gr + + +def split_pic(image, inverse_xy, width, height, overlap_ratio): + if inverse_xy: + from_w, from_h = image.height, image.width + to_w, to_h = height, width + else: + from_w, from_h = image.width, image.height + to_w, to_h = width, height + h = from_h * to_w // from_w + if inverse_xy: + image = image.resize((h, to_w)) + else: + image = image.resize((to_w, h)) + + split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio))) + y_step = (h - to_h) / (split_count - 1) + for i in range(split_count): + y = int(y_step * i) + if inverse_xy: + splitted = image.crop((y, 0, y + to_h, to_w)) + else: + splitted = image.crop((0, y, to_w, y + to_h)) + yield splitted + + +class ScriptPostprocessingSplitOversized(scripts_postprocessing.ScriptPostprocessing): + name = "Split oversized images" + order = 4000 + + def ui(self): + with ui_components.InputAccordion(False, label="Split oversized images") as enable: + with gr.Row(): + split_threshold = gr.Slider(label='Threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_split_threshold") + overlap_ratio = gr.Slider(label='Overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="postprocess_overlap_ratio") + + return { + "enable": enable, + "split_threshold": split_threshold, + "overlap_ratio": overlap_ratio, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, split_threshold, overlap_ratio): + if not enable: + return + + width = pp.shared.target_width + height = pp.shared.target_height + + if not width or not height: + return + + if pp.image.height > pp.image.width: + ratio = (pp.image.width * height) / (pp.image.height * width) + inverse_xy = False + else: + ratio = (pp.image.height * width) / (pp.image.width * height) + inverse_xy = True + + if ratio >= 1.0 or ratio > split_threshold: + return + + result, *others = split_pic(pp.image, inverse_xy, width, height, overlap_ratio) + + pp.image = result + pp.extra_images = [pp.create_copy(x) for x in others] + diff --git a/extensions-builtin/sd_forge_controlnet/install.py b/extensions-builtin/sd_forge_controlnet/install.py index 5370d2213..99d3f0f92 100644 --- a/extensions-builtin/sd_forge_controlnet/install.py +++ b/extensions-builtin/sd_forge_controlnet/install.py @@ -13,7 +13,7 @@ def comparable_version(version: str) -> Tuple: - return tuple(version.split(".")) + return tuple(map(int, version.split("."))) def get_installed_version(package: str) -> Optional[str]: diff --git a/extensions-builtin/sd_forge_controlnet/javascript/active_units.js b/extensions-builtin/sd_forge_controlnet/javascript/active_units.js index a3ba0fc3a..72c7ca95d 100644 --- a/extensions-builtin/sd_forge_controlnet/javascript/active_units.js +++ b/extensions-builtin/sd_forge_controlnet/javascript/active_units.js @@ -95,7 +95,6 @@ this.attachImageUploadListener(); this.attachImageStateChangeObserver(); this.attachA1111SendInfoObserver(); - this.attachPresetDropdownObserver(); this.attachAccordionStateObserver(); } @@ -119,7 +118,7 @@ */ getUnitHeaderTextElement() { return this.tab.querySelector( - `:nth-child(${this.tabIndex + 1}) span.svelte-s1r2yt` + `button > span:nth-child(1)` ); } @@ -192,16 +191,17 @@ unitHeader.appendChild(span); } getInputImageSrc() { - const img = this.inputImageGroup.querySelector('.cnet-image img'); - return img ? img.src : null; + const img = this.inputImageGroup.querySelector('.cnet-image .forge-image'); + return (img && img.src.startsWith('data')) ? img.src : null; } getPreprocessorPreviewImageSrc() { - const img = this.generatedImageGroup.querySelector('.cnet-image img'); - return img ? img.src : null; + const img = this.generatedImageGroup.querySelector('.cnet-image .forge-image'); + return (img && img.src.startsWith('data')) ? img.src : null; } getMaskImageSrc() { function isEmptyCanvas(canvas) { if (!canvas) return true; + if (canvas.width == 0 || canvas.height ==0) return true; const ctx = canvas.getContext('2d'); // Get the image data const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); @@ -216,14 +216,14 @@ } return isPureBlack; } - const maskImg = this.maskImageGroup.querySelector('.cnet-mask-image img'); + const maskImg = this.maskImageGroup.querySelector('.cnet-mask-image .forge-image'); // Hand-drawn mask on mask upload. - const handDrawnMaskCanvas = this.maskImageGroup.querySelector('.cnet-mask-image canvas[key="mask"]'); + const handDrawnMaskCanvas = this.maskImageGroup.querySelector('.cnet-mask-image .forge-drawing-canvas'); // Hand-drawn mask on input image upload. - const inputImageHandDrawnMaskCanvas = this.inputImageGroup.querySelector('.cnet-image canvas[key="mask"]'); + const inputImageHandDrawnMaskCanvas = this.inputImageGroup.querySelector('.cnet-image .forge-drawing-canvas'); if (!isEmptyCanvas(handDrawnMaskCanvas)) { return handDrawnMaskCanvas.toDataURL(); - } else if (maskImg) { + } else if (maskImg && maskImg.src.startsWith('data')) { return maskImg.src; } else if (!isEmptyCanvas(inputImageHandDrawnMaskCanvas)) { return inputImageHandDrawnMaskCanvas.toDataURL(); @@ -347,25 +347,6 @@ } } - attachPresetDropdownObserver() { - const presetDropDown = this.tab.querySelector('.cnet-preset-dropdown'); - - new MutationObserver((mutationsList) => { - for (const mutation of mutationsList) { - if (mutation.removedNodes.length > 0) { - setTimeout(() => { - this.updateActiveState(); - this.updateActiveUnitCount(); - this.updateActiveControlType(); - }, 1000); - return; - } - } - }).observe(presetDropDown, { - childList: true, - subtree: true, - }); - } /** * Observer that triggers when the ControlNetUnit's accordion(tab) closes. */ diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py index b3e16df68..74f62b8d8 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py @@ -13,43 +13,44 @@ ) from lib_controlnet.logging import logger from lib_controlnet.controlnet_ui.openpose_editor import OpenposeEditor -from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI -from lib_controlnet.controlnet_ui.tool_button import ToolButton from lib_controlnet.controlnet_ui.photopea import Photopea from lib_controlnet.enums import InputMode, HiResFixOption from modules import shared, script_callbacks from modules.ui_components import FormRow from modules_forge.forge_util import HWC3 from lib_controlnet.external_code import UiControlNetUnit +from modules.ui_components import ToolButton +from gradio_rangeslider import RangeSlider +from modules_forge.forge_canvas.canvas import ForgeCanvas @dataclass class A1111Context: """Contains all components from A1111.""" - img2img_batch_input_dir: Optional[gr.components.IOComponent] = None - img2img_batch_output_dir: Optional[gr.components.IOComponent] = None - txt2img_submit_button: Optional[gr.components.IOComponent] = None - img2img_submit_button: Optional[gr.components.IOComponent] = None + img2img_batch_input_dir = None + img2img_batch_output_dir = None + txt2img_submit_button = None + img2img_submit_button = None # Slider controls from A1111 WebUI. - txt2img_w_slider: Optional[gr.components.IOComponent] = None - txt2img_h_slider: Optional[gr.components.IOComponent] = None - img2img_w_slider: Optional[gr.components.IOComponent] = None - img2img_h_slider: Optional[gr.components.IOComponent] = None + txt2img_w_slider = None + txt2img_h_slider = None + img2img_w_slider = None + img2img_h_slider = None - img2img_img2img_tab: Optional[gr.components.IOComponent] = None - img2img_img2img_sketch_tab: Optional[gr.components.IOComponent] = None - img2img_batch_tab: Optional[gr.components.IOComponent] = None - img2img_inpaint_tab: Optional[gr.components.IOComponent] = None - img2img_inpaint_sketch_tab: Optional[gr.components.IOComponent] = None - img2img_inpaint_upload_tab: Optional[gr.components.IOComponent] = None + img2img_img2img_tab = None + img2img_img2img_sketch_tab = None + img2img_batch_tab = None + img2img_inpaint_tab = None + img2img_inpaint_sketch_tab = None + img2img_inpaint_upload_tab = None - img2img_inpaint_area: Optional[gr.components.IOComponent] = None - txt2img_enable_hr: Optional[gr.components.IOComponent] = None + img2img_inpaint_area = None + txt2img_enable_hr = None @property - def img2img_inpaint_tabs(self) -> Tuple[gr.components.IOComponent]: + def img2img_inpaint_tabs(self): return ( self.img2img_inpaint_tab, self.img2img_inpaint_sketch_tab, @@ -57,7 +58,7 @@ def img2img_inpaint_tabs(self) -> Tuple[gr.components.IOComponent]: ) @property - def img2img_non_inpaint_tabs(self) -> Tuple[gr.components.IOComponent]: + def img2img_non_inpaint_tabs(self): return ( self.img2img_img2img_tab, self.img2img_img2img_sketch_tab, @@ -81,7 +82,7 @@ def ui_initialized(self) -> bool: if name not in optional_components.values() ) - def set_component(self, component: gr.components.IOComponent): + def set_component(self, component): id_mapping = { "img2img_batch_input_dir": "img2img_batch_input_dir", "img2img_batch_output_dir": "img2img_batch_output_dir", @@ -187,16 +188,13 @@ def __init__( self.batch_image_dir = None self.merge_tab = None self.batch_input_gallery = None - self.merge_upload_button = None - self.merge_clear_button = None + self.batch_mask_gallery = None self.create_canvas = None self.canvas_width = None self.canvas_height = None self.canvas_create_button = None self.canvas_cancel_button = None self.open_new_canvas_button = None - self.webcam_enable = None - self.webcam_mirror = None self.send_dimen_button = None self.pixel_perfect = None self.preprocessor_preview = None @@ -207,6 +205,7 @@ def __init__( self.model = None self.refresh_models = None self.weight = None + self.timestep_range = None self.guidance_start = None self.guidance_end = None self.advanced = None @@ -217,7 +216,6 @@ def __init__( self.resize_mode = None self.use_preview_as_input = None self.openpose_editor = None - self.preset_panel = None self.upload_independent_img_in_img2img = None self.image_upload_panel = None self.save_detected_map = None @@ -249,43 +247,34 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: with gr.Group(visible=not self.is_img2img) as self.image_upload_panel: self.save_detected_map = gr.Checkbox(value=True, visible=False) - with gr.Tabs(): + + with gr.Tabs(visible=True): with gr.Tab(label="Single Image") as self.upload_tab: with gr.Row(elem_classes=["cnet-image-row"], equal_height=True): with gr.Group(elem_classes=["cnet-input-image-group"]): - self.image = gr.Image( - source="upload", - brush_radius=20, - mirror_webcam=False, - type="numpy", - tool="sketch", + self.image = ForgeCanvas( elem_id=f"{elem_id_tabname}_{tabname}_input_image", elem_classes=["cnet-image"], - brush_color=shared.opts.img2img_inpaint_mask_brush_color - if hasattr( - shared.opts, "img2img_inpaint_mask_brush_color" - ) - else None, - ) - self.image.preprocess = functools.partial( - svg_preprocess, preprocess=self.image.preprocess + contrast_scribbles=True, + height=300, + numpy=True ) self.openpose_editor.render_upload() with gr.Group( - visible=False, elem_classes=["cnet-generated-image-group"] + visible=False, elem_classes=["cnet-generated-image-group"] ) as self.generated_image_group: - self.generated_image = gr.Image( - value=None, - label="Preprocessor Preview", + self.generated_image = ForgeCanvas( elem_id=f"{elem_id_tabname}_{tabname}_generated_image", elem_classes=["cnet-image"], - interactive=True, - height=242, - ) # Gradio's magic number. Only 242 works. + height=300, + no_scribbles=True, + no_upload=True, + numpy=True + ) with gr.Group( - elem_classes=["cnet-generated-image-control-group"] + elem_classes=["cnet-generated-image-control-group"] ): if self.photopea: self.photopea.render_child_trigger() @@ -299,22 +288,18 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: ) with gr.Group( - visible=False, elem_classes=["cnet-mask-image-group"] + visible=False, elem_classes=["cnet-mask-image-group"] ) as self.mask_image_group: - self.mask_image = gr.Image( - value=None, - label="Mask", + self.mask_image = ForgeCanvas( elem_id=f"{elem_id_tabname}_{tabname}_mask_image", elem_classes=["cnet-mask-image"], - interactive=True, - brush_radius=20, - type="numpy", - tool="sketch", - brush_color=shared.opts.img2img_inpaint_mask_brush_color - if hasattr( - shared.opts, "img2img_inpaint_mask_brush_color" - ) - else None, + height=300, + scribble_color='#FFFFFF', + scribble_width=1, + scribble_alpha_fixed=True, + scribble_color_fixed=True, + scribble_softness_fixed=True, + numpy=True ) with gr.Tab(label="Batch Folder") as self.batch_tab: @@ -337,28 +322,14 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: self.batch_input_gallery = gr.Gallery( columns=[4], rows=[2], object_fit="contain", height="auto", label="Images" ) - with gr.Row(): - self.merge_upload_button = gr.UploadButton( - "Upload Images", - file_types=["image"], - file_count="multiple", - ) - self.merge_clear_button = gr.Button("Clear Images") with gr.Group(visible=False, elem_classes=["cnet-mask-gallery-group"]) as self.batch_mask_gallery_group: with gr.Column(): self.batch_mask_gallery = gr.Gallery( columns=[4], rows=[2], object_fit="contain", height="auto", label="Masks" ) - with gr.Row(): - self.mask_merge_upload_button = gr.UploadButton( - "Upload Masks", - file_types=["image"], - file_count="multiple", - ) - self.mask_merge_clear_button = gr.Button("Clear Masks") if self.photopea: - self.photopea.attach_photopea_output(self.generated_image) + self.photopea.attach_photopea_output(self.generated_image.background) with gr.Accordion( label="Open New Canvas", visible=False @@ -397,23 +368,13 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: self.open_new_canvas_button = ToolButton( value=ControlNetUiGroup.open_symbol, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_open_new_canvas_button", + elem_classes=["cnet-toolbutton"], tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.open_symbol], ) - self.webcam_enable = ToolButton( - value=ControlNetUiGroup.camera_symbol, - elem_id=f"{elem_id_tabname}_{tabname}_controlnet_webcam_enable", - tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.camera_symbol], - ) - self.webcam_mirror = ToolButton( - value=ControlNetUiGroup.reverse_symbol, - elem_id=f"{elem_id_tabname}_{tabname}_controlnet_webcam_mirror", - tooltip=ControlNetUiGroup.tooltips[ - ControlNetUiGroup.reverse_symbol - ], - ) self.send_dimen_button = ToolButton( value=ControlNetUiGroup.tossup_symbol, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_send_dimen_button", + elem_classes=["cnet-toolbutton"], tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.tossup_symbol], ) @@ -481,7 +442,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: value=ControlNetUiGroup.trigger_symbol, visible=not self.is_img2img, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_trigger_preprocessor", - elem_classes=["cnet-run-preprocessor"], + elem_classes=["cnet-run-preprocessor", "cnet-toolbutton"], tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.trigger_symbol], ) self.model = gr.Dropdown( @@ -493,6 +454,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: self.refresh_models = ToolButton( value=ControlNetUiGroup.refresh_symbol, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_refresh_models", + elem_classes=["cnet-toolbutton"], tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.refresh_symbol], ) @@ -506,24 +468,22 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: elem_id=f"{elem_id_tabname}_{tabname}_controlnet_control_weight_slider", elem_classes="controlnet_control_weight_slider", ) - self.guidance_start = gr.Slider( - label="Starting Control Step", - value=self.default_unit.guidance_start, - minimum=0.0, + self.timestep_range = RangeSlider( + label='Timestep Range', + minimum=0, maximum=1.0, - interactive=True, - elem_id=f"{elem_id_tabname}_{tabname}_controlnet_start_control_step_slider", - elem_classes="controlnet_start_control_step_slider", - ) - self.guidance_end = gr.Slider( - label="Ending Control Step", - value=self.default_unit.guidance_end, - minimum=0.0, - maximum=1.0, - interactive=True, - elem_id=f"{elem_id_tabname}_{tabname}_controlnet_ending_control_step_slider", - elem_classes="controlnet_ending_control_step_slider", + value=(self.default_unit.guidance_start, self.default_unit.guidance_end), + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_control_step_slider", + elem_classes="controlnet_control_step_slider", ) + self.guidance_start = gr.State(self.default_unit.guidance_start) + self.guidance_end = gr.State(self.default_unit.guidance_end) + + self.timestep_range.change( + lambda x: (x[0], x[1]), + inputs=[self.timestep_range], + outputs=[self.guidance_start, self.guidance_end] + ) # advanced options with gr.Column(visible=False) as self.advanced: @@ -581,18 +541,6 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: visible=False, ) - # self.loopback = gr.Checkbox( - # label="[Batch Loopback] Automatically send generated images to this ControlNet unit in batch generation", - # value=self.default_unit.loopback, - # elem_id=f"{elem_id_tabname}_{tabname}_controlnet_automatically_send_generated_images_checkbox", - # elem_classes="controlnet_loopback_checkbox", - # visible=False, - # ) - - self.preset_panel = ControlNetPresetUI( - id_prefix=f"{elem_id_tabname}_{tabname}_" - ) - self.batch_image_dir_state = gr.State("") self.output_dir_state = gr.State("") unit_args = ( @@ -602,14 +550,16 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: self.batch_mask_dir, self.batch_input_gallery, self.batch_mask_gallery, - self.generated_image, - self.mask_image, + self.generated_image.background, + self.mask_image.background, + self.mask_image.foreground, self.hr_option, self.enabled, self.module, self.model, self.weight, - self.image, + self.image.background, + self.image.foreground, self.resize_mode, self.processor_res, self.threshold_a, @@ -665,41 +615,18 @@ def closesteight(num): else: return round(num + (8 - rem)) - if image: - interm = np.asarray(image.get("image")) - return closesteight(interm.shape[1]), closesteight(interm.shape[0]) + if image is not None: + return closesteight(image.shape[1]), closesteight(image.shape[0]) else: return gr.Slider.update(), gr.Slider.update() self.send_dimen_button.click( fn=send_dimensions, - inputs=[self.image], + inputs=[self.image.background], outputs=[self.width_slider, self.height_slider], show_progress=False, ) - def register_webcam_toggle(self): - def webcam_toggle(): - self.webcam_enabled = not self.webcam_enabled - return { - "value": None, - "source": "webcam" if self.webcam_enabled else "upload", - "__type__": "update", - } - - self.webcam_enable.click( - webcam_toggle, inputs=None, outputs=self.image, show_progress=False - ) - - def register_webcam_mirror_toggle(self): - def webcam_mirror_toggle(): - self.webcam_mirrored = not self.webcam_mirrored - return {"mirror_webcam": self.webcam_mirrored, "__type__": "update"} - - self.webcam_mirror.click( - webcam_mirror_toggle, inputs=None, outputs=self.image, show_progress=False - ) - def register_refresh_all_models(self): def refresh_all_models(): global_state.update_controlnet_filenames() @@ -799,16 +726,17 @@ def filter_selected(k: str): ) def register_run_annotator(self): - def run_annotator(image, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm): + def run_annotator(image, mask, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm): if image is None: return ( - gr.update(value=None, visible=True), + gr.update(visible=True), + None, gr.update(), *self.openpose_editor.update(""), ) - img = HWC3(image["image"]) - mask = HWC3(image["mask"]) + img = HWC3(image) + mask = HWC3(mask) if not (mask > 5).any(): mask = None @@ -862,8 +790,8 @@ def is_openpose(module: str): result = external_code.visualize_inpaint_mask(result) return ( - # Update to `generated_image` - gr.update(value=result, visible=True, interactive=False), + gr.update(visible=True), + result, # preprocessor_preview gr.update(value=True), # openpose editor @@ -873,7 +801,8 @@ def is_openpose(module: str): self.trigger_preprocessor.click( fn=run_annotator, inputs=[ - self.image, + self.image.background, + self.image.foreground, self.module, self.processor_res, self.threshold_a, @@ -884,7 +813,8 @@ def is_openpose(module: str): self.resize_mode, ], outputs=[ - self.generated_image, + self.generated_image.block, + self.generated_image.background, self.preprocessor_preview, *self.openpose_editor.outputs(), ], @@ -909,7 +839,7 @@ def shift_preview(is_on): fn=shift_preview, inputs=[self.preprocessor_preview], outputs=[ - self.generated_image, + self.generated_image.background, self.generated_image_group, self.use_preview_as_input, self.openpose_editor.download_link, @@ -920,27 +850,27 @@ def shift_preview(is_on): def register_create_canvas(self): self.open_new_canvas_button.click( - lambda: gr.Accordion.update(visible=True), + lambda: gr.update(visible=True), inputs=None, outputs=self.create_canvas, show_progress=False, ) self.canvas_cancel_button.click( - lambda: gr.Accordion.update(visible=False), + lambda: gr.update(visible=False), inputs=None, outputs=self.create_canvas, show_progress=False, ) def fn_canvas(h, w): - return np.zeros(shape=(h, w, 3), dtype=np.uint8), gr.Accordion.update( + return np.zeros(shape=(h, w, 3), dtype=np.uint8), gr.update( visible=False ) self.canvas_create_button.click( fn=fn_canvas, inputs=[self.canvas_height, self.canvas_width], - outputs=[self.image, self.create_canvas], + outputs=[self.image.background, self.create_canvas], show_progress=False, ) @@ -956,7 +886,7 @@ def fn_same_checked(x): fn_same_checked, inputs=self.upload_independent_img_in_img2img, outputs=[ - self.image, + self.image.background, self.batch_image_dir, self.preprocessor_preview, self.image_upload_panel, @@ -993,7 +923,7 @@ def on_checkbox_click(checked: bool, canvas_height: int, canvas_width: int): self.mask_upload.change( fn=on_checkbox_click, inputs=[self.mask_upload, self.height_slider, self.width_slider], - outputs=[self.mask_image_group, self.mask_image, self.batch_mask_dir, + outputs=[self.mask_image_group, self.mask_image.background, self.batch_mask_dir, self.batch_mask_gallery_group, self.batch_mask_gallery], show_progress=False, ) @@ -1073,106 +1003,27 @@ def clear_preview(x): event_subscriber( fn=clear_preview, inputs=self.use_preview_as_input, - outputs=[self.use_preview_as_input, self.generated_image], + outputs=[self.use_preview_as_input, self.generated_image.background], show_progress=False ) - def register_multi_images_upload(self): - """Register callbacks on merge tab multiple images upload.""" - self.merge_clear_button.click( - fn=lambda: [], - inputs=[], - outputs=[self.batch_input_gallery], - ).then( - fn=lambda x: gr.update(value=x + 1), - inputs=[self.dummy_gradio_update_trigger], - outputs=[self.dummy_gradio_update_trigger], - ) - self.mask_merge_clear_button.click( - fn=lambda: [], - inputs=[], - outputs=[self.batch_mask_gallery], - ).then( - fn=lambda x: gr.update(value=x + 1), - inputs=[self.dummy_gradio_update_trigger], - outputs=[self.dummy_gradio_update_trigger], - ) - - def upload_file(files, current_files): - return {file_d["name"] for file_d in current_files} | { - file.name for file in files - } - - self.merge_upload_button.upload( - upload_file, - inputs=[self.merge_upload_button, self.batch_input_gallery], - outputs=[self.batch_input_gallery], - queue=False, - ).then( - fn=lambda x: gr.update(value=x + 1), - inputs=[self.dummy_gradio_update_trigger], - outputs=[self.dummy_gradio_update_trigger], - ) - self.mask_merge_upload_button.upload( - upload_file, - inputs=[self.mask_merge_upload_button, self.batch_mask_gallery], - outputs=[self.batch_mask_gallery], - queue=False, - ).then( - fn=lambda x: gr.update(value=x + 1), - inputs=[self.dummy_gradio_update_trigger], - outputs=[self.dummy_gradio_update_trigger], - ) - return - def register_core_callbacks(self): """Register core callbacks that only involves gradio components defined within this ui group.""" - self.register_webcam_toggle() - self.register_webcam_mirror_toggle() self.register_refresh_all_models() self.register_build_sliders() self.register_shift_preview() self.register_create_canvas() self.register_clear_preview() - self.register_multi_images_upload() self.openpose_editor.register_callbacks( self.generated_image, self.use_preview_as_input, self.model, ) assert self.type_filter is not None - self.preset_panel.register_callbacks( - self, - self.type_filter, - *[ - getattr(self, key) - for key in external_code.ControlNetUnit.infotext_fields() - ], - ) if self.is_img2img: self.register_img2img_same_input() - def register_sd_model_changed(self): - def sd_version_changed(type_filter: str, current_model: str, setting_value: str, setting_name: str): - """When SD version changes, update model dropdown choices.""" - if setting_name != "sd_model_checkpoint": - return gr.update() - - filtered_model_list = global_state.get_filtered_controlnet_names(type_filter) - assert len(filtered_model_list) > 0 - default_model = filtered_model_list[1] if len(filtered_model_list) > 1 else filtered_model_list[0] - return gr.Dropdown.update( - choices=filtered_model_list, - value=current_model if current_model in filtered_model_list else default_model - ) - - script_callbacks.on_setting_updated_subscriber(dict( - fn=sd_version_changed, - inputs=[self.type_filter, self.model], - outputs=[self.model], - )) - def register_callbacks(self): """Register callbacks that involves A1111 context gradio components.""" # Prevent infinite recursion. @@ -1184,7 +1035,6 @@ def register_callbacks(self): self.register_run_annotator() self.register_sync_batch_dir() self.register_shift_upload_mask() - self.register_sd_model_changed() if self.is_img2img: self.register_shift_crop_input_image() else: diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/openpose_editor.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/openpose_editor.py index 4146018a1..a8a7e784c 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/openpose_editor.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/openpose_editor.py @@ -112,7 +112,7 @@ def render_pose(pose_url: str) -> Tuple[Dict, Dict]: self.render_button.click( fn=render_pose, inputs=[self.pose_input], - outputs=[generated_image, use_preview_as_input, *self.outputs()], + outputs=[generated_image.background, use_preview_as_input, *self.outputs()], ) def update_upload_link(model: str) -> Dict: diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/preset.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/preset.py deleted file mode 100644 index 15a9f24ca..000000000 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/preset.py +++ /dev/null @@ -1,313 +0,0 @@ -import os -import gradio as gr - -from typing import Dict, List - -from modules import scripts -from lib_controlnet.infotext import parse_unit, serialize_unit -from lib_controlnet.controlnet_ui.tool_button import ToolButton -from lib_controlnet.logging import logger -from lib_controlnet.external_code import ControlNetUnit, UiControlNetUnit -from lib_controlnet.global_state import get_preprocessor -from modules_forge.supported_preprocessor import Preprocessor - -save_symbol = "\U0001f4be" # 💾 -delete_symbol = "\U0001f5d1\ufe0f" # 🗑️ -refresh_symbol = "\U0001f504" # 🔄 -reset_symbol = "\U000021A9" # ↩ - -NEW_PRESET = "New Preset" - - -def load_presets(preset_dir: str) -> Dict[str, str]: - if not os.path.exists(preset_dir): - os.makedirs(preset_dir) - return {} - - presets = {} - for filename in os.listdir(preset_dir): - if filename.endswith(".txt"): - with open(os.path.join(preset_dir, filename), "r") as f: - name = filename.replace(".txt", "") - if name == NEW_PRESET: - continue - presets[name] = f.read() - return presets - - -def infer_control_type(module: str) -> str: - preprocessor: Preprocessor = get_preprocessor(module) - assert preprocessor is not None - return preprocessor.tags[0] if preprocessor.tags else "All" - - -class ControlNetPresetUI(object): - preset_directory = os.path.join(scripts.basedir(), "presets") - presets = load_presets(preset_directory) - - def __init__(self, id_prefix: str): - with gr.Row(): - self.dropdown = gr.Dropdown( - label="Presets", - show_label=True, - elem_classes=["cnet-preset-dropdown"], - choices=ControlNetPresetUI.dropdown_choices(), - value=NEW_PRESET, - ) - self.reset_button = ToolButton( - value=reset_symbol, - elem_classes=["cnet-preset-reset"], - tooltip="Reset preset", - visible=False, - ) - self.save_button = ToolButton( - value=save_symbol, - elem_classes=["cnet-preset-save"], - tooltip="Save preset", - ) - self.delete_button = ToolButton( - value=delete_symbol, - elem_classes=["cnet-preset-delete"], - tooltip="Delete preset", - ) - self.refresh_button = ToolButton( - value=refresh_symbol, - elem_classes=["cnet-preset-refresh"], - tooltip="Refresh preset", - ) - - with gr.Box( - elem_classes=["popup-dialog", "cnet-preset-enter-name"], - elem_id=f"{id_prefix}_cnet_preset_enter_name", - ) as self.name_dialog: - with gr.Row(): - self.preset_name = gr.Textbox( - label="Preset name", - show_label=True, - lines=1, - elem_classes=["cnet-preset-name"], - ) - self.confirm_preset_name = ToolButton( - value=save_symbol, - elem_classes=["cnet-preset-confirm-name"], - tooltip="Save preset", - ) - - def register_callbacks( - self, - uigroup, - control_type: gr.Radio, - *ui_states, - ): - def init_with_ui_states(*ui_states) -> ControlNetUnit: - return ControlNetUnit(**{ - field: value - for field, value in zip(ControlNetUnit.infotext_fields(), ui_states) - }) - - def apply_preset(name: str, control_type: str, *ui_states): - if name == NEW_PRESET: - return ( - gr.update(visible=False), - *( - (gr.skip(),) - * (len(ControlNetUnit.infotext_fields()) + 1) - ), - ) - - assert name in ControlNetPresetUI.presets - - infotext = ControlNetPresetUI.presets[name] - preset_unit = parse_unit(infotext) - current_unit = init_with_ui_states(*ui_states) - preset_unit.image = None - current_unit.image = None - - # Do not compare module param that are not used in preset. - for module_param in ("processor_res", "threshold_a", "threshold_b"): - if getattr(preset_unit, module_param) == -1: - setattr(current_unit, module_param, -1) - - # No update necessary. - if vars(current_unit) == vars(preset_unit): - return ( - gr.update(visible=False), - *( - (gr.skip(),) - * (len(ControlNetUnit.infotext_fields()) + 1) - ), - ) - - unit = preset_unit - - try: - new_control_type = infer_control_type(unit.module) - except ValueError as e: - logger.error(e) - new_control_type = control_type - - if new_control_type != control_type: - uigroup.prevent_next_n_module_update += 1 - - if preset_unit.module != current_unit.module: - uigroup.prevent_next_n_slider_value_update += 1 - - if preset_unit.pixel_perfect != current_unit.pixel_perfect: - uigroup.prevent_next_n_slider_value_update += 1 - - return ( - gr.update(visible=True), - gr.update(value=new_control_type), - *[ - gr.update(value=value) if value is not None else gr.update() - for field in ControlNetUnit.infotext_fields() - for value in (getattr(unit, field),) - ], - ) - - for element, action in ( - (self.dropdown, "change"), - (self.reset_button, "click"), - ): - getattr(element, action)( - fn=apply_preset, - inputs=[self.dropdown, control_type, *ui_states], - outputs=[self.delete_button, control_type, *ui_states], - show_progress="hidden", - ).then( - fn=lambda: gr.update(visible=False), - inputs=None, - outputs=[self.reset_button], - ) - - def save_preset(name: str, *ui_states): - if name == NEW_PRESET: - return gr.update(visible=True), gr.update(), gr.update() - - ControlNetPresetUI.save_preset( - name, init_with_ui_states(*ui_states) - ) - return ( - gr.update(), # name dialog - gr.update(choices=ControlNetPresetUI.dropdown_choices(), value=name), - gr.update(visible=False), # Reset button - ) - - self.save_button.click( - fn=save_preset, - inputs=[self.dropdown, *ui_states], - outputs=[self.name_dialog, self.dropdown, self.reset_button], - show_progress="hidden", - ).then( - fn=None, - _js=f""" - (name) => {{ - if (name === "{NEW_PRESET}") - popup(gradioApp().getElementById('{self.name_dialog.elem_id}')); - }}""", - inputs=[self.dropdown], - ) - - def delete_preset(name: str): - ControlNetPresetUI.delete_preset(name) - return gr.Dropdown.update( - choices=ControlNetPresetUI.dropdown_choices(), - value=NEW_PRESET, - ), gr.update(visible=False) - - self.delete_button.click( - fn=delete_preset, - inputs=[self.dropdown], - outputs=[self.dropdown, self.reset_button], - show_progress="hidden", - ) - - self.name_dialog.visible = False - - def save_new_preset(new_name: str, *ui_states): - if new_name == NEW_PRESET: - logger.warn(f"Cannot save preset with reserved name '{NEW_PRESET}'") - return gr.update(visible=False), gr.update() - - ControlNetPresetUI.save_preset( - new_name, init_with_ui_states(*ui_states) - ) - return gr.update(visible=False), gr.update( - choices=ControlNetPresetUI.dropdown_choices(), value=new_name - ) - - self.confirm_preset_name.click( - fn=save_new_preset, - inputs=[self.preset_name, *ui_states], - outputs=[self.name_dialog, self.dropdown], - show_progress="hidden", - ).then(fn=None, _js="closePopup") - - self.refresh_button.click( - fn=ControlNetPresetUI.refresh_preset, - inputs=None, - outputs=[self.dropdown], - show_progress="hidden", - ) - - def update_reset_button(preset_name: str, *ui_states): - if preset_name == NEW_PRESET: - return gr.update(visible=False) - - infotext = ControlNetPresetUI.presets[preset_name] - preset_unit = parse_unit(infotext) - current_unit = init_with_ui_states(*ui_states) - preset_unit.image = None - current_unit.image = None - - # Do not compare module param that are not used in preset. - for module_param in ("processor_res", "threshold_a", "threshold_b"): - if getattr(preset_unit, module_param) == -1: - setattr(current_unit, module_param, -1) - - return gr.update(visible=vars(current_unit) != vars(preset_unit)) - - for ui_state in ui_states: - if isinstance(ui_state, gr.Image): - continue - - for action in ("edit", "click", "change", "clear", "release"): - if action == "release" and not isinstance(ui_state, gr.Slider): - continue - - if hasattr(ui_state, action): - getattr(ui_state, action)( - fn=update_reset_button, - inputs=[self.dropdown, *ui_states], - outputs=[self.reset_button], - ) - - @staticmethod - def dropdown_choices() -> List[str]: - return list(ControlNetPresetUI.presets.keys()) + [NEW_PRESET] - - @staticmethod - def save_preset(name: str, unit: ControlNetUnit): - infotext = serialize_unit(unit) - with open( - os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt"), "w" - ) as f: - f.write(infotext) - - ControlNetPresetUI.presets[name] = infotext - - @staticmethod - def delete_preset(name: str): - if name not in ControlNetPresetUI.presets: - return - - del ControlNetPresetUI.presets[name] - - file = os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt") - if os.path.exists(file): - os.unlink(file) - - @staticmethod - def refresh_preset(): - ControlNetPresetUI.presets = load_presets(ControlNetPresetUI.preset_directory) - return gr.update(choices=ControlNetPresetUI.dropdown_choices()) diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/tool_button.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/tool_button.py deleted file mode 100644 index 8a38df8f4..000000000 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/tool_button.py +++ /dev/null @@ -1,12 +0,0 @@ -import gradio as gr - -class ToolButton(gr.Button, gr.components.FormComponent): - """Small button with single emoji as text, fits inside gradio forms""" - - def __init__(self, **kwargs): - super().__init__(variant="tool", - elem_classes=kwargs.pop('elem_classes', []) + ["cnet-toolbutton"], - **kwargs) - - def get_block_name(self): - return "button" diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py index 4954478ac..e7e769633 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py @@ -155,76 +155,31 @@ class GradioImageMaskPair(TypedDict): @dataclass class ControlNetUnit: - """Represents an entire ControlNet processing unit. - - To add a new field to this class - ## If the new field can be specified on UI, you need to - - Add a new field of the same name in constructor of `ControlNetUiGroup` - - Initialize the new `ControlNetUiGroup` field in `ControlNetUiGroup.render` - as a Gradio `IOComponent`. - - Add the new `ControlNetUiGroup` field to `unit_args` in - `ControlNetUiGroup.render`. The order of parameters matters. - - ## If the new field needs to appear in infotext, you need to - - Add a new item in `ControlNetUnit.infotext_fields`. - API-only fields cannot appear in infotext. - """ - # Following fields should only be used in the UI. - # ====== Start of UI only fields ====== - # Specifies the input mode for the unit, defaulting to a simple mode. input_mode: InputMode = InputMode.SIMPLE - # Determines whether to use the preview image as input; defaults to False. use_preview_as_input: bool = False - # Directory path for batch processing of images. batch_image_dir: str = '' - # Directory path for batch processing of masks. batch_mask_dir: str = '' - # Optional list of gallery images for batch input; defaults to None. batch_input_gallery: Optional[List[str]] = None - # Optional list of gallery masks for batch processing; defaults to None. batch_mask_gallery: Optional[List[str]] = None - # Holds the preview image as a NumPy array; defaults to None. generated_image: Optional[np.ndarray] = None - # ====== End of UI only fields ====== - - # Following fields are used in both the API and the UI. - # Holds the mask image; defaults to None. mask_image: Optional[GradioImageMaskPair] = None - # Specifies how this unit should be applied in each pass of high-resolution fix. - # Ignored if high-resolution fix is not enabled. + mask_image_fg: Optional[GradioImageMaskPair] = None hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH - # Indicates whether the unit is enabled; defaults to True. enabled: bool = True - # Name of the module being used; defaults to "None". module: str = "None" - # Name of the model being used; defaults to "None". model: str = "None" - # Weight of the unit in the overall processing; defaults to 1.0. weight: float = 1.0 - # Optional image for input; defaults to None. image: Optional[GradioImageMaskPair] = None - # Specifies the mode of image resizing; defaults to inner fit. + image_fg: Optional[GradioImageMaskPair] = None resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT - # Resolution for processing by the unit; defaults to -1 (unspecified). processor_res: int = -1 - # Threshold A for processing; defaults to -1 (unspecified). threshold_a: float = -1 - # Threshold B for processing; defaults to -1 (unspecified). threshold_b: float = -1 - # Start value for guidance; defaults to 0.0. guidance_start: float = 0.0 - # End value for guidance; defaults to 1.0. guidance_end: float = 1.0 - # Enables pixel-perfect processing; defaults to False. pixel_perfect: bool = False - # Control mode for the unit; defaults to balanced. control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED - - # Following fields should only be used in the API. - # ====== Start of API only fields ====== - # Whether to save the detected map for this unit; defaults to True. save_detected_map: bool = True - # ====== End of API only fields ====== @staticmethod def infotext_fields(): diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index 08aa5ff2a..9dc4a1d0a 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -157,7 +157,7 @@ def get_input_data(self, p, unit, preprocessor, h, w): if unit.input_mode == external_code.InputMode.MERGE: for idx, item in enumerate(unit.batch_input_gallery): - img_path = item['name'] + img_path = item[0] logger.info(f'Try to read image: {img_path}') img = np.ascontiguousarray(cv2.imread(img_path)[:, :, ::-1]).copy() mask = None @@ -197,30 +197,36 @@ def get_input_data(self, p, unit, preprocessor, h, w): using_a1111_data = False + unit_image = unit.image + unit_image_fg = unit.image_fg[:, :, 3] if unit.image_fg is not None else None + if unit.use_preview_as_input and unit.generated_image is not None: image = unit.generated_image elif unit.image is None: resize_mode = external_code.resize_mode_from_value(p.resize_mode) image = HWC3(np.asarray(a1111_i2i_image)) using_a1111_data = True - elif (unit.image['image'] < 5).all() and (unit.image['mask'] > 5).any(): - image = unit.image['mask'] + elif (unit_image < 5).all() and (unit_image_fg > 5).any(): + image = unit_image_fg else: - image = unit.image['image'] + image = unit_image if not isinstance(image, np.ndarray): raise ValueError("controlnet is enabled but no input image is given") image = HWC3(image) + unit_mask_image = unit.mask_image + unit_mask_image_fg = unit.mask_image_fg[:, :, 3] if unit.mask_image_fg is not None else None + if using_a1111_data: mask = HWC3(np.asarray(a1111_i2i_mask)) if a1111_i2i_mask is not None else None - elif unit.mask_image is not None and (unit.mask_image['image'] > 5).any(): - mask = unit.mask_image['image'] - elif unit.mask_image is not None and (unit.mask_image['mask'] > 5).any(): - mask = unit.mask_image['mask'] - elif unit.image is not None and (unit.image['mask'] > 5).any(): - mask = unit.image['mask'] + elif unit_mask_image_fg is not None and (unit_mask_image_fg > 5).any(): + mask = unit_mask_image_fg + elif unit_mask_image is not None and (unit_mask_image > 5).any(): + mask = unit_mask_image + elif unit_image_fg is not None and (unit_image_fg > 5).any(): + mask = unit_image_fg else: mask = None diff --git a/extensions-builtin/sd_forge_controlnet/style.css b/extensions-builtin/sd_forge_controlnet/style.css index a27207411..35135c7c3 100644 --- a/extensions-builtin/sd_forge_controlnet/style.css +++ b/extensions-builtin/sd_forge_controlnet/style.css @@ -225,4 +225,27 @@ border-radius: var(--radius-sm); background: var(--background-fill-primary); color: var(--block-label-text-color); -} \ No newline at end of file +} + +.controlnet_control_type_filter_group label { + background: unset !important; + border: unset !important; + margin-left: -10px !important; +} + +.controlnet_control_type_filter_group > span { + display: none !important; +} + +.controlnet_control_type_filter_group > .wrap { + margin-top: -20px !important; +} + +.cnet-toolbutton { + background: unset !important; + border: unset !important; +} + +.range-slider { + margin-top: -8px; +} diff --git a/extensions-builtin/sd_forge_controlnet/tests/conftest.py b/extensions-builtin/sd_forge_controlnet/tests/conftest.py deleted file mode 100644 index c8792cd97..000000000 --- a/extensions-builtin/sd_forge_controlnet/tests/conftest.py +++ /dev/null @@ -1,7 +0,0 @@ -import os - - -def pytest_configure(config): - # We don't want to fail on Py.test command line arguments being - # parsed by webui: - os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1") diff --git a/extensions-builtin/sd_forge_controlnet/tests/images/1girl.png b/extensions-builtin/sd_forge_controlnet/tests/images/1girl.png deleted file mode 100644 index d825e716b..000000000 Binary files a/extensions-builtin/sd_forge_controlnet/tests/images/1girl.png and /dev/null differ diff --git a/extensions-builtin/sd_forge_controlnet/tests/images/mask.png b/extensions-builtin/sd_forge_controlnet/tests/images/mask.png deleted file mode 100644 index 166203af0..000000000 Binary files a/extensions-builtin/sd_forge_controlnet/tests/images/mask.png and /dev/null differ diff --git a/extensions-builtin/sd_forge_controlnet/tests/images/mask_small.png b/extensions-builtin/sd_forge_controlnet/tests/images/mask_small.png deleted file mode 100644 index c48d77e47..000000000 Binary files a/extensions-builtin/sd_forge_controlnet/tests/images/mask_small.png and /dev/null differ diff --git a/extensions-builtin/sd_forge_controlnet/tests/images/portrait/1.webp b/extensions-builtin/sd_forge_controlnet/tests/images/portrait/1.webp deleted file mode 100644 index 5b9eccf01..000000000 Binary files a/extensions-builtin/sd_forge_controlnet/tests/images/portrait/1.webp and /dev/null differ diff --git a/extensions-builtin/sd_forge_controlnet/tests/images/portrait/2.jpg b/extensions-builtin/sd_forge_controlnet/tests/images/portrait/2.jpg deleted file mode 100644 index c16127c29..000000000 Binary files a/extensions-builtin/sd_forge_controlnet/tests/images/portrait/2.jpg and /dev/null differ diff --git a/extensions-builtin/sd_forge_controlnet/tests/images/portrait/3.jpeg b/extensions-builtin/sd_forge_controlnet/tests/images/portrait/3.jpeg deleted file mode 100644 index 1936d9fc2..000000000 Binary files a/extensions-builtin/sd_forge_controlnet/tests/images/portrait/3.jpeg and /dev/null differ diff --git a/extensions-builtin/sd_forge_controlnet/tests/images/portrait/4.jpg b/extensions-builtin/sd_forge_controlnet/tests/images/portrait/4.jpg deleted file mode 100644 index f0a0a8bf7..000000000 Binary files a/extensions-builtin/sd_forge_controlnet/tests/images/portrait/4.jpg and /dev/null differ diff --git a/extensions-builtin/sd_forge_controlnet/tests/images/portrait/5.jpg b/extensions-builtin/sd_forge_controlnet/tests/images/portrait/5.jpg deleted file mode 100644 index 605a914aa..000000000 Binary files a/extensions-builtin/sd_forge_controlnet/tests/images/portrait/5.jpg and /dev/null differ diff --git a/extensions-builtin/sd_forge_controlnet/tests/images/portrait/6.jpg b/extensions-builtin/sd_forge_controlnet/tests/images/portrait/6.jpg deleted file mode 100644 index c081789ad..000000000 Binary files a/extensions-builtin/sd_forge_controlnet/tests/images/portrait/6.jpg and /dev/null differ diff --git a/extensions-builtin/sd_forge_controlnet/tests/web_api/__init__.py b/extensions-builtin/sd_forge_controlnet/tests/web_api/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/extensions-builtin/sd_forge_controlnet/tests/web_api/detect_test.py b/extensions-builtin/sd_forge_controlnet/tests/web_api/detect_test.py deleted file mode 100644 index 81f26b756..000000000 --- a/extensions-builtin/sd_forge_controlnet/tests/web_api/detect_test.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest -import requests -from typing import List - -from .template import ( - APITestTemplate, - realistic_girl_face_img, - save_base64, - get_dest_dir, - disable_in_cq, -) - - -def get_modules() -> List[str]: - return requests.get(APITestTemplate.BASE_URL + "controlnet/module_list").json()[ - "module_list" - ] - - -def detect_template(payload, output_name: str): - url = APITestTemplate.BASE_URL + "controlnet/detect" - resp = requests.post(url, json=payload) - assert resp.status_code == 200 - resp_json = resp.json() - assert "images" in resp_json - assert len(resp_json["images"]) == len(payload["controlnet_input_images"]) - if not APITestTemplate.is_cq_run: - for i, img in enumerate(resp_json["images"]): - if img == "Detect result is not image": - continue - dest = get_dest_dir() / f"{output_name}_{i}.png" - save_base64(img, dest) - return resp_json - - -@disable_in_cq -@pytest.mark.parametrize("module", get_modules()) -def test_detect_all_modules(module: str): - payload = dict( - controlnet_input_images=[realistic_girl_face_img], - controlnet_module=module, - ) - detect_template(payload, f"detect_{module}") - - -def test_detect_simple(): - detect_template( - dict( - controlnet_input_images=[realistic_girl_face_img], - controlnet_module="canny", # Canny does not require model download. - ), - "simple_detect", - ) - - -def test_detect_multiple_inputs(): - detect_template( - dict( - controlnet_input_images=[realistic_girl_face_img, realistic_girl_face_img], - controlnet_module="canny", # Canny does not require model download. - ), - "multiple_inputs_detect", - ) diff --git a/extensions-builtin/sd_forge_controlnet/tests/web_api/generation_test.py b/extensions-builtin/sd_forge_controlnet/tests/web_api/generation_test.py deleted file mode 100644 index 433819d15..000000000 --- a/extensions-builtin/sd_forge_controlnet/tests/web_api/generation_test.py +++ /dev/null @@ -1,171 +0,0 @@ -import pytest - -from .template import ( - APITestTemplate, - girl_img, - mask_img, - disable_in_cq, - get_model, -) - - -@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"]) -def test_no_unit(gen_type): - assert APITestTemplate( - f"test_no_unit{gen_type}", - gen_type, - payload_overrides={}, - unit_overrides=[], - input_image=girl_img, - ).exec() - - -@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"]) -def test_multiple_iter(gen_type): - assert APITestTemplate( - f"test_multiple_iter{gen_type}", - gen_type, - payload_overrides={"n_iter": 2}, - unit_overrides={}, - input_image=girl_img, - ).exec() - - -@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"]) -def test_batch_size(gen_type): - assert APITestTemplate( - f"test_batch_size{gen_type}", - gen_type, - payload_overrides={"batch_size": 2}, - unit_overrides={}, - input_image=girl_img, - ).exec() - - -@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"]) -def test_2_units(gen_type): - assert APITestTemplate( - f"test_2_units{gen_type}", - gen_type, - payload_overrides={}, - unit_overrides=[{}, {}], - input_image=girl_img, - ).exec() - - -@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"]) -def test_preprocessor(gen_type): - assert APITestTemplate( - f"test_preprocessor{gen_type}", - gen_type, - payload_overrides={}, - unit_overrides={"module": "canny"}, - input_image=girl_img, - ).exec() - - -@pytest.mark.parametrize("param_name", ("processor_res", "threshold_a", "threshold_b")) -@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"]) -def test_invalid_param(gen_type, param_name): - assert APITestTemplate( - f"test_invalid_param{(gen_type, param_name)}", - gen_type, - payload_overrides={}, - unit_overrides={param_name: -1}, - input_image=girl_img, - ).exec() - - -@pytest.mark.parametrize("save_map", [True, False]) -@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"]) -def test_save_map(gen_type, save_map): - assert APITestTemplate( - f"test_save_map{(gen_type, save_map)}", - gen_type, - payload_overrides={}, - unit_overrides={"save_detected_map": save_map}, - input_image=girl_img, - ).exec(expected_output_num=2 if save_map else 1) - - -@disable_in_cq -def test_masked_controlnet_txt2img(): - assert APITestTemplate( - f"test_masked_controlnet_txt2img", - "txt2img", - payload_overrides={}, - unit_overrides={ - "image": girl_img, - "mask_image": mask_img, - }, - ).exec() - - -@disable_in_cq -def test_masked_controlnet_img2img(): - assert APITestTemplate( - f"test_masked_controlnet_img2img", - "img2img", - payload_overrides={ - "init_images": [girl_img], - }, - # Note: Currently you must give ControlNet unit input image to specify - # mask. - # TODO: Fix this for img2img. - unit_overrides={ - "image": girl_img, - "mask_image": mask_img, - }, - ).exec() - - -@disable_in_cq -def test_txt2img_inpaint(): - assert APITestTemplate( - "txt2img_inpaint", - "txt2img", - payload_overrides={}, - unit_overrides={ - "image": girl_img, - "mask_image": mask_img, - "model": get_model("v11p_sd15_inpaint"), - "module": "inpaint_only", - }, - ).exec() - - -@disable_in_cq -def test_img2img_inpaint(): - assert APITestTemplate( - "img2img_inpaint", - "img2img", - payload_overrides={ - "init_images": [girl_img], - "mask": mask_img, - }, - unit_overrides={ - "model": get_model("v11p_sd15_inpaint"), - "module": "inpaint_only", - }, - ).exec() - - -# Currently failing. -# TODO Fix lama outpaint. -@disable_in_cq -def test_lama_outpaint(): - assert APITestTemplate( - "txt2img_lama_outpaint", - "txt2img", - payload_overrides={ - "width": 768, - "height": 768, - }, - # Outpaint should not need a mask. - unit_overrides={ - "image": girl_img, - "model": get_model("v11p_sd15_inpaint"), - "module": "inpaint_only+lama", - "resize_mode": "Resize and Fill", # OUTER_FIT - }, - ).exec() diff --git a/extensions-builtin/sd_forge_controlnet/tests/web_api/template.py b/extensions-builtin/sd_forge_controlnet/tests/web_api/template.py deleted file mode 100644 index 5129e541f..000000000 --- a/extensions-builtin/sd_forge_controlnet/tests/web_api/template.py +++ /dev/null @@ -1,347 +0,0 @@ -import io -import os -import cv2 -import base64 -import functools -from typing import Dict, Any, List, Union, Literal, Optional -from pathlib import Path -import datetime -from enum import Enum -import numpy as np -import pytest - -import requests -from PIL import Image - - -def disable_in_cq(func): - """Skips the decorated test func in CQ run.""" - @functools.wraps(func) - def wrapped_func(*args, **kwargs): - if APITestTemplate.is_cq_run: - pytest.skip() - return func(*args, **kwargs) - return wrapped_func - - -PayloadOverrideType = Dict[str, Any] - -timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") -test_result_dir = Path(__file__).parent / "results" / f"test_result_{timestamp}" -test_expectation_dir = Path(__file__).parent / "expectations" -os.makedirs(test_expectation_dir, exist_ok=True) -resource_dir = Path(__file__).parents[1] / "images" - - -def get_dest_dir(): - if APITestTemplate.is_set_expectation_run: - return test_expectation_dir - else: - return test_result_dir - - -def save_base64(base64img: str, dest: Path): - Image.open(io.BytesIO(base64.b64decode(base64img.split(",", 1)[0]))).save(dest) - - -def read_image(img_path: Path) -> str: - img = cv2.imread(str(img_path)) - _, bytes = cv2.imencode(".png", img) - encoded_image = base64.b64encode(bytes).decode("utf-8") - return encoded_image - - -def read_image_dir(img_dir: Path, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -> List[str]: - """Try read all images in given img_dir.""" - img_dir = str(img_dir) - images = [] - for filename in os.listdir(img_dir): - if filename.endswith(suffixes): - img_path = os.path.join(img_dir, filename) - try: - images.append(read_image(img_path)) - except IOError: - print(f"Error opening {img_path}") - return images - - -girl_img = read_image(resource_dir / "1girl.png") -mask_img = read_image(resource_dir / "mask.png") -mask_small_img = read_image(resource_dir / "mask_small.png") -portrait_imgs = read_image_dir(resource_dir / "portrait") -realistic_girl_face_img = portrait_imgs[0] - - -general_negative_prompt = """ -(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, -((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, -backlight,(ugly:1.331), (duplicate:1.331), (morbid:1.21), (mutilated:1.21), -(tranny:1.331), mutated hands, (poorly drawn hands:1.331), blurry, (bad anatomy:1.21), -(bad proportions:1.331), extra limbs, (missing arms:1.331), (extra legs:1.331), -(fused fingers:1.61051), (too many fingers:1.61051), (unclear eyes:1.331), bad hands, -missing fingers, extra digit, bad body, easynegative, nsfw""" - -class StableDiffusionVersion(Enum): - """The version family of stable diffusion model.""" - - UNKNOWN = 0 - SD1x = 1 - SD2x = 2 - SDXL = 3 - - -sd_version = StableDiffusionVersion( - int(os.environ.get("CONTROLNET_TEST_SD_VERSION", StableDiffusionVersion.SD1x.value)) -) - -is_full_coverage = os.environ.get("CONTROLNET_TEST_FULL_COVERAGE", None) is not None - - -class APITestTemplate: - is_set_expectation_run = os.environ.get("CONTROLNET_SET_EXP", "True") == "True" - is_cq_run = os.environ.get("FORGE_CQ_TEST", "False") == "True" - BASE_URL = "http://localhost:7860/" - - def __init__( - self, - name: str, - gen_type: Union[Literal["img2img"], Literal["txt2img"]], - payload_overrides: PayloadOverrideType, - unit_overrides: Union[PayloadOverrideType, List[PayloadOverrideType]], - input_image: Optional[str] = None, - ): - self.name = name - self.url = APITestTemplate.BASE_URL + "sdapi/v1/" + gen_type - self.payload = { - **(txt2img_payload if gen_type == "txt2img" else img2img_payload), - **payload_overrides, - } - if gen_type == "img2img" and input_image is not None: - self.payload["init_images"] = [input_image] - - # CQ runs on CPU. Reduce steps to increase test speed. - if "steps" not in payload_overrides and APITestTemplate.is_cq_run: - self.payload["steps"] = 3 - - unit_overrides = ( - unit_overrides - if isinstance(unit_overrides, (list, tuple)) - else [unit_overrides] - ) - self.payload["alwayson_scripts"]["ControlNet"]["args"] = [ - { - **default_unit, - **unit_override, - **({"image": input_image} if gen_type == "txt2img" and input_image is not None else {}), - } - for unit_override in unit_overrides - ] - self.active_unit_count = len(unit_overrides) - - def exec(self, *args, **kwargs) -> bool: - if APITestTemplate.is_cq_run: - return self.exec_cq(*args, **kwargs) - else: - return self.exec_local(*args, **kwargs) - - def exec_cq(self, expected_output_num: Optional[int] = None, *args, **kwargs) -> bool: - """Execute test in CQ environment.""" - res = requests.post(url=self.url, json=self.payload) - if res.status_code != 200: - print(f"Unexpected status code {res.status_code}") - return False - - response = res.json() - if "images" not in response: - print(response.keys()) - return False - - if expected_output_num is None: - expected_output_num = self.payload["n_iter"] * self.payload["batch_size"] + self.active_unit_count - - if len(response["images"]) != expected_output_num: - print(f"{len(response['images'])} != {expected_output_num}") - return False - - return True - - def exec_local(self, result_only: bool = True, *args, **kwargs) -> bool: - """Execute test in local environment.""" - if not APITestTemplate.is_set_expectation_run: - os.makedirs(test_result_dir, exist_ok=True) - - failed = False - - response = requests.post(url=self.url, json=self.payload).json() - if "images" not in response: - print(response.keys()) - return False - - dest_dir = get_dest_dir() - results = response["images"][:1] if result_only else response["images"] - for i, base64image in enumerate(results): - img_file_name = f"{self.name}_{i}.png" - save_base64(base64image, dest_dir / img_file_name) - - if not APITestTemplate.is_set_expectation_run: - try: - img1 = cv2.imread(os.path.join(test_expectation_dir, img_file_name)) - img2 = cv2.imread(os.path.join(test_result_dir, img_file_name)) - except Exception as e: - print(f"Get exception reading imgs: {e}") - failed = True - continue - - if img1 is None: - print(f"Warn: No expectation file found {img_file_name}.") - continue - - if not expect_same_image( - img1, - img2, - diff_img_path=str(test_result_dir - / img_file_name.replace(".png", "_diff.png")), - ): - failed = True - return not failed - - -def expect_same_image(img1, img2, diff_img_path: str) -> bool: - # Calculate the difference between the two images - diff = cv2.absdiff(img1, img2) - - # Set a threshold to highlight the different pixels - threshold = 30 - diff_highlighted = np.where(diff > threshold, 255, 0).astype(np.uint8) - - # Assert that the two images are similar within a tolerance - similar = np.allclose(img1, img2, rtol=0.5, atol=1) - if not similar: - # Save the diff_highlighted image to inspect the differences - cv2.imwrite(diff_img_path, diff_highlighted) - - matching_pixels = np.isclose(img1, img2, rtol=0.5, atol=1) - similar_in_general = (matching_pixels.sum() / matching_pixels.size) >= 0.95 - return similar_in_general - - -def get_model(model_name: str) -> str: - """ Find an available model with specified model name.""" - if model_name.lower() == "none": - return "None" - - r = requests.get(APITestTemplate.BASE_URL + "controlnet/model_list") - result = r.json() - if "model_list" not in result: - raise ValueError("No model available") - - candidates = [ - model - for model in result["model_list"] - if model_name.lower() in model.lower() - ] - - if not candidates: - raise ValueError("No suitable model available") - - return candidates[0] - - -default_unit = { - "control_mode": 0, - "enabled": True, - "guidance_end": 1, - "guidance_start": 0, - "pixel_perfect": True, - "processor_res": 512, - "resize_mode": 1, - "threshold_a": 64, - "threshold_b": 64, - "weight": 1, - "module": "canny", - "model": get_model("sd15_canny"), -} - -img2img_payload = { - "batch_size": 1, - "cfg_scale": 7, - "height": 768, - "width": 512, - "n_iter": 1, - "steps": 10, - "sampler_name": "Euler a", - "prompt": "(masterpiece: 1.3), (highres: 1.3), best quality,", - "negative_prompt": "", - "seed": 42, - "seed_enable_extras": False, - "seed_resize_from_h": 0, - "seed_resize_from_w": 0, - "subseed": -1, - "subseed_strength": 0, - "override_settings": {}, - "override_settings_restore_afterwards": False, - "do_not_save_grid": False, - "do_not_save_samples": False, - "s_churn": 0, - "s_min_uncond": 0, - "s_noise": 1, - "s_tmax": None, - "s_tmin": 0, - "script_args": [], - "script_name": None, - "styles": [], - "alwayson_scripts": {"ControlNet": {"args": [default_unit]}}, - "denoising_strength": 0.75, - "initial_noise_multiplier": 1, - "inpaint_full_res": 0, - "inpaint_full_res_padding": 32, - "inpainting_fill": 1, - "inpainting_mask_invert": 0, - "mask_blur_x": 4, - "mask_blur_y": 4, - "mask_blur": 4, - "resize_mode": 0, -} - -txt2img_payload = { - "alwayson_scripts": {"ControlNet": {"args": [default_unit]}}, - "batch_size": 1, - "cfg_scale": 7, - "comments": {}, - "disable_extra_networks": False, - "do_not_save_grid": False, - "do_not_save_samples": False, - "enable_hr": False, - "height": 768, - "hr_negative_prompt": "", - "hr_prompt": "", - "hr_resize_x": 0, - "hr_resize_y": 0, - "hr_scale": 2, - "hr_second_pass_steps": 0, - "hr_upscaler": "Latent", - "n_iter": 1, - "negative_prompt": "", - "override_settings": {}, - "override_settings_restore_afterwards": True, - "prompt": "(masterpiece: 1.3), (highres: 1.3), best quality,", - "restore_faces": False, - "s_churn": 0.0, - "s_min_uncond": 0, - "s_noise": 1.0, - "s_tmax": None, - "s_tmin": 0.0, - "sampler_name": "Euler a", - "script_args": [], - "script_name": None, - "seed": 42, - "seed_enable_extras": True, - "seed_resize_from_h": -1, - "seed_resize_from_w": -1, - "steps": 10, - "styles": [], - "subseed": -1, - "subseed_strength": 0, - "tiling": False, - "width": 512, -} diff --git a/extensions-builtin/sd_forge_controlnet_example/preload.py b/extensions-builtin/sd_forge_controlnet_example/preload.py deleted file mode 100644 index ddc29489a..000000000 --- a/extensions-builtin/sd_forge_controlnet_example/preload.py +++ /dev/null @@ -1,6 +0,0 @@ -def preload(parser): - parser.add_argument( - "--show-controlnet-example", - action="store_true", - help="Show development example extension for ControlNet.", - ) diff --git a/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py b/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py deleted file mode 100644 index 9c10cb23b..000000000 --- a/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py +++ /dev/null @@ -1,160 +0,0 @@ -# Use --show-controlnet-example to see this extension. - -import cv2 -import gradio as gr -import torch - -from modules import scripts -from modules.shared_cmd_options import cmd_opts -from modules_forge.shared import supported_preprocessors -from modules.modelloader import load_file_from_url -from ldm_patched.modules.controlnet import load_controlnet -from modules_forge.controlnet import apply_controlnet_advanced -from modules_forge.forge_util import numpy_to_pytorch -from modules_forge.shared import controlnet_dir - - -class ControlNetExampleForge(scripts.Script): - model = None - - def title(self): - return "ControlNet Example for Developers" - - def show(self, is_img2img): - # make this extension visible in both txt2img and img2img tab. - return scripts.AlwaysVisible - - def ui(self, *args, **kwargs): - with gr.Accordion(open=False, label=self.title()): - gr.HTML('This is an example controlnet extension for developers.') - gr.HTML('You see this extension because you used --show-controlnet-example') - input_image = gr.Image(source='upload', type='numpy') - funny_slider = gr.Slider(label='This slider does nothing. It just shows you how to transfer parameters.', - minimum=0.0, maximum=1.0, value=0.5) - - return input_image, funny_slider - - def process(self, p, *script_args, **kwargs): - input_image, funny_slider = script_args - - # This slider does nothing. It just shows you how to transfer parameters. - del funny_slider - - if input_image is None: - return - - # controlnet_canny_path = load_file_from_url( - # url='https://huggingface.co/lllyasviel/sd_control_collection/resolve/main/sai_xl_canny_256lora.safetensors', - # model_dir=model_dir, - # file_name='sai_xl_canny_256lora.safetensors' - # ) - controlnet_canny_path = load_file_from_url( - url='https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/control_v11p_sd15_canny_fp16.safetensors', - model_dir=controlnet_dir, - file_name='control_v11p_sd15_canny_fp16.safetensors' - ) - print('The model [control_v11p_sd15_canny_fp16.safetensors] download finished.') - - self.model = load_controlnet(controlnet_canny_path) - print('Controlnet loaded.') - - return - - def process_before_every_sampling(self, p, *script_args, **kwargs): - # This will be called before every sampling. - # If you use highres fix, this will be called twice. - - input_image, funny_slider = script_args - - if input_image is None or self.model is None: - return - - B, C, H, W = kwargs['noise'].shape # latent_shape - height = H * 8 - width = W * 8 - batch_size = p.batch_size - - preprocessor = supported_preprocessors['canny'] - - # detect control at certain resolution - control_image = preprocessor( - input_image, resolution=512, slider_1=100, slider_2=200, slider_3=None) - - # here we just use nearest neighbour to align input shape. - # You may want crop and resize, or crop and fill, or others. - control_image = cv2.resize( - control_image, (width, height), interpolation=cv2.INTER_NEAREST) - - # Output preprocessor result. Now called every sampling. Cache in your own way. - p.extra_result_images.append(control_image) - - print('Preprocessor Canny finished.') - - control_image_bchw = numpy_to_pytorch(control_image).movedim(-1, 1) - - unet = p.sd_model.forge_objects.unet - - # Unet has input, middle, output blocks, and we can give different weights - # to each layers in all blocks. - # Below is an example for stronger control in middle block. - # This is helpful for some high-res fix passes. (p.is_hr_pass) - positive_advanced_weighting = { - 'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2], - 'middle': [1.0], - 'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2] - } - negative_advanced_weighting = { - 'input': [0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95, 1.05, 1.15, 1.25], - 'middle': [1.05], - 'output': [0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95, 1.05, 1.15, 1.25] - } - - # The advanced_frame_weighting is a weight applied to each image in a batch. - # The length of this list must be same with batch size - # For example, if batch size is 5, the below list is [0.2, 0.4, 0.6, 0.8, 1.0] - # If you view the 5 images as 5 frames in a video, this will lead to - # progressively stronger control over time. - advanced_frame_weighting = [float(i + 1) / float(batch_size) for i in range(batch_size)] - - # The advanced_sigma_weighting allows you to dynamically compute control - # weights given diffusion timestep (sigma). - # For example below code can softly make beginning steps stronger than ending steps. - sigma_max = unet.model.model_sampling.sigma_max - sigma_min = unet.model.model_sampling.sigma_min - advanced_sigma_weighting = lambda s: (s - sigma_min) / (sigma_max - sigma_min) - - # You can even input a tensor to mask all control injections - # The mask will be automatically resized during inference in UNet. - # The size should be B 1 H W and the H and W are not important - # because they will be resized automatically - advanced_mask_weighting = torch.ones(size=(1, 1, 512, 512)) - - # But in this simple example we do not use them - positive_advanced_weighting = None - negative_advanced_weighting = None - advanced_frame_weighting = None - advanced_sigma_weighting = None - advanced_mask_weighting = None - - unet = apply_controlnet_advanced(unet=unet, controlnet=self.model, image_bchw=control_image_bchw, - strength=0.6, start_percent=0.0, end_percent=0.8, - positive_advanced_weighting=positive_advanced_weighting, - negative_advanced_weighting=negative_advanced_weighting, - advanced_frame_weighting=advanced_frame_weighting, - advanced_sigma_weighting=advanced_sigma_weighting, - advanced_mask_weighting=advanced_mask_weighting) - - p.sd_model.forge_objects.unet = unet - - # Below codes will add some logs to the texts below the image outputs on UI. - # The extra_generation_params does not influence results. - p.extra_generation_params.update(dict( - controlnet_info='You should see these texts below output images!', - )) - - return - - -# Use --show-controlnet-example to see this extension. -if not cmd_opts.show_controlnet_example: - del ControlNetExampleForge diff --git a/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py index d90243442..d5716d37c 100644 --- a/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py +++ b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py @@ -3,6 +3,7 @@ import math from modules.ui_components import InputAccordion import modules.scripts as scripts +from modules.torch_utils import float64 class SoftInpaintingSettings: @@ -57,10 +58,14 @@ def latent_blend(settings, a, b, t): # NOTE: We use inplace operations wherever possible. - # [4][w][h] to [1][4][w][h] - t2 = t.unsqueeze(0) - # [4][w][h] to [1][1][w][h] - the [4] seem redundant. - t3 = t[0].unsqueeze(0).unsqueeze(0) + if len(t.shape) == 3: + # [4][w][h] to [1][4][w][h] + t2 = t.unsqueeze(0) + # [4][w][h] to [1][1][w][h] - the [4] seem redundant. + t3 = t[0].unsqueeze(0).unsqueeze(0) + else: + t2 = t + t3 = t[:, 0][:, None] one_minus_t2 = 1 - t2 one_minus_t3 = 1 - t3 @@ -75,13 +80,11 @@ def latent_blend(settings, a, b, t): # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) # 64-bit operations are used here to allow large exponents. - current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001) + current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(float64(image_interp)).add_(0.00001) # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_( - settings.inpaint_detail_preservation) * one_minus_t3 - b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_( - settings.inpaint_detail_preservation) * t3 + a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(float64(a)).pow_(settings.inpaint_detail_preservation) * one_minus_t3 + b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(float64(b)).pow_(settings.inpaint_detail_preservation) * t3 desired_magnitude = a_magnitude desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation) del a_magnitude, b_magnitude, t3, one_minus_t3 @@ -104,7 +107,7 @@ def latent_blend(settings, a, b, t): def get_modified_nmask(settings, nmask, sigma): """ - Converts a negative mask representing the transparency of the original latent vectors being overlayed + Converts a negative mask representing the transparency of the original latent vectors being overlaid to a mask that is scaled according to the denoising strength for this step. Where: @@ -135,7 +138,10 @@ def apply_adaptive_masks( from PIL import Image, ImageOps, ImageFilter # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. - latent_mask = nmask[0].float() + if len(nmask.shape) == 3: + latent_mask = nmask[0].float() + else: + latent_mask = nmask[:, 0].float() # convert the original mask into a form we use to scale distances for thresholding mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2)) mask_scalar = (0.5 * (1 - settings.composite_mask_influence) @@ -157,7 +163,14 @@ def apply_adaptive_masks( percentile_min=0.25, percentile_max=0.75, min_width=1) # The distance at which opacity of original decreases to 50% - half_weighted_distance = settings.composite_difference_threshold * mask_scalar + if len(mask_scalar.shape) == 3: + if mask_scalar.shape[0] > i: + half_weighted_distance = settings.composite_difference_threshold * mask_scalar[i] + else: + half_weighted_distance = settings.composite_difference_threshold * mask_scalar[0] + else: + half_weighted_distance = settings.composite_difference_threshold * mask_scalar + converted_mask = converted_mask / half_weighted_distance converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast) @@ -472,7 +485,7 @@ def gaussian_kernel_func(coordinate): class Script(scripts.Script): def __init__(self): - self.section = "inpaint" + # self.section = "inpaint" self.masks_for_overlay = None self.overlay_images = None diff --git a/html/extra-networks-copy-path-button.html b/html/extra-networks-copy-path-button.html index 8083bb033..50304b42d 100644 --- a/html/extra-networks-copy-path-button.html +++ b/html/extra-networks-copy-path-button.html @@ -1,5 +1,5 @@
\ No newline at end of file diff --git a/html/extra-networks-edit-item-button.html b/html/extra-networks-edit-item-button.html index 0fe43082a..fd728600f 100644 --- a/html/extra-networks-edit-item-button.html +++ b/html/extra-networks-edit-item-button.html @@ -1,4 +1,4 @@
+ onclick="extraNetworksEditUserMetadata(event, '{tabname}', '{extra_networks_tabname}')">
\ No newline at end of file diff --git a/html/extra-networks-metadata-button.html b/html/extra-networks-metadata-button.html index 285b5b3b6..4ef013bc0 100644 --- a/html/extra-networks-metadata-button.html +++ b/html/extra-networks-metadata-button.html @@ -1,4 +1,4 @@
+ onclick="extraNetworksRequestMetadata(event, '{extra_networks_tabname}')">
\ No newline at end of file diff --git a/html/extra-networks-pane-dirs.html b/html/extra-networks-pane-dirs.html new file mode 100644 index 000000000..d7c9661a0 --- /dev/null +++ b/html/extra-networks-pane-dirs.html @@ -0,0 +1,8 @@ +
+
+ {dirs_html} +
+
+ {items_html} +
+
diff --git a/html/extra-networks-pane-tree.html b/html/extra-networks-pane-tree.html new file mode 100644 index 000000000..e4d92a359 --- /dev/null +++ b/html/extra-networks-pane-tree.html @@ -0,0 +1,8 @@ +
+
+ {tree_html} +
+
+ {items_html} +
+
\ No newline at end of file diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html index 0c763f710..9a67baea9 100644 --- a/html/extra-networks-pane.html +++ b/html/extra-networks-pane.html @@ -1,23 +1,53 @@ -
+
-
-
- {tree_html} -
-
- {items_html} +
-
\ No newline at end of file + {pane_content} +
diff --git a/html/footer.html b/html/footer.html index 69b2372c7..8fe2bf8da 100644 --- a/html/footer.html +++ b/html/footer.html @@ -1,7 +1,7 @@
API  •  - Github + Github  •  Gradio  •  diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js index 2cf2d571f..90aa25c99 100644 --- a/javascript/aspectRatioOverlay.js +++ b/javascript/aspectRatioOverlay.js @@ -1,10 +1,8 @@ - -let currentWidth = null; -let currentHeight = null; -let arFrameTimeout = setTimeout(function() {}, 0); +let currentWidth; +let currentHeight; +let arFrameTimeout; function dimensionChange(e, is_width, is_height) { - if (is_width) { currentWidth = e.target.value * 1.0; } @@ -22,18 +20,18 @@ function dimensionChange(e, is_width, is_height) { var tabIndex = get_tab_index('mode_img2img'); if (tabIndex == 0) { // img2img - targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img'); + targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] canvas'); } else if (tabIndex == 1) { //Sketch - targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img'); + targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] canvas'); } else if (tabIndex == 2) { // Inpaint - targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img'); + targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] canvas'); } else if (tabIndex == 3) { // Inpaint sketch - targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img'); + targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] canvas'); + } else if (tabIndex == 4) { // Inpaint upload + targetElement = gradioApp().querySelector('#img_inpaint_base div[data-testid=image] img'); } - if (targetElement) { - var arPreviewRect = gradioApp().querySelector('#imageARPreview'); if (!arPreviewRect) { arPreviewRect = document.createElement('div'); @@ -41,26 +39,23 @@ function dimensionChange(e, is_width, is_height) { gradioApp().appendChild(arPreviewRect); } - - var viewportOffset = targetElement.getBoundingClientRect(); + var viewportscale = Math.min(targetElement.clientWidth / targetElement.width, targetElement.clientHeight / targetElement.height); - var viewportscale = Math.min(targetElement.clientWidth / targetElement.naturalWidth, targetElement.clientHeight / targetElement.naturalHeight); + var scaledx = targetElement.width * viewportscale; + var scaledy = targetElement.height * viewportscale; - var scaledx = targetElement.naturalWidth * viewportscale; - var scaledy = targetElement.naturalHeight * viewportscale; - - var cleintRectTop = (viewportOffset.top + window.scrollY); - var cleintRectLeft = (viewportOffset.left + window.scrollX); - var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight / 2); - var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth / 2); + var clientRectTop = (viewportOffset.top + window.scrollY); + var clientRectLeft = (viewportOffset.left + window.scrollX); + var clientRectCentreY = clientRectTop + (targetElement.clientHeight / 2); + var clientRectCentreX = clientRectLeft + (targetElement.clientWidth / 2); var arscale = Math.min(scaledx / currentWidth, scaledy / currentHeight); var arscaledx = currentWidth * arscale; var arscaledy = currentHeight * arscale; - var arRectTop = cleintRectCentreY - (arscaledy / 2); - var arRectLeft = cleintRectCentreX - (arscaledx / 2); + var arRectTop = clientRectCentreY - (arscaledy / 2); + var arRectLeft = clientRectCentreX - (arscaledx / 2); var arRectWidth = arscaledx; var arRectHeight = arscaledy; @@ -75,21 +70,18 @@ function dimensionChange(e, is_width, is_height) { }, 2000); arPreviewRect.style.display = 'block'; - } - } - onAfterUiUpdate(function() { var arPreviewRect = gradioApp().querySelector('#imageARPreview'); if (arPreviewRect) { arPreviewRect.style.display = 'none'; } + var tabImg2img = gradioApp().querySelector("#tab_img2img"); if (tabImg2img) { - var inImg2img = tabImg2img.style.display == "block"; - if (inImg2img) { + if (tabImg2img.style.display == "block") { let inputs = gradioApp().querySelectorAll('input'); inputs.forEach(function(e) { var is_width = e.parentElement.id == "img2img_width"; diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js index ccae242f2..e01fd67e8 100644 --- a/javascript/contextMenus.js +++ b/javascript/contextMenus.js @@ -8,9 +8,6 @@ var contextMenuInit = function() { }; function showContextMenu(event, element, menuEntries) { - let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft; - let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop; - let oldMenu = gradioApp().querySelector('#context-menu'); if (oldMenu) { oldMenu.remove(); @@ -23,10 +20,8 @@ var contextMenuInit = function() { contextMenu.style.background = baseStyle.background; contextMenu.style.color = baseStyle.color; contextMenu.style.fontFamily = baseStyle.fontFamily; - contextMenu.style.top = posy + 'px'; - contextMenu.style.left = posx + 'px'; - - + contextMenu.style.top = event.pageY + 'px'; + contextMenu.style.left = event.pageX + 'px'; const contextMenuList = document.createElement('ul'); contextMenuList.className = 'context-menu-items'; @@ -43,21 +38,6 @@ var contextMenuInit = function() { }); gradioApp().appendChild(contextMenu); - - let menuWidth = contextMenu.offsetWidth + 4; - let menuHeight = contextMenu.offsetHeight + 4; - - let windowWidth = window.innerWidth; - let windowHeight = window.innerHeight; - - if ((windowWidth - posx) < menuWidth) { - contextMenu.style.left = windowWidth - menuWidth + "px"; - } - - if ((windowHeight - posy) < menuHeight) { - contextMenu.style.top = windowHeight - menuHeight + "px"; - } - } function appendContextMenuOption(targetElementSelector, entryName, entryFunction) { @@ -107,16 +87,23 @@ var contextMenuInit = function() { oldMenu.remove(); } }); - gradioApp().addEventListener("contextmenu", function(e) { - let oldMenu = gradioApp().querySelector('#context-menu'); - if (oldMenu) { - oldMenu.remove(); - } - menuSpecs.forEach(function(v, k) { - if (e.composedPath()[0].matches(k)) { - showContextMenu(e, e.composedPath()[0], v); - e.preventDefault(); + ['contextmenu', 'touchstart'].forEach((eventType) => { + gradioApp().addEventListener(eventType, function(e) { + let ev = e; + if (eventType.startsWith('touch')) { + if (e.touches.length !== 2) return; + ev = e.touches[0]; + } + let oldMenu = gradioApp().querySelector('#context-menu'); + if (oldMenu) { + oldMenu.remove(); } + menuSpecs.forEach(function(v, k) { + if (e.composedPath()[0].matches(k)) { + showContextMenu(ev, e.composedPath()[0], v); + e.preventDefault(); + } + }); }); }); eventListenerApplied = true; diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js index d680daf52..882562d73 100644 --- a/javascript/dragdrop.js +++ b/javascript/dragdrop.js @@ -56,6 +56,15 @@ function eventHasFiles(e) { return false; } +function isURL(url) { + try { + const _ = new URL(url); + return true; + } catch { + return false; + } +} + function dragDropTargetIsPrompt(target) { if (target?.placeholder && target?.placeholder.indexOf("Prompt") >= 0) return true; if (target?.parentNode?.parentNode?.className?.indexOf("prompt") > 0) return true; @@ -74,22 +83,39 @@ window.document.addEventListener('dragover', e => { e.dataTransfer.dropEffect = 'copy'; }); -window.document.addEventListener('drop', e => { +window.document.addEventListener('drop', async e => { const target = e.composedPath()[0]; - if (!eventHasFiles(e)) return; + const url = e.dataTransfer.getData('text/uri-list') || e.dataTransfer.getData('text/plain'); + if (!eventHasFiles(e) && !isURL(url)) return; if (dragDropTargetIsPrompt(target)) { e.stopPropagation(); e.preventDefault(); - let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; + const isImg2img = get_tab_index('tabs') == 1; + let prompt_image_target = isImg2img ? "img2img_prompt_image" : "txt2img_prompt_image"; - const imgParent = gradioApp().getElementById(prompt_target); + const imgParent = gradioApp().getElementById(prompt_image_target); const files = e.dataTransfer.files; const fileInput = imgParent.querySelector('input[type="file"]'); - if (fileInput) { + if (eventHasFiles(e) && fileInput) { fileInput.files = files; fileInput.dispatchEvent(new Event('change')); + } else if (url) { + try { + const request = await fetch(url); + if (!request.ok) { + console.error('Error fetching URL:', url, request.status); + return; + } + const data = new DataTransfer(); + data.items.add(new File([await request.blob()], 'image.png')); + fileInput.files = data.files; + fileInput.dispatchEvent(new Event('change')); + } catch (error) { + console.error('Error fetching URL:', url, error); + return; + } } } diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index 688c2f112..b07ba97cb 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -64,6 +64,14 @@ function keyupEditAttention(event) { selectionEnd++; } + // deselect surrounding whitespace + while (text[selectionStart] == " " && selectionStart < selectionEnd) { + selectionStart++; + } + while (text[selectionEnd - 1] == " " && selectionEnd > selectionStart) { + selectionEnd--; + } + target.setSelectionRange(selectionStart, selectionEnd); return true; } diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index d5855fe96..c5cced973 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -39,12 +39,12 @@ function setupExtraNetworksForTab(tabname) { // tabname_full = {tabname}_{extra_networks_tabname} var tabname_full = elem.id; var search = gradioApp().querySelector("#" + tabname_full + "_extra_search"); - var sort_mode = gradioApp().querySelector("#" + tabname_full + "_extra_sort"); var sort_dir = gradioApp().querySelector("#" + tabname_full + "_extra_sort_dir"); var refresh = gradioApp().querySelector("#" + tabname_full + "_extra_refresh"); + var currentSort = ''; // If any of the buttons above don't exist, we want to skip this iteration of the loop. - if (!search || !sort_mode || !sort_dir || !refresh) { + if (!search || !sort_dir || !refresh) { return; // `return` is equivalent of `continue` but for forEach loops. } @@ -52,7 +52,7 @@ function setupExtraNetworksForTab(tabname) { var searchTerm = search.value.toLowerCase(); gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) { var searchOnly = elem.querySelector('.search_only'); - var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) { + var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms, .description'), function(t) { return t.textContent.toLowerCase(); }).join(" "); @@ -71,42 +71,46 @@ function setupExtraNetworksForTab(tabname) { }; var applySort = function(force) { - var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); + var cards = gradioApp().querySelectorAll('#' + tabname_full + ' div.card'); + var parent = gradioApp().querySelector('#' + tabname_full + "_cards"); var reverse = sort_dir.dataset.sortdir == "Descending"; - var sortKey = sort_mode.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; - sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); - var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length; + var activeSearchElem = gradioApp().querySelector('#' + tabname_full + "_controls .extra-network-control--sort.extra-network-control--enabled"); + var sortKey = activeSearchElem ? activeSearchElem.dataset.sortkey : "default"; + var sortKeyDataField = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); + var sortKeyStore = sortKey + "-" + sort_dir.dataset.sortdir + "-" + cards.length; - if (sortKeyStore == sort_mode.dataset.sortkey && !force) { + if (sortKeyStore == currentSort && !force) { return; } - sort_mode.dataset.sortkey = sortKeyStore; + currentSort = sortKeyStore; - cards.forEach(function(card) { - card.originalParentElement = card.parentElement; - }); var sortedCards = Array.from(cards); sortedCards.sort(function(cardA, cardB) { - var a = cardA.dataset[sortKey]; - var b = cardB.dataset[sortKey]; + var a = cardA.dataset[sortKeyDataField]; + var b = cardB.dataset[sortKeyDataField]; if (!isNaN(a) && !isNaN(b)) { return parseInt(a) - parseInt(b); } return (a < b ? -1 : (a > b ? 1 : 0)); }); + if (reverse) { sortedCards.reverse(); } - cards.forEach(function(card) { - card.remove(); - }); + + parent.innerHTML = ''; + + var frag = document.createDocumentFragment(); sortedCards.forEach(function(card) { - card.originalParentElement.appendChild(card); + frag.appendChild(card); }); + parent.appendChild(frag); }; - search.addEventListener("input", applyFilter); + search.addEventListener("input", function() { + applyFilter(); + }); applySort(); applyFilter(); extraNetworksApplySort[tabname_full] = applySort; @@ -272,6 +276,15 @@ function saveCardPreview(event, tabname, filename) { event.preventDefault(); } +function extraNetworksSearchButton(tabname, extra_networks_tabname, event) { + var searchTextarea = gradioApp().querySelector("#" + tabname + "_" + extra_networks_tabname + "_extra_search"); + var button = event.target; + var text = button.classList.contains("search-all") ? "" : button.textContent.trim(); + + searchTextarea.value = text; + updateInput(searchTextarea); +} + function extraNetworksTreeProcessFileClick(event, btn, tabname, extra_networks_tabname) { /** * Processes `onclick` events when user clicks on files in tree. @@ -290,7 +303,7 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, extra_netwo * Processes `onclick` events when user clicks on directories in tree. * * Here is how the tree reacts to clicks for various states: - * unselected unopened directory: Diretory is selected and expanded. + * unselected unopened directory: Directory is selected and expanded. * unselected opened directory: Directory is selected. * selected opened directory: Directory is collapsed and deselected. * chevron is clicked: Directory is expanded or collapsed. Selected state unchanged. @@ -383,36 +396,17 @@ function extraNetworksTreeOnClick(event, tabname, extra_networks_tabname) { } function extraNetworksControlSortOnClick(event, tabname, extra_networks_tabname) { - /** - * Handles `onclick` events for the Sort Mode button. - * - * Modifies the data attributes of the Sort Mode button to cycle between - * various sorting modes. - * - * @param event The generated event. - * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. - * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. - */ - var curr_mode = event.currentTarget.dataset.sortmode; - var el_sort_dir = gradioApp().querySelector("#" + tabname + "_" + extra_networks_tabname + "_extra_sort_dir"); - var sort_dir = el_sort_dir.dataset.sortdir; - if (curr_mode == "path") { - event.currentTarget.dataset.sortmode = "name"; - event.currentTarget.dataset.sortkey = "sortName-" + sort_dir + "-640"; - event.currentTarget.setAttribute("title", "Sort by filename"); - } else if (curr_mode == "name") { - event.currentTarget.dataset.sortmode = "date_created"; - event.currentTarget.dataset.sortkey = "sortDate_created-" + sort_dir + "-640"; - event.currentTarget.setAttribute("title", "Sort by date created"); - } else if (curr_mode == "date_created") { - event.currentTarget.dataset.sortmode = "date_modified"; - event.currentTarget.dataset.sortkey = "sortDate_modified-" + sort_dir + "-640"; - event.currentTarget.setAttribute("title", "Sort by date modified"); - } else { - event.currentTarget.dataset.sortmode = "path"; - event.currentTarget.dataset.sortkey = "sortPath-" + sort_dir + "-640"; - event.currentTarget.setAttribute("title", "Sort by path"); - } + /** Handles `onclick` events for Sort Mode buttons. */ + + var self = event.currentTarget; + var parent = event.currentTarget.parentElement; + + parent.querySelectorAll('.extra-network-control--sort').forEach(function(x) { + x.classList.remove('extra-network-control--enabled'); + }); + + self.classList.add('extra-network-control--enabled'); + applyExtraNetworkSort(tabname + "_" + extra_networks_tabname); } @@ -447,8 +441,12 @@ function extraNetworksControlTreeViewOnClick(event, tabname, extra_networks_tabn * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. */ - gradioApp().getElementById(tabname + "_" + extra_networks_tabname + "_tree").classList.toggle("hidden"); - event.currentTarget.classList.toggle("extra-network-control--enabled"); + var button = event.currentTarget; + button.classList.toggle("extra-network-control--enabled"); + var show = !button.classList.contains("extra-network-control--enabled"); + + var pane = gradioApp().getElementById(tabname + "_" + extra_networks_tabname + "_pane"); + pane.classList.toggle("extra-network-dirs-hidden", show); } function extraNetworksControlRefreshOnClick(event, tabname, extra_networks_tabname) { @@ -509,12 +507,76 @@ function popupId(id) { popup(storedPopupIds[id]); } +function extraNetworksFlattenMetadata(obj) { + const result = {}; + + // Convert any stringified JSON objects to actual objects + for (const key of Object.keys(obj)) { + if (typeof obj[key] === 'string') { + try { + const parsed = JSON.parse(obj[key]); + if (parsed && typeof parsed === 'object') { + obj[key] = parsed; + } + } catch (error) { + continue; + } + } + } + + // Flatten the object + for (const key of Object.keys(obj)) { + if (typeof obj[key] === 'object' && obj[key] !== null) { + const nested = extraNetworksFlattenMetadata(obj[key]); + for (const nestedKey of Object.keys(nested)) { + result[`${key}/${nestedKey}`] = nested[nestedKey]; + } + } else { + result[key] = obj[key]; + } + } + + // Special case for handling modelspec keys + for (const key of Object.keys(result)) { + if (key.startsWith("modelspec.")) { + result[key.replaceAll(".", "/")] = result[key]; + delete result[key]; + } + } + + // Add empty keys to designate hierarchy + for (const key of Object.keys(result)) { + const parts = key.split("/"); + for (let i = 1; i < parts.length; i++) { + const parent = parts.slice(0, i).join("/"); + if (!result[parent]) { + result[parent] = ""; + } + } + } + + return result; +} + function extraNetworksShowMetadata(text) { + try { + let parsed = JSON.parse(text); + if (parsed && typeof parsed === 'object') { + parsed = extraNetworksFlattenMetadata(parsed); + const table = createVisualizationTable(parsed, 0); + popup(table); + return; + } + } catch (error) { + console.error(error); + } + var elem = document.createElement('pre'); elem.classList.add('popup-metadata'); elem.textContent = text; popup(elem); + return; } function requestGet(url, data, handler, errorHandler) { @@ -543,16 +605,18 @@ function requestGet(url, data, handler, errorHandler) { xhr.send(js); } -function extraNetworksCopyCardPath(event, path) { - navigator.clipboard.writeText(path); +function extraNetworksCopyCardPath(event) { + navigator.clipboard.writeText(event.target.getAttribute("data-clipboard-text")); event.stopPropagation(); } -function extraNetworksRequestMetadata(event, extraPage, cardName) { +function extraNetworksRequestMetadata(event, extraPage) { var showError = function() { extraNetworksShowMetadata("there was an error getting metadata"); }; + var cardName = event.target.parentElement.parentElement.getAttribute("data-name"); + requestGet("./sd_extra_networks/metadata", {page: extraPage, item: cardName}, function(data) { if (data && data.metadata) { extraNetworksShowMetadata(data.metadata); @@ -566,7 +630,7 @@ function extraNetworksRequestMetadata(event, extraPage, cardName) { var extraPageUserMetadataEditors = {}; -function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) { +function extraNetworksEditUserMetadata(event, tabname, extraPage) { var id = tabname + '_' + extraPage + '_edit_user_metadata'; var editor = extraPageUserMetadataEditors[id]; @@ -578,6 +642,7 @@ function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) { extraPageUserMetadataEditors[id] = editor; } + var cardName = event.target.parentElement.parentElement.getAttribute("data-name"); editor.nameTextarea.value = cardName; updateInput(editor.nameTextarea); diff --git a/javascript/gradio.js b/javascript/gradio.js new file mode 100644 index 000000000..e68b98b04 --- /dev/null +++ b/javascript/gradio.js @@ -0,0 +1,7 @@ + +// added to fix a weird error in gradio 4.19 at page load +Object.defineProperty(Array.prototype, 'toLowerCase', { + value: function() { + return this; + } +}); diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 625c5d148..ff673a02a 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -6,6 +6,8 @@ function closeModal() { function showModal(event) { const source = event.target || event.srcElement; const modalImage = gradioApp().getElementById("modalImage"); + const modalToggleLivePreviewBtn = gradioApp().getElementById("modal_toggle_live_preview"); + modalToggleLivePreviewBtn.innerHTML = opts.js_live_preview_in_modal_lightbox ? "🗇" : "🗆"; const lb = gradioApp().getElementById("lightboxModal"); modalImage.src = source.src; if (modalImage.style.display === 'none') { @@ -51,14 +53,7 @@ function modalImageSwitch(offset) { var galleryButtons = all_gallery_buttons(); if (galleryButtons.length > 1) { - var currentButton = selected_gallery_button(); - - var result = -1; - galleryButtons.forEach(function(v, i) { - if (v == currentButton) { - result = i; - } - }); + var result = selected_gallery_index(); if (result != -1) { var nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]; @@ -131,19 +126,15 @@ function setupImageForLightbox(e) { e.style.cursor = 'pointer'; e.style.userSelect = 'none'; - var isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1; - - // For Firefox, listening on click first switched to next image then shows the lightbox. - // If you know how to fix this without switching to mousedown event, please. - // For other browsers the event is click to make it possiblr to drag picture. - var event = isFirefox ? 'mousedown' : 'click'; - - e.addEventListener(event, function(evt) { + e.addEventListener('mousedown', function(evt) { if (evt.button == 1) { open(evt.target.src); evt.preventDefault(); return; } + }, true); + + e.addEventListener('click', function(evt) { if (!opts.js_modal_lightbox || evt.button != 0) return; modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed); @@ -163,6 +154,13 @@ function modalZoomToggle(event) { event.stopPropagation(); } +function modalLivePreviewToggle(event) { + const modalToggleLivePreview = gradioApp().getElementById("modal_toggle_live_preview"); + opts.js_live_preview_in_modal_lightbox = !opts.js_live_preview_in_modal_lightbox; + modalToggleLivePreview.innerHTML = opts.js_live_preview_in_modal_lightbox ? "🗇" : "🗆"; + event.stopPropagation(); +} + function modalTileImageToggle(event) { const modalImage = gradioApp().getElementById("modalImage"); const modal = gradioApp().getElementById("lightboxModal"); @@ -179,7 +177,7 @@ function modalTileImageToggle(event) { } onAfterUiUpdate(function() { - var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img'); + var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > button > button > img'); if (fullImg_preview != null) { fullImg_preview.forEach(setupImageForLightbox); } @@ -220,6 +218,14 @@ document.addEventListener("DOMContentLoaded", function() { modalSave.title = "Save Image(s)"; modalControls.appendChild(modalSave); + const modalToggleLivePreview = document.createElement('span'); + modalToggleLivePreview.className = 'modalToggleLivePreview cursor'; + modalToggleLivePreview.id = "modal_toggle_live_preview"; + modalToggleLivePreview.innerHTML = "🗆"; + modalToggleLivePreview.onclick = modalLivePreviewToggle; + modalToggleLivePreview.title = "Toggle live preview"; + modalControls.appendChild(modalToggleLivePreview); + const modalClose = document.createElement('span'); modalClose.className = 'modalClose cursor'; modalClose.innerHTML = '×'; diff --git a/javascript/profilerVisualization.js b/javascript/profilerVisualization.js index 9d8e5f42f..9822f4b2a 100644 --- a/javascript/profilerVisualization.js +++ b/javascript/profilerVisualization.js @@ -33,120 +33,141 @@ function createRow(table, cellName, items) { return res; } -function showProfile(path, cutoff = 0.05) { - requestGet(path, {}, function(data) { - var table = document.createElement('table'); - table.className = 'popup-table'; - - data.records['total'] = data.total; - var keys = Object.keys(data.records).sort(function(a, b) { - return data.records[b] - data.records[a]; +function createVisualizationTable(data, cutoff = 0, sort = "") { + var table = document.createElement('table'); + table.className = 'popup-table'; + + var keys = Object.keys(data); + if (sort === "number") { + keys = keys.sort(function(a, b) { + return data[b] - data[a]; }); - var items = keys.map(function(x) { - return {key: x, parts: x.split('/'), time: data.records[x]}; + } else { + keys = keys.sort(); + } + var items = keys.map(function(x) { + return {key: x, parts: x.split('/'), value: data[x]}; + }); + var maxLength = items.reduce(function(a, b) { + return Math.max(a, b.parts.length); + }, 0); + + var cols = createRow( + table, + 'th', + [ + cutoff === 0 ? 'key' : 'record', + cutoff === 0 ? 'value' : 'seconds' + ] + ); + cols[0].colSpan = maxLength; + + function arraysEqual(a, b) { + return !(a < b || b < a); + } + + var addLevel = function(level, parent, hide) { + var matching = items.filter(function(x) { + return x.parts[level] && !x.parts[level + 1] && arraysEqual(x.parts.slice(0, level), parent); }); - var maxLength = items.reduce(function(a, b) { - return Math.max(a, b.parts.length); - }, 0); - - var cols = createRow(table, 'th', ['record', 'seconds']); - cols[0].colSpan = maxLength; - - function arraysEqual(a, b) { - return !(a < b || b < a); + if (sort === "number") { + matching = matching.sort(function(a, b) { + return b.value - a.value; + }); + } else { + matching = matching.sort(); } + var othersTime = 0; + var othersList = []; + var othersRows = []; + var childrenRows = []; + matching.forEach(function(x) { + var visible = (cutoff === 0 && !hide) || (x.value >= cutoff && !hide); + + var cells = []; + for (var i = 0; i < maxLength; i++) { + cells.push(x.parts[i]); + } + cells.push(cutoff === 0 ? x.value : x.value.toFixed(3)); + var cols = createRow(table, 'td', cells); + for (i = 0; i < level; i++) { + cols[i].className = 'muted'; + } - var addLevel = function(level, parent, hide) { - var matching = items.filter(function(x) { - return x.parts[level] && !x.parts[level + 1] && arraysEqual(x.parts.slice(0, level), parent); - }); - var sorted = matching.sort(function(a, b) { - return b.time - a.time; - }); - var othersTime = 0; - var othersList = []; - var othersRows = []; - var childrenRows = []; - sorted.forEach(function(x) { - var visible = x.time >= cutoff && !hide; - - var cells = []; - for (var i = 0; i < maxLength; i++) { - cells.push(x.parts[i]); - } - cells.push(x.time.toFixed(3)); - var cols = createRow(table, 'td', cells); - for (i = 0; i < level; i++) { - cols[i].className = 'muted'; - } - - var tr = cols[0].parentNode; - if (!visible) { - tr.classList.add("hidden"); - } - - if (x.time >= cutoff) { - childrenRows.push(tr); - } else { - othersTime += x.time; - othersList.push(x.parts[level]); - othersRows.push(tr); - } - - var children = addLevel(level + 1, parent.concat([x.parts[level]]), true); - if (children.length > 0) { - var cell = cols[level]; - var onclick = function() { - cell.classList.remove("link"); - cell.removeEventListener("click", onclick); - children.forEach(function(x) { - x.classList.remove("hidden"); - }); - }; - cell.classList.add("link"); - cell.addEventListener("click", onclick); - } - }); + var tr = cols[0].parentNode; + if (!visible) { + tr.classList.add("hidden"); + } - if (othersTime > 0) { - var cells = []; - for (var i = 0; i < maxLength; i++) { - cells.push(parent[i]); - } - cells.push(othersTime.toFixed(3)); - cells[level] = 'others'; - var cols = createRow(table, 'td', cells); - for (i = 0; i < level; i++) { - cols[i].className = 'muted'; - } + if (cutoff === 0 || x.value >= cutoff) { + childrenRows.push(tr); + } else { + othersTime += x.value; + othersList.push(x.parts[level]); + othersRows.push(tr); + } + var children = addLevel(level + 1, parent.concat([x.parts[level]]), true); + if (children.length > 0) { var cell = cols[level]; - var tr = cell.parentNode; var onclick = function() { - tr.classList.add("hidden"); cell.classList.remove("link"); cell.removeEventListener("click", onclick); - othersRows.forEach(function(x) { + children.forEach(function(x) { x.classList.remove("hidden"); }); }; - - cell.title = othersList.join(", "); cell.classList.add("link"); cell.addEventListener("click", onclick); + } + }); - if (hide) { - tr.classList.add("hidden"); - } + if (othersTime > 0) { + var cells = []; + for (var i = 0; i < maxLength; i++) { + cells.push(parent[i]); + } + cells.push(othersTime.toFixed(3)); + cells[level] = 'others'; + var cols = createRow(table, 'td', cells); + for (i = 0; i < level; i++) { + cols[i].className = 'muted'; + } - childrenRows.push(tr); + var cell = cols[level]; + var tr = cell.parentNode; + var onclick = function() { + tr.classList.add("hidden"); + cell.classList.remove("link"); + cell.removeEventListener("click", onclick); + othersRows.forEach(function(x) { + x.classList.remove("hidden"); + }); + }; + + cell.title = othersList.join(", "); + cell.classList.add("link"); + cell.addEventListener("click", onclick); + + if (hide) { + tr.classList.add("hidden"); } - return childrenRows; - }; + childrenRows.push(tr); + } + + return childrenRows; + }; - addLevel(0, []); + addLevel(0, []); + + return table; +} +function showProfile(path, cutoff = 0.05) { + requestGet(path, {}, function(data) { + data.records['total'] = data.total; + const table = createVisualizationTable(data.records, cutoff, "number"); popup(table); }); } diff --git a/javascript/progressbar.js b/javascript/progressbar.js index f068bac6a..23dea64ce 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -76,6 +76,26 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre var dateStart = new Date(); var wasEverActive = false; var parentProgressbar = progressbarContainer.parentNode; + var wakeLock = null; + + var requestWakeLock = async function() { + if (!opts.prevent_screen_sleep_during_generation || wakeLock) return; + try { + wakeLock = await navigator.wakeLock.request('screen'); + } catch (err) { + console.error('Wake Lock is not supported.'); + } + }; + + var releaseWakeLock = async function() { + if (!opts.prevent_screen_sleep_during_generation || !wakeLock) return; + try { + await wakeLock.release(); + wakeLock = null; + } catch (err) { + console.error('Wake Lock release failed', err); + } + }; var divProgress = document.createElement('div'); divProgress.className = 'progressDiv'; @@ -89,6 +109,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre var livePreview = null; var removeProgressBar = function() { + releaseWakeLock(); if (!divProgress) return; setTitle(""); @@ -100,6 +121,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre }; var funProgress = function(id_task) { + requestWakeLock(); request("./internal/progress", {id_task: id_task, live_preview: false}, function(res) { if (res.completed) { removeProgressBar(); diff --git a/javascript/resizeHandle.js b/javascript/resizeHandle.js index 6560372cc..4aeb14b41 100644 --- a/javascript/resizeHandle.js +++ b/javascript/resizeHandle.js @@ -2,6 +2,7 @@ const GRADIO_MIN_WIDTH = 320; const PAD = 16; const DEBOUNCE_TIME = 100; + const DOUBLE_TAP_DELAY = 200; //ms const R = { tracking: false, @@ -10,6 +11,7 @@ leftCol: null, leftColStartWidth: null, screenX: null, + lastTapTime: null, }; let resizeTimer; @@ -20,6 +22,9 @@ } function displayResizeHandle(parent) { + if (!parent.needHideOnMoblie) { + return true; + } if (window.innerWidth < GRADIO_MIN_WIDTH * 2 + PAD * 4) { parent.style.display = 'flex'; parent.resizeHandle.style.display = "none"; @@ -39,7 +44,7 @@ const ratio = newParentWidth / oldParentWidth; - const newWidthL = Math.max(Math.floor(ratio * widthL), GRADIO_MIN_WIDTH); + const newWidthL = Math.max(Math.floor(ratio * widthL), parent.minLeftColWidth); setLeftColGridTemplate(parent, newWidthL); R.parentWidth = newParentWidth; @@ -47,6 +52,14 @@ } function setup(parent) { + + function onDoubleClick(evt) { + evt.preventDefault(); + evt.stopPropagation(); + + parent.style.gridTemplateColumns = parent.style.originalGridTemplateColumns; + } + const leftCol = parent.firstElementChild; const rightCol = parent.lastElementChild; @@ -54,7 +67,24 @@ parent.style.display = 'grid'; parent.style.gap = '0'; - const gridTemplateColumns = `${parent.children[0].style.flexGrow}fr ${PAD}px ${parent.children[1].style.flexGrow}fr`; + let leftColTemplate = ""; + if (parent.children[0].style.flexGrow) { + leftColTemplate = `${parent.children[0].style.flexGrow}fr`; + parent.minLeftColWidth = GRADIO_MIN_WIDTH; + parent.minRightColWidth = GRADIO_MIN_WIDTH; + parent.needHideOnMoblie = true; + } else { + leftColTemplate = parent.children[0].style.flexBasis; + parent.minLeftColWidth = parent.children[0].style.flexBasis.slice(0, -2) / 2; + parent.minRightColWidth = 0; + parent.needHideOnMoblie = false; + } + + if (!leftColTemplate) { + leftColTemplate = '1fr'; + } + + const gridTemplateColumns = `${leftColTemplate} ${PAD}px ${parent.children[1].style.flexGrow}fr`; parent.style.gridTemplateColumns = gridTemplateColumns; parent.style.originalGridTemplateColumns = gridTemplateColumns; @@ -69,6 +99,14 @@ if (evt.button !== 0) return; } else { if (evt.changedTouches.length !== 1) return; + + const currentTime = new Date().getTime(); + if (R.lastTapTime && currentTime - R.lastTapTime <= DOUBLE_TAP_DELAY) { + onDoubleClick(evt); + return; + } + + R.lastTapTime = currentTime; } evt.preventDefault(); @@ -89,12 +127,7 @@ }); }); - resizeHandle.addEventListener('dblclick', (evt) => { - evt.preventDefault(); - evt.stopPropagation(); - - parent.style.gridTemplateColumns = parent.style.originalGridTemplateColumns; - }); + resizeHandle.addEventListener('dblclick', onDoubleClick); afterResize(parent); } @@ -119,7 +152,7 @@ } else { delta = R.screenX - evt.changedTouches[0].screenX; } - const leftColWidth = Math.max(Math.min(R.leftColStartWidth - delta, R.parent.offsetWidth - GRADIO_MIN_WIDTH - PAD), GRADIO_MIN_WIDTH); + const leftColWidth = Math.max(Math.min(R.leftColStartWidth - delta, R.parent.offsetWidth - R.parent.minRightColWidth - PAD), R.parent.minLeftColWidth); setLeftColGridTemplate(R.parent, leftColWidth); } }); @@ -158,10 +191,15 @@ setupResizeHandle = setup; })(); -onUiLoaded(function() { + +function setupAllResizeHandles() { for (var elem of gradioApp().querySelectorAll('.resize-handle-row')) { - if (!elem.querySelector('.resize-handle')) { + if (!elem.querySelector('.resize-handle') && !elem.children[0].classList.contains("hidden")) { setupResizeHandle(elem); } } -}); +} + + +onUiLoaded(setupAllResizeHandles); + diff --git a/javascript/ui.js b/javascript/ui.js index f2adc7dd8..2c5db4838 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -26,13 +26,18 @@ function selected_gallery_index() { return all_gallery_buttons().findIndex(elem => elem.classList.contains('selected')); } +function gallery_container_buttons(gallery_container) { + return gradioApp().querySelectorAll(`#${gallery_container} .thumbnail-item.thumbnail-small`); +} + +function selected_gallery_index_id(gallery_container) { + return Array.from(gallery_container_buttons(gallery_container)).findIndex(elem => elem.classList.contains('selected')); +} + function extract_image_from_gallery(gallery) { if (gallery.length == 0) { return [null]; } - if (gallery.length == 1) { - return [gallery[0]]; - } var index = selected_gallery_index(); @@ -41,7 +46,7 @@ function extract_image_from_gallery(gallery) { index = 0; } - return [gallery[index]]; + return [[gallery[index]]]; } window.args_to_array = Array.from; // Compatibility with e.g. extensions that may expect this to be around @@ -113,14 +118,6 @@ function get_img2img_tab_index() { function create_submit_args(args) { var res = Array.from(args); - // As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image. - // This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate. - // I don't know why gradio is sending outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some. - // If gradio at some point stops sending outputs, this may break something - if (Array.isArray(res[res.length - 3])) { - res[res.length - 3] = null; - } - return res; } @@ -141,11 +138,10 @@ function showSubmitInterruptingPlaceholder(tabname) { function showRestoreProgressButton(tabname, show) { var button = gradioApp().getElementById(tabname + "_restore_progress"); if (!button) return; - - button.style.display = show ? "flex" : "none"; + button.style.setProperty('display', show ? 'flex' : 'none', 'important'); } -function submit() { +function submit(args) { showSubmitButtons('txt2img', false); var id = randomId(); @@ -157,22 +153,22 @@ function submit() { showRestoreProgressButton('txt2img', false); }); - var res = create_submit_args(arguments); + var res = create_submit_args(args); res[0] = id; return res; } -function submit_txt2img_upscale() { - var res = submit(...arguments); +function submit_txt2img_upscale(args) { + var res = submit(...args); res[2] = selected_gallery_index(); return res; } -function submit_img2img() { +function submit_img2img(args) { showSubmitButtons('img2img', false); var id = randomId(); @@ -184,15 +180,14 @@ function submit_img2img() { showRestoreProgressButton('img2img', false); }); - var res = create_submit_args(arguments); + var res = create_submit_args(args); res[0] = id; - res[1] = get_tab_index('mode_img2img'); return res; } -function submit_extras() { +function submit_extras(args) { showSubmitButtons('extras', false); var id = randomId(); @@ -201,11 +196,10 @@ function submit_extras() { showSubmitButtons('extras', true); }); - var res = create_submit_args(arguments); + var res = create_submit_args(args); res[0] = id; - console.log(res); return res; } @@ -214,6 +208,7 @@ function restoreProgressTxt2img() { var id = localGet("txt2img_task_id"); if (id) { + showSubmitInterruptingPlaceholder('txt2img'); requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() { showSubmitButtons('txt2img', true); }, null, 0); @@ -228,6 +223,7 @@ function restoreProgressImg2img() { var id = localGet("img2img_task_id"); if (id) { + showSubmitInterruptingPlaceholder('img2img'); requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() { showSubmitButtons('img2img', true); }, null, 0); @@ -303,6 +299,7 @@ onAfterUiUpdate(function() { var jsdata = textarea.value; opts = JSON.parse(jsdata); + executeCallbacks(optionsAvailableCallbacks); /*global optionsAvailableCallbacks*/ executeCallbacks(optionsChangedCallbacks); /*global optionsChangedCallbacks*/ Object.defineProperty(textarea, 'value', { @@ -341,8 +338,8 @@ onOptionsChanged(function() { let txt2img_textarea, img2img_textarea = undefined; function restart_reload() { + document.body.style.backgroundColor = "var(--background-fill-primary)"; document.body.innerHTML = '

Reloading...

'; - var requestPing = function() { requestGet("./internal/ping", {}, function(data) { location.reload(); @@ -371,9 +368,9 @@ function selectCheckpoint(name) { gradioApp().getElementById('change_checkpoint').click(); } -function currentImg2imgSourceResolution(w, h, scaleBy) { - var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] img'); - return img ? [img.naturalWidth, img.naturalHeight, scaleBy] : [0, 0, scaleBy]; +function currentImg2imgSourceResolution(w, h, r) { + var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] :is(img, canvas)'); + return img ? [img.naturalWidth || img.width, img.naturalHeight || img.height, r] : [0, 0, r]; } function updateImg2imgResizeToTextAfterChangingImage() { @@ -416,7 +413,7 @@ function switchWidthHeight(tabname) { var onEditTimers = {}; -// calls func after afterMs milliseconds has passed since the input elem has beed enited by user +// calls func after afterMs milliseconds has passed since the input elem has been edited by user function onEdit(editId, elem, afterMs, func) { var edited = function() { var existingTimer = onEditTimers[editId]; diff --git a/javascript/ui_settings_hints.js b/javascript/ui_settings_hints.js index d088f9494..c3984bd02 100644 --- a/javascript/ui_settings_hints.js +++ b/javascript/ui_settings_hints.js @@ -14,10 +14,16 @@ onOptionsChanged(function() { if (!commentBefore && !commentAfter) return; var span = null; - if (div.classList.contains('gradio-checkbox')) span = div.querySelector('label span'); - else if (div.classList.contains('gradio-checkboxgroup')) span = div.querySelector('span').firstChild; - else if (div.classList.contains('gradio-radio')) span = div.querySelector('span').firstChild; - else span = div.querySelector('label span').firstChild; + if (div.classList.contains('gradio-checkbox')) { + span = div.querySelector('label span'); + } else if (div.classList.contains('gradio-checkboxgroup')) { + span = div.querySelector('span').firstChild; + } else if (div.classList.contains('gradio-radio')) { + span = div.querySelector('span').firstChild; + } else { + var elem = div.querySelector('label span'); + if (elem) span = elem.firstChild; + } if (!span) return; diff --git a/ldm_patched/ldm/models/autoencoder.py b/ldm_patched/ldm/models/autoencoder.py index fadefee82..d97899d40 100644 --- a/ldm_patched/ldm/models/autoencoder.py +++ b/ldm_patched/ldm/models/autoencoder.py @@ -182,7 +182,7 @@ def get_autoencoder_params(self) -> list: return params def encode( - self, x: torch.Tensor, return_reg_log: bool = False + self, x: torch.Tensor, regulation=None, return_reg_log: bool = False ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: if self.max_batch_size is None: z = self.encoder(x) @@ -198,7 +198,7 @@ def encode( z.append(z_batch) z = torch.cat(z, 0) - z, reg_log = self.regularization(z) + z, reg_log = self.regularization(z) if regulation is None else regulation(z) if return_reg_log: return z, reg_log return z diff --git a/ldm_patched/modules/model_patcher.py b/ldm_patched/modules/model_patcher.py index d1093dc6d..02d992964 100644 --- a/ldm_patched/modules/model_patcher.py +++ b/ldm_patched/modules/model_patcher.py @@ -81,6 +81,9 @@ def set_model_vae_encode_wrapper(self, wrapper_function): def set_model_vae_decode_wrapper(self, wrapper_function): self.model_options["model_vae_decode_wrapper"] = wrapper_function + def set_model_vae_regulation(self, vae_regulation): + self.model_options["model_vae_regulation"] = vae_regulation + def set_model_patch(self, patch, name): to = self.model_options["transformer_options"] if "patches" not in to: diff --git a/ldm_patched/modules/sd.py b/ldm_patched/modules/sd.py index 2830cc721..934c0092a 100644 --- a/ldm_patched/modules/sd.py +++ b/ldm_patched/modules/sd.py @@ -296,6 +296,8 @@ def encode_inner(self, pixel_samples): if model_management.VAE_ALWAYS_TILED: return self.encode_tiled(pixel_samples) + regulation = self.patcher.model_options.get("model_vae_regulation", None) + pixel_samples = pixel_samples.movedim(-1,1) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) @@ -306,7 +308,7 @@ def encode_inner(self, pixel_samples): samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device) for x in range(0, pixel_samples.shape[0], batch_number): pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) - samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() + samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in, regulation).to(self.output_device).float() except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") diff --git a/modules/api/api.py b/modules/api/api.py index d5348bb24..78d109697 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -2,13 +2,11 @@ import io import os import time -import itertools import datetime import uvicorn import ipaddress import requests import gradio as gr -import numpy as np from threading import Lock from io import BytesIO from fastapi import APIRouter, Depends, FastAPI, Request, Response @@ -19,13 +17,13 @@ from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork -from PIL import PngImagePlugin, Image +from PIL import PngImagePlugin from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices @@ -45,7 +43,7 @@ def script_name_to_index(name, scripts): def validate_sampler_name(name): config = sd_samplers.all_samplers_map.get(name, None) if config is None: - raise HTTPException(status_code=404, detail="Sampler not found") + raise HTTPException(status_code=400, detail="Sampler not found") return name @@ -87,7 +85,7 @@ def decode_base64_to_image(encoding): headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {} response = requests.get(encoding, timeout=30, headers=headers) try: - image = Image.open(BytesIO(response.content)) + image = images.read(BytesIO(response.content)) return image except Exception as e: raise HTTPException(status_code=500, detail="Invalid image url") from e @@ -95,7 +93,7 @@ def decode_base64_to_image(encoding): if encoding.startswith("data:image/"): encoding = encoding.split(";")[1].split(",")[1] try: - image = Image.open(BytesIO(base64.b64decode(encoding))) + image = images.read(BytesIO(base64.b64decode(encoding))) return image except Exception as e: raise HTTPException(status_code=500, detail="Invalid encoded image") from e @@ -105,8 +103,6 @@ def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: if isinstance(image, str): return image - if isinstance(image, np.ndarray): - image = Image.fromarray(image) if opts.samples_format.lower() == 'png': use_metadata = False metadata = PngImagePlugin.PngInfo() @@ -117,7 +113,7 @@ def encode_pil_to_base64(image): image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality) elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): - if image.mode == "RGBA": + if image.mode in ("RGBA", "P"): image = image.convert("RGB") parameters = image.info.get('parameters', None) exif_bytes = piexif.dump({ @@ -211,7 +207,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.router = APIRouter() self.app = app self.queue_lock = queue_lock - api_middleware(self.app) + #api_middleware(self.app) # XXX this will have to be fixed self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse) self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse) self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse) @@ -225,6 +221,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem]) + self.add_api_route("/sdapi/v1/schedulers", self.get_schedulers, methods=["GET"], response_model=list[models.SchedulerItem]) self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem]) self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem]) self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem]) @@ -364,7 +361,7 @@ def init_script_args(self, request, default_script_args, selectable_scripts, sel return script_args def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None): - """Processes `infotext` field from the `request`, and sets other fields of the `request` accoring to what's in infotext. + """Processes `infotext` field from the `request`, and sets other fields of the `request` according to what's in infotext. If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored. @@ -375,7 +372,7 @@ def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_scri return {} possible_fields = infotext_utils.paste_fields[tabname]["fields"] - set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this + set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have different names for this params = infotext_utils.parse_generation_parameters(request.infotext) def get_field_value(field, params): @@ -413,8 +410,8 @@ def get_field_value(field, params): if request.override_settings is None: request.override_settings = {} - overriden_settings = infotext_utils.get_override_settings(params) - for _, setting_name, value in overriden_settings: + overridden_settings = infotext_utils.get_override_settings(params) + for _, setting_name, value in overridden_settings: if setting_name not in request.override_settings: request.override_settings[setting_name] = value @@ -441,15 +438,19 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) + sampler, scheduler = sd_samplers.get_sampler_and_scheduler(txt2imgreq.sampler_name or txt2imgreq.sampler_index, txt2imgreq.scheduler) populate = txt2imgreq.copy(update={ # Override __init__ params - "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), + "sampler_name": validate_sampler_name(sampler), "do_not_save_samples": not txt2imgreq.save_images, "do_not_save_grid": not txt2imgreq.save_images, }) if populate.sampler_name: populate.sampler_index = None # prevent a warning later on + if not populate.scheduler and scheduler != "Automatic": + populate.scheduler = scheduler + args = vars(populate) args.pop('script_name', None) args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them @@ -484,11 +485,7 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): shared.state.end() shared.total_tqdm.clear() - b64images = [ - encode_pil_to_base64(image) - for image in itertools.chain(processed.images, processed.extra_images) - if send_images - ] + b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) @@ -509,9 +506,10 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) + sampler, scheduler = sd_samplers.get_sampler_and_scheduler(img2imgreq.sampler_name or img2imgreq.sampler_index, img2imgreq.scheduler) populate = img2imgreq.copy(update={ # Override __init__ params - "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), + "sampler_name": validate_sampler_name(sampler), "do_not_save_samples": not img2imgreq.save_images, "do_not_save_grid": not img2imgreq.save_images, "mask": mask, @@ -519,6 +517,9 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): if populate.sampler_name: populate.sampler_index = None # prevent a warning later on + if not populate.scheduler and scheduler != "Automatic": + populate.scheduler = scheduler + args = vars(populate) args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. args.pop('script_name', None) @@ -555,11 +556,7 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): shared.state.end() shared.total_tqdm.clear() - b64images = [ - encode_pil_to_base64(image) - for image in itertools.chain(processed.images, processed.extra_images) - if send_images - ] + b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] if not img2imgreq.include_init_images: img2imgreq.init_images = None @@ -695,6 +692,17 @@ def get_cmd_flags(self): def get_samplers(self): return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers] + def get_schedulers(self): + return [ + { + "name": scheduler.name, + "label": scheduler.label, + "aliases": scheduler.aliases, + "default_rho": scheduler.default_rho, + "need_inner_model": scheduler.need_inner_model, + } + for scheduler in sd_schedulers.schedulers] + def get_upscalers(self): return [ { diff --git a/modules/api/models.py b/modules/api/models.py index 16edf11cf..f44e5dca0 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,6 +1,6 @@ import inspect -from pydantic import BaseModel, Field, create_model +from pydantic import BaseModel, Field, create_model, ConfigDict from typing import Any, Optional, Literal from inflection import underscore from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img @@ -92,9 +92,7 @@ def generate_model(self): fields = { d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def } - DynamicModel = create_model(self._model_name, **fields) - DynamicModel.__config__.allow_population_by_field_name = True - DynamicModel.__config__.allow_mutation = True + DynamicModel = create_model(self._model_name, __config__=ConfigDict(populate_by_name=True, frozen=False), **fields) return DynamicModel StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( @@ -102,13 +100,13 @@ def generate_model(self): StableDiffusionProcessingTxt2Img, [ {"key": "sampler_index", "type": str, "default": "Euler"}, - {"key": "script_name", "type": str, "default": None}, + {"key": "script_name", "type": str | None, "default": None}, {"key": "script_args", "type": list, "default": []}, {"key": "send_images", "type": bool, "default": True}, {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, - {"key": "force_task_id", "type": str, "default": None}, - {"key": "infotext", "type": str, "default": None}, + {"key": "force_task_id", "type": str | None, "default": None}, + {"key": "infotext", "type": str | None, "default": None}, ] ).generate_model() @@ -117,27 +115,27 @@ def generate_model(self): StableDiffusionProcessingImg2Img, [ {"key": "sampler_index", "type": str, "default": "Euler"}, - {"key": "init_images", "type": list, "default": None}, + {"key": "init_images", "type": list | None, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, - {"key": "mask", "type": str, "default": None}, + {"key": "mask", "type": str | None, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, - {"key": "script_name", "type": str, "default": None}, + {"key": "script_name", "type": str | None, "default": None}, {"key": "script_args", "type": list, "default": []}, {"key": "send_images", "type": bool, "default": True}, {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, - {"key": "force_task_id", "type": str, "default": None}, - {"key": "infotext", "type": str, "default": None}, + {"key": "force_task_id", "type": str | None, "default": None}, + {"key": "infotext", "type": str | None, "default": None}, ] ).generate_model() class TextToImageResponse(BaseModel): - images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + images: list[str] | None = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: dict info: str class ImageToImageResponse(BaseModel): - images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + images: list[str] | None = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: dict info: str @@ -147,7 +145,7 @@ class ExtrasBaseRequest(BaseModel): gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.") codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.") codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.") - upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.") + upscaling_resize: float = Field(default=2, title="Upscaling Factor", gt=0, description="By how much to upscale the image, only used when resize_mode=0.") upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.") upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.") upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?") @@ -163,7 +161,7 @@ class ExtrasSingleImageRequest(ExtrasBaseRequest): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") class ExtrasSingleImageResponse(ExtraBaseResponse): - image: str = Field(default=None, title="Image", description="The generated image in base64 format.") + image: str | None = Field(default=None, title="Image", description="The generated image in base64 format.") class FileData(BaseModel): data: str = Field(title="File data", description="Base64 representation of the file") @@ -190,15 +188,15 @@ class ProgressResponse(BaseModel): progress: float = Field(title="Progress", description="The progress with a range of 0 to 1") eta_relative: float = Field(title="ETA in secs") state: dict = Field(title="State", description="The current state snapshot") - current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") - textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.") + current_image: str | None = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") + textinfo: str | None = Field(default=None, title="Info text", description="Info text used by WebUI.") class InterrogateRequest(BaseModel): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") model: str = Field(default="clip", title="Model", description="The interrogate model used.") class InterrogateResponse(BaseModel): - caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") + caption: str | None = Field(default=None, title="Caption", description="The generated caption for the image.") class TrainResponse(BaseModel): info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.") @@ -223,7 +221,7 @@ class CreateResponse(BaseModel): for key in _options: if(_options[key].dest != 'help'): flag = _options[key] - _type = str + _type = str | None if _options[key].default is not None: _type = type(_options[key].default) flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))}) @@ -233,9 +231,19 @@ class CreateResponse(BaseModel): class SamplerItem(BaseModel): name: str = Field(title="Name") aliases: list[str] = Field(title="Aliases") - options: dict[str, str] = Field(title="Options") + options: dict[str, Any] = Field(title="Options") + +class SchedulerItem(BaseModel): + name: str = Field(title="Name") + label: str = Field(title="Label") + aliases: Optional[list[str]] = Field(title="Aliases") + default_rho: Optional[float] = Field(title="Default Rho") + need_inner_model: Optional[bool] = Field(title="Needs Inner Model") class UpscalerItem(BaseModel): + class Config: + protected_namespaces = () + name: str = Field(title="Name") model_name: Optional[str] = Field(title="Model Name") model_path: Optional[str] = Field(title="Path") @@ -246,6 +254,9 @@ class LatentUpscalerModeItem(BaseModel): name: str = Field(title="Name") class SDModelItem(BaseModel): + class Config: + protected_namespaces = () + title: str = Field(title="Title") model_name: str = Field(title="Model Name") hash: Optional[str] = Field(title="Short hash") @@ -254,6 +265,9 @@ class SDModelItem(BaseModel): config: Optional[str] = Field(title="Config file") class SDVaeItem(BaseModel): + class Config: + protected_namespaces = () + model_name: str = Field(title="Model Name") filename: str = Field(title="Filename") @@ -293,12 +307,12 @@ class MemoryResponse(BaseModel): class ScriptsList(BaseModel): - txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)") - img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)") + txt2img: list | None = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)") + img2img: list | None = Field(default=None, title="Img2img", description="Titles of scripts (img2img)") class ScriptArg(BaseModel): - label: str = Field(default=None, title="Label", description="Name of the argument in UI") + label: str | None = Field(default=None, title="Label", description="Name of the argument in UI") value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument") minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI") maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI") @@ -307,9 +321,9 @@ class ScriptArg(BaseModel): class ScriptInfo(BaseModel): - name: str = Field(default=None, title="Name", description="Script name") - is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script") - is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script") + name: str | None = Field(default=None, title="Name", description="Script name") + is_alwayson: bool | None = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script") + is_img2img: bool | None = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script") args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments") class ExtensionItem(BaseModel): diff --git a/modules/cache.py b/modules/cache.py index a9822a0eb..f4e5f702b 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -2,48 +2,55 @@ import os import os.path import threading -import time + +import diskcache +import tqdm from modules.paths import data_path, script_path cache_filename = os.environ.get('SD_WEBUI_CACHE_FILE', os.path.join(data_path, "cache.json")) -cache_data = None +cache_dir = os.environ.get('SD_WEBUI_CACHE_DIR', os.path.join(data_path, "cache")) +caches = {} cache_lock = threading.Lock() -dump_cache_after = None -dump_cache_thread = None - def dump_cache(): - """ - Marks cache for writing to disk. 5 seconds after no one else flags the cache for writing, it is written. - """ + """old function for dumping cache to disk; does nothing since diskcache.""" - global dump_cache_after - global dump_cache_thread + pass - def thread_func(): - global dump_cache_after - global dump_cache_thread - while dump_cache_after is not None and time.time() < dump_cache_after: - time.sleep(1) +def make_cache(subsection: str) -> diskcache.Cache: + return diskcache.Cache( + os.path.join(cache_dir, subsection), + size_limit=2**32, # 4 GB, culling oldest first + disk_min_file_size=2**18, # keep up to 256KB in Sqlite + ) - with cache_lock: - cache_filename_tmp = cache_filename + "-" - with open(cache_filename_tmp, "w", encoding="utf8") as file: - json.dump(cache_data, file, indent=4, ensure_ascii=False) - os.replace(cache_filename_tmp, cache_filename) +def convert_old_cached_data(): + try: + with open(cache_filename, "r", encoding="utf8") as file: + data = json.load(file) + except FileNotFoundError: + return + except Exception: + os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) + print('[ERROR] issue occurred while trying to read cache.json; old cache has been moved to tmp/cache.json') + return - dump_cache_after = None - dump_cache_thread = None + total_count = sum(len(keyvalues) for keyvalues in data.values()) - with cache_lock: - dump_cache_after = time.time() + 5 - if dump_cache_thread is None: - dump_cache_thread = threading.Thread(name='cache-writer', target=thread_func) - dump_cache_thread.start() + with tqdm.tqdm(total=total_count, desc="converting cache") as progress: + for subsection, keyvalues in data.items(): + cache_obj = caches.get(subsection) + if cache_obj is None: + cache_obj = make_cache(subsection) + caches[subsection] = cache_obj + + for key, value in keyvalues.items(): + cache_obj[key] = value + progress.update(1) def cache(subsection): @@ -54,28 +61,21 @@ def cache(subsection): subsection (str): The subsection identifier for the cache. Returns: - dict: The cache data for the specified subsection. + diskcache.Cache: The cache data for the specified subsection. """ - global cache_data - - if cache_data is None: + cache_obj = caches.get(subsection) + if not cache_obj: with cache_lock: - if cache_data is None: - try: - with open(cache_filename, "r", encoding="utf8") as file: - cache_data = json.load(file) - except FileNotFoundError: - cache_data = {} - except Exception: - os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) - print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') - cache_data = {} - - s = cache_data.get(subsection, {}) - cache_data[subsection] = s - - return s + if not os.path.exists(cache_dir) and os.path.isfile(cache_filename): + convert_old_cached_data() + + cache_obj = caches.get(subsection) + if not cache_obj: + cache_obj = make_cache(subsection) + caches[subsection] = cache_obj + + return cache_obj def cached_data_for_file(subsection, title, filename, func): diff --git a/modules/call_queue.py b/modules/call_queue.py index bcd7c5462..555c35312 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -1,8 +1,9 @@ +import os.path from functools import wraps import html import time -from modules import shared, progress, errors, devices, fifo_lock +from modules import shared, progress, errors, devices, fifo_lock, profiling queue_lock = fifo_lock.FIFOLock() @@ -46,6 +47,22 @@ def f(*args, **kwargs): def wrap_gradio_call(func, extra_outputs=None, add_stats=False): + @wraps(func) + def f(*args, **kwargs): + try: + res = func(*args, **kwargs) + finally: + shared.state.skipped = False + shared.state.interrupted = False + shared.state.stopping_generation = False + shared.state.job_count = 0 + shared.state.job = "" + return res + + return wrap_gradio_call_no_job(f, extra_outputs, add_stats) + + +def wrap_gradio_call_no_job(func, extra_outputs=None, add_stats=False): @wraps(func) def f(*args, extra_outputs_array=extra_outputs, **kwargs): run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats @@ -65,9 +82,6 @@ def f(*args, extra_outputs_array=extra_outputs, **kwargs): arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)" errors.report(f"{message}\n{arg_str}", exc_info=True) - shared.state.job = "" - shared.state.job_count = 0 - if extra_outputs_array is None: extra_outputs_array = [None, ''] @@ -76,11 +90,6 @@ def f(*args, extra_outputs_array=extra_outputs, **kwargs): devices.torch_gc() - shared.state.skipped = False - shared.state.interrupted = False - shared.state.stopping_generation = False - shared.state.job_count = 0 - if not add_stats: return tuple(res) @@ -100,8 +109,8 @@ def f(*args, extra_outputs_array=extra_outputs, **kwargs): sys_pct = sys_peak/max(sys_total, 1) * 100 toltip_a = "Active: peak amount of video memory used during generation (excluding cached data)" - toltip_r = "Reserved: total amout of video memory allocated by the Torch library " - toltip_sys = "System: peak amout of video memory allocated by all running programs, out of total capacity" + toltip_r = "Reserved: total amount of video memory allocated by the Torch library " + toltip_sys = "System: peak amount of video memory allocated by all running programs, out of total capacity" text_a = f"A: {active_peak/1024:.2f} GB" text_r = f"R: {reserved_peak/1024:.2f} GB" @@ -111,9 +120,15 @@ def f(*args, extra_outputs_array=extra_outputs, **kwargs): else: vram_html = '' + if shared.opts.profiling_enable and os.path.exists(shared.opts.profiling_filename): + profiling_html = f"

[ Profile ]

" + else: + profiling_html = '' + # last item is always HTML - res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" + res[-1] += f"

Time taken: {elapsed_text}

{vram_html}{profiling_html}
" return tuple(res) return f + diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 6730f144d..fdc30cf17 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -22,6 +22,7 @@ parser.add_argument("--loglevel", type=str, help="log level; one of: CRITICAL, ERROR, WARNING, INFO, DEBUG", default=None) parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint") parser.add_argument("--data-dir", type=normalized_filepath, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored") +parser.add_argument("--models-dir", type=normalized_filepath, default=None, help="base path where models are stored; overrides --data-dir") parser.add_argument("--config", type=normalized_filepath, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--ckpt", type=normalized_filepath, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=normalized_filepath, default=None, help="Path to directory with stable diffusion checkpoints") @@ -31,7 +32,7 @@ parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") -parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") +parser.add_argument("--max-batch-count", type=int, default=16, help="does not do anything") parser.add_argument("--embeddings-dir", type=normalized_filepath, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") parser.add_argument("--textual-inversion-templates-dir", type=normalized_filepath, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates") parser.add_argument("--hypernetwork-dir", type=normalized_filepath, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") @@ -43,7 +44,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM") parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") -parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") +parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "half", "autocast"], default="autocast") parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) @@ -55,6 +56,7 @@ parser.add_argument("--esrgan-models-path", type=normalized_filepath, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN')) parser.add_argument("--bsrgan-models-path", type=normalized_filepath, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN')) parser.add_argument("--realesrgan-models-path", type=normalized_filepath, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN')) +parser.add_argument("--dat-models-path", type=normalized_filepath, help="Path to directory with DAT model file(s).", default=os.path.join(models_path, 'DAT')) parser.add_argument("--clip-models-path", type=normalized_filepath, help="Path to directory with CLIP model file(s).", default=None) parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") @@ -122,7 +124,10 @@ parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) -parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) +parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui") +parser.add_argument("--unix-filenames-sanitization", action='store_true', help="allow any symbols except '/' in filenames. May conflict with your browser and file system") +parser.add_argument("--filenames-max-length", type=int, default=128, help='maximal length of filenames of saved images. If you override it, it can conflict with your file system') +parser.add_argument("--no-prompt-history", action='store_true', help="disable read prompt from last generation feature; settings this argument will not create '--data_path/params.txt' file") # Arguments added by forge. parser.add_argument( diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 44b84618e..0b353353b 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -50,7 +50,7 @@ def restore(self, np_image, w: float | None = None): def restore_face(cropped_face_t): assert self.net is not None - return self.net(cropped_face_t, w=w, adain=True)[0] + return self.net(cropped_face_t, weight=w, adain=True)[0] return self.restore_with_helper(np_image, restore_face) diff --git a/modules/errors.py b/modules/errors.py index 48aa13a17..ecc2280d0 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -109,7 +109,7 @@ def check_versions(): expected_torch_version = "2.1.2" expected_xformers_version = "0.0.23.post1" - expected_gradio_version = "3.41.2" + expected_gradio_version = "4.39.0" if version.parse(torch.__version__) < version.parse(expected_torch_version): print_error_explanation(f""" diff --git a/modules/extensions.py b/modules/extensions.py index a47cdbe96..715a864c7 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,6 +1,7 @@ from __future__ import annotations import configparser +import dataclasses import os import threading import re @@ -10,6 +11,10 @@ from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 from modules_forge.config import always_disabled_extensions +extensions: list[Extension] = [] +extension_paths: dict[str, Extension] = {} +loaded_extensions: dict[str, Exception] = {} + os.makedirs(extensions_dir, exist_ok=True) @@ -23,6 +28,13 @@ def active(): return [x for x in extensions if x.enabled] +@dataclasses.dataclass +class CallbackOrderInfo: + name: str + before: list + after: list + + class ExtensionMetadata: filename = "metadata.ini" config: configparser.ConfigParser @@ -43,7 +55,7 @@ def __init__(self, path, canonical_name): self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name) self.canonical_name = canonical_name.lower().strip() - self.requires = self.get_script_requirements("Requires", "Extension") + self.requires = None def get_script_requirements(self, field, section, extra_section=None): """reads a list of requirements from the config; field is the name of the field in the ini file, @@ -55,7 +67,15 @@ def get_script_requirements(self, field, section, extra_section=None): if extra_section: x = x + ', ' + self.config.get(extra_section, field, fallback='') - return self.parse_list(x.lower()) + listed_requirements = self.parse_list(x.lower()) + res = [] + + for requirement in listed_requirements: + loaded_requirements = (x for x in requirement.split("|") if x in loaded_extensions) + relevant_requirement = next(loaded_requirements, requirement) + res.append(relevant_requirement) + + return res def parse_list(self, text): """converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])""" @@ -66,6 +86,22 @@ def parse_list(self, text): # both "," and " " are accepted as separator return [x for x in re.split(r"[,\s]+", text.strip()) if x] + def list_callback_order_instructions(self): + for section in self.config.sections(): + if not section.startswith("callbacks/"): + continue + + callback_name = section[10:] + + if not callback_name.startswith(self.canonical_name): + errors.report(f"Callback order section for extension {self.canonical_name} is referencing the wrong extension: {section}") + continue + + before = self.parse_list(self.config.get(section, 'Before', fallback='')) + after = self.parse_list(self.config.get(section, 'After', fallback='')) + + yield CallbackOrderInfo(callback_name, before, after) + class Extension: lock = threading.Lock() @@ -156,14 +192,17 @@ def list_files(self, subdir, extension): def check_updates(self): repo = Repo(self.path) + branch_name = f'{repo.remote().name}/{self.branch}' for fetch in repo.remote().fetch(dry_run=True): + if self.branch and fetch.name != branch_name: + continue if fetch.flags != fetch.HEAD_UPTODATE: self.can_update = True self.status = "new commits" return try: - origin = repo.rev_parse('origin') + origin = repo.rev_parse(branch_name) if repo.head.commit != origin: self.can_update = True self.status = "behind HEAD" @@ -176,8 +215,10 @@ def check_updates(self): self.can_update = False self.status = "latest" - def fetch_and_reset_hard(self, commit='origin'): + def fetch_and_reset_hard(self, commit=None): repo = Repo(self.path) + if commit is None: + commit = f'{repo.remote().name}/{self.branch}' # Fix: `error: Your local changes to the following files would be overwritten by merge`, # because WSL2 Docker set 755 file permissions instead of 644, this results to the error. repo.git.fetch(all=True) @@ -187,6 +228,8 @@ def fetch_and_reset_hard(self, commit='origin'): def list_extensions(): extensions.clear() + extension_paths.clear() + loaded_extensions.clear() if shared.cmd_opts.disable_all_extensions: print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") @@ -197,7 +240,6 @@ def list_extensions(): elif shared.opts.disable_all_extensions == "extra": print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") - loaded_extensions = {} # scan through extensions directory and load metadata for dirname in [extensions_builtin_dir, extensions_dir]: @@ -231,8 +273,12 @@ def list_extensions(): ) extensions.append(extension) + extension_paths[extension.path] = extension loaded_extensions[canonical_name] = extension + for extension in extensions: + extension.metadata.requires = extension.metadata.get_script_requirements("Requires", "Extension") + # check for requirements for extension in extensions: if not extension.enabled: @@ -249,4 +295,16 @@ def list_extensions(): continue -extensions: list[Extension] = [] +def find_extension(filename): + parentdir = os.path.dirname(os.path.realpath(filename)) + + while parentdir != filename: + extension = extension_paths.get(parentdir) + if extension is not None: + return extension + + filename = parentdir + parentdir = os.path.dirname(filename) + + return None + diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 04249dffd..ae8d42d9b 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -60,7 +60,7 @@ def activate(self, p, params_list): Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments separated by colon. - Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list - + Even if the user does not mention this ExtraNetwork in his prompt, the call will still be made, with empty params_list - in this case, all effects of this extra networks should be disabled. Can be called multiple times before deactivate() - each new call should override the previous call completely. diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 445b04092..01ef899e4 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -36,13 +36,11 @@ def load_net(self) -> torch.Module: ext_filter=['.pth'], ): if 'GFPGAN' in os.path.basename(model_path): - model = modelloader.load_spandrel_model( + return modelloader.load_spandrel_model( model_path, device=self.get_device(), expected_architecture='GFPGAN', ).model - model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 - return model raise ValueError("No GFPGAN model found") def restore(self, np_image): diff --git a/modules/gradio_extensions.py b/modules/gradio_extensions.py new file mode 100644 index 000000000..84414f6e3 --- /dev/null +++ b/modules/gradio_extensions.py @@ -0,0 +1,166 @@ +import inspect +import warnings +from functools import wraps + +import gradio as gr +import gradio.component_meta + + +from modules import scripts, ui_tempdir, patches + + +class GradioDeprecationWarning(DeprecationWarning): + pass + + +def add_classes_to_gradio_component(comp): + """ + this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others + """ + + comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(getattr(comp, 'elem_classes', None) or [])] + + if getattr(comp, 'multiselect', False): + comp.elem_classes.append('multiselect') + + +def IOComponent_init(self, *args, **kwargs): + self.webui_tooltip = kwargs.pop('tooltip', None) + + if scripts.scripts_current is not None: + scripts.scripts_current.before_component(self, **kwargs) + + scripts.script_callbacks.before_component_callback(self, **kwargs) + + res = original_IOComponent_init(self, *args, **kwargs) + + add_classes_to_gradio_component(self) + + scripts.script_callbacks.after_component_callback(self, **kwargs) + + if scripts.scripts_current is not None: + scripts.scripts_current.after_component(self, **kwargs) + + return res + + +def Block_get_config(self): + config = original_Block_get_config(self) + + webui_tooltip = getattr(self, 'webui_tooltip', None) + if webui_tooltip: + config["webui_tooltip"] = webui_tooltip + + config.pop('example_inputs', None) + + return config + + +def BlockContext_init(self, *args, **kwargs): + if scripts.scripts_current is not None: + scripts.scripts_current.before_component(self, **kwargs) + + scripts.script_callbacks.before_component_callback(self, **kwargs) + + res = original_BlockContext_init(self, *args, **kwargs) + + add_classes_to_gradio_component(self) + + scripts.script_callbacks.after_component_callback(self, **kwargs) + + if scripts.scripts_current is not None: + scripts.scripts_current.after_component(self, **kwargs) + + return res + + +def Blocks_get_config_file(self, *args, **kwargs): + config = original_Blocks_get_config_file(self, *args, **kwargs) + + for comp_config in config["components"]: + if "example_inputs" in comp_config: + comp_config["example_inputs"] = {"serialized": []} + + return config + + +original_IOComponent_init = patches.patch(__name__, obj=gr.components.Component, field="__init__", replacement=IOComponent_init) +original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config) +original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init) +original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file) + + +ui_tempdir.install_ui_tempdir_override() + + +def gradio_component_meta_create_or_modify_pyi(component_class, class_name, events): + if hasattr(component_class, 'webui_do_not_create_gradio_pyi_thank_you'): + return + + gradio_component_meta_create_or_modify_pyi_original(component_class, class_name, events) + + +# this prevents creation of .pyi files in webui dir +gradio_component_meta_create_or_modify_pyi_original = patches.patch(__file__, gradio.component_meta, 'create_or_modify_pyi', gradio_component_meta_create_or_modify_pyi) + +# this function is broken and does not seem to do anything useful +gradio.component_meta.updateable = lambda x: x + +def repair(grclass): + if not getattr(grclass, 'EVENTS', None): + return + + @wraps(grclass.__init__) + def __repaired_init__(self, *args, tooltip=None, source=None, original=grclass.__init__, **kwargs): + if source: + kwargs["sources"] = [source] + + allowed_kwargs = inspect.signature(original).parameters + fixed_kwargs = {} + for k, v in kwargs.items(): + if k in allowed_kwargs: + fixed_kwargs[k] = v + else: + warnings.warn(f"unexpected argument for {grclass.__name__}: {k}", GradioDeprecationWarning, stacklevel=2) + + original(self, *args, **fixed_kwargs) + + self.webui_tooltip = tooltip + + for event in self.EVENTS: + replaced_event = getattr(self, str(event)) + + def fun(*xargs, _js=None, replaced_event=replaced_event, **xkwargs): + if _js: + xkwargs['js'] = _js + + return replaced_event(*xargs, **xkwargs) + + setattr(self, str(event), fun) + + grclass.__init__ = __repaired_init__ + grclass.update = gr.update + + +for component in set(gr.components.__all__ + gr.layouts.__all__): + repair(getattr(gr, component, None)) + + +class Dependency(gr.events.Dependency): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def then(*xargs, _js=None, **xkwargs): + if _js: + xkwargs['js'] = _js + + return original_then(*xargs, **xkwargs) + + original_then = self.then + self.then = then + + +gr.events.Dependency = Dependency + +gr.Box = gr.Group + diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py deleted file mode 100644 index 7d88dc984..000000000 --- a/modules/gradio_extensons.py +++ /dev/null @@ -1,83 +0,0 @@ -import gradio as gr - -from modules import scripts, ui_tempdir, patches - - -def add_classes_to_gradio_component(comp): - """ - this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others - """ - - comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])] - - if getattr(comp, 'multiselect', False): - comp.elem_classes.append('multiselect') - - -def IOComponent_init(self, *args, **kwargs): - self.webui_tooltip = kwargs.pop('tooltip', None) - - if scripts.scripts_current is not None: - scripts.scripts_current.before_component(self, **kwargs) - - scripts.script_callbacks.before_component_callback(self, **kwargs) - - res = original_IOComponent_init(self, *args, **kwargs) - - add_classes_to_gradio_component(self) - - scripts.script_callbacks.after_component_callback(self, **kwargs) - - if scripts.scripts_current is not None: - scripts.scripts_current.after_component(self, **kwargs) - - return res - - -def Block_get_config(self): - config = original_Block_get_config(self) - - webui_tooltip = getattr(self, 'webui_tooltip', None) - if webui_tooltip: - config["webui_tooltip"] = webui_tooltip - - config.pop('example_inputs', None) - - return config - - -def BlockContext_init(self, *args, **kwargs): - if scripts.scripts_current is not None: - scripts.scripts_current.before_component(self, **kwargs) - - scripts.script_callbacks.before_component_callback(self, **kwargs) - - res = original_BlockContext_init(self, *args, **kwargs) - - add_classes_to_gradio_component(self) - - scripts.script_callbacks.after_component_callback(self, **kwargs) - - if scripts.scripts_current is not None: - scripts.scripts_current.after_component(self, **kwargs) - - return res - - -def Blocks_get_config_file(self, *args, **kwargs): - config = original_Blocks_get_config_file(self, *args, **kwargs) - - for comp_config in config["components"]: - if "example_inputs" in comp_config: - comp_config["example_inputs"] = {"serialized": []} - - return config - - -original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init) -original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config) -original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init) -original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file) - - -ui_tempdir.install_ui_tempdir_override() diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index be3e46484..17454665f 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -11,7 +11,7 @@ from einops import rearrange, repeat from ldm.util import default from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors -from modules.textual_inversion import textual_inversion, logging +from modules.textual_inversion import textual_inversion, saving_settings from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_ @@ -95,6 +95,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N zeros_(b) else: raise KeyError(f"Key {weight_init} is not defined as initialization!") + devices.torch_npu_set_device() self.to(devices.device) def fix_old_state_dict(self, state_dict): @@ -532,7 +533,7 @@ def train_hypernetwork(id_task, hypernetwork_name: str, learn_rate: float, batch model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]} ) - logging.save_settings_to_file(log_directory, {**saved_params, **locals()}) + saving_settings.save_settings_to_file(log_directory, {**saved_params, **locals()}) latent_sampling_method = ds.latent_sampling_method diff --git a/modules/images.py b/modules/images.py index b6f2358c3..031396ee8 100644 --- a/modules/images.py +++ b/modules/images.py @@ -1,7 +1,7 @@ from __future__ import annotations import datetime - +import functools import pytz import io import math @@ -12,7 +12,9 @@ import numpy as np import piexif import piexif.helper -from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin +from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin, ImageOps +# pillow_avif needs to be imported somewhere in code for it to work +import pillow_avif # noqa: F401 import string import json import hashlib @@ -52,11 +54,14 @@ def image_grid(imgs, batch_size=1, rows=None): params = script_callbacks.ImageGridLoopParams(imgs, cols, rows) script_callbacks.image_grid_callback(params) - w, h = imgs[0].size - grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black') + w, h = map(max, zip(*(img.size for img in imgs))) + grid_background_color = ImageColor.getcolor(opts.grid_background_color, 'RGBA') + grid = Image.new('RGBA', size=(params.cols * w, params.rows * h), color=grid_background_color) for i, img in enumerate(params.imgs): - grid.paste(img, box=(i % params.cols * w, i // params.cols * h)) + img_w, img_h = img.size + w_offset, h_offset = 0 if img_w == w else (w - img_w) // 2, 0 if img_h == h else (h - img_h) // 2 + grid.paste(img, box=(i % params.cols * w + w_offset, i // params.cols * h + h_offset)) return grid @@ -244,7 +249,7 @@ def draw_prompt_matrix(im, width, height, all_prompts, margin=0): return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin) -def resize_image(resize_mode, im, width, height, upscaler_name=None): +def resize_image(resize_mode, im, width, height, upscaler_name=None, force_RGBA=False): """ Resizes an image with the specified resize_mode, width, and height. @@ -262,7 +267,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None): upscaler_name = upscaler_name or opts.upscaler_for_img2img def resize(im, w, h): - if upscaler_name is None or upscaler_name == "None" or im.mode == 'L': + if upscaler_name is None or upscaler_name == "None" or im.mode == 'L' or force_RGBA: return im.resize((w, h), resample=LANCZOS) scale = max(w / im.width, h / im.height) @@ -293,7 +298,7 @@ def resize(im, w, h): src_h = height if ratio <= src_ratio else im.height * width // im.width resized = resize(im, src_w, src_h) - res = Image.new("RGB", (width, height)) + res = Image.new("RGB" if not force_RGBA else "RGBA", (width, height)) res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) else: @@ -304,7 +309,7 @@ def resize(im, w, h): src_h = height if ratio >= src_ratio else im.height * width // im.width resized = resize(im, src_w, src_h) - res = Image.new("RGB", (width, height)) + res = Image.new("RGB" if not force_RGBA else "RGBA", (width, height)) res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) if ratio < src_ratio: @@ -321,13 +326,16 @@ def resize(im, w, h): return res -invalid_filename_chars = '#<>:"/\\|?*\n\r\t' +if not shared.cmd_opts.unix_filenames_sanitization: + invalid_filename_chars = '#<>:"/\\|?*\n\r\t' +else: + invalid_filename_chars = '/' invalid_filename_prefix = ' ' invalid_filename_postfix = ' .' re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)") re_pattern_arg = re.compile(r"(.*)<([^>]*)>$") -max_filename_part_length = 128 +max_filename_part_length = shared.cmd_opts.filenames_max_length NOTHING_AND_SKIP_PREVIOUS_TEXT = object() @@ -344,8 +352,35 @@ def sanitize_filename_part(text, replace_spaces=True): return text +@functools.cache +def get_scheduler_str(sampler_name, scheduler_name): + """Returns {Scheduler} if the scheduler is applicable to the sampler""" + if scheduler_name == 'Automatic': + config = sd_samplers.find_sampler_config(sampler_name) + scheduler_name = config.options.get('scheduler', 'Automatic') + return scheduler_name.capitalize() + + +@functools.cache +def get_sampler_scheduler_str(sampler_name, scheduler_name): + """Returns the '{Sampler} {Scheduler}' if the scheduler is applicable to the sampler""" + return f'{sampler_name} {get_scheduler_str(sampler_name, scheduler_name)}' + + +def get_sampler_scheduler(p, sampler): + """Returns '{Sampler} {Scheduler}' / '{Scheduler}' / 'NOTHING_AND_SKIP_PREVIOUS_TEXT'""" + if hasattr(p, 'scheduler') and hasattr(p, 'sampler_name'): + if sampler: + sampler_scheduler = get_sampler_scheduler_str(p.sampler_name, p.scheduler) + else: + sampler_scheduler = get_scheduler_str(p.sampler_name, p.scheduler) + return sanitize_filename_part(sampler_scheduler, replace_spaces=False) + return NOTHING_AND_SKIP_PREVIOUS_TEXT + + class FilenameGenerator: replacements = { + 'basename': lambda self: self.basename or 'img', 'seed': lambda self: self.seed if self.seed is not None else '', 'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0], 'seed_last': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.all_seeds[-1], @@ -355,6 +390,8 @@ class FilenameGenerator: 'height': lambda self: self.image.height, 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False), 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False), + 'sampler_scheduler': lambda self: self.p and get_sampler_scheduler(self.p, True), + 'scheduler': lambda self: self.p and get_sampler_scheduler(self.p, False), 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash), 'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.name_for_extra, replace_spaces=False), 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), @@ -380,12 +417,13 @@ class FilenameGenerator: } default_time_format = '%Y%m%d%H%M%S' - def __init__(self, p, seed, prompt, image, zip=False): + def __init__(self, p, seed, prompt, image, zip=False, basename=""): self.p = p self.seed = seed self.prompt = prompt self.image = image self.zip = zip + self.basename = basename def get_vae_filename(self): """Get the name of the VAE file.""" @@ -566,6 +604,17 @@ def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_p }) piexif.insert(exif_bytes, filename) + elif extension.lower() == '.avif': + if opts.enable_pnginfo and geninfo is not None: + exif_bytes = piexif.dump({ + "Exif": { + piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode") + }, + }) + else: + exif_bytes = None + + image.save(filename,format=image_format, quality=opts.jpeg_quality, exif=exif_bytes) elif extension.lower() == ".gif": image.save(filename, format=image_format, comment=geninfo) else: @@ -605,12 +654,12 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i txt_fullfn (`str` or None): If a text file is saved for this image, this will be its full path. Otherwise None. """ - namegen = FilenameGenerator(p, seed, prompt, image) + namegen = FilenameGenerator(p, seed, prompt, image, basename=basename) # WebP and JPG formats have maximum dimension limits of 16383 and 65535 respectively. switch to PNG which has a much higher limit if (image.height > 65535 or image.width > 65535) and extension.lower() in ("jpg", "jpeg") or (image.height > 16383 or image.width > 16383) and extension.lower() == "webp": print('Image dimensions too large; saving as PNG') - extension = ".png" + extension = "png" if save_to_dirs is None: save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) @@ -744,10 +793,12 @@ def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]: exif_comment = exif_comment.decode('utf8', errors="ignore") if exif_comment: - items['exif comment'] = exif_comment geninfo = exif_comment elif "comment" in items: # for gif - geninfo = items["comment"].decode('utf8', errors="ignore") + if isinstance(items["comment"], bytes): + geninfo = items["comment"].decode('utf8', errors="ignore") + else: + geninfo = items["comment"] for field in IGNORED_INFO_KEYS: items.pop(field, None) @@ -770,7 +821,7 @@ def image_data(data): import gradio as gr try: - image = Image.open(io.BytesIO(data)) + image = read(io.BytesIO(data)) textinfo, _ = read_info_from_image(image) return textinfo, None except Exception: @@ -797,3 +848,30 @@ def flatten(img, bgcolor): return img.convert('RGB') + +def read(fp, **kwargs): + image = Image.open(fp, **kwargs) + image = fix_image(image) + + return image + + +def fix_image(image: Image.Image): + if image is None: + return None + + try: + image = ImageOps.exif_transpose(image) + image = fix_png_transparency(image) + except Exception: + pass + + return image + + +def fix_png_transparency(image: Image.Image): + if image.mode not in ("RGB", "P") or not isinstance(image.info.get("transparency"), bytes): + return image + + image = image.convert("RGBA") + return image diff --git a/modules/img2img.py b/modules/img2img.py index 6e9729a4b..5a1360105 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -2,11 +2,10 @@ from contextlib import closing from pathlib import Path -import numpy as np from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError import gradio as gr -from modules import images as imgutil +from modules import images from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state @@ -18,11 +17,14 @@ from modules_forge import main_thread -def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None): +def process_batch(p, input, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None): output_dir = output_dir.strip() processing.fix_seed(p) - images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff"))) + if isinstance(input, str): + batch_images = list(shared.walk_files(input, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff"))) + else: + batch_images = [os.path.abspath(x.name) for x in input] is_inpaint_batch = False if inpaint_mask_dir: @@ -32,9 +34,9 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal if is_inpaint_batch: print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.") - print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") + print(f"Will process {len(batch_images)} images, creating {p.n_iter * p.batch_size} new images for each.") - state.job_count = len(images) * p.n_iter + state.job_count = len(batch_images) * p.n_iter # extract "default" params to use in case getting png info fails prompt = p.prompt @@ -47,8 +49,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None)) batch_results = None discard_further_results = False - for i, image in enumerate(images): - state.job = f"{i+1} out of {len(images)}" + for i, image in enumerate(batch_images): + state.job = f"{i+1} out of {len(batch_images)}" if state.skipped: state.skipped = False @@ -56,7 +58,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal break try: - img = Image.open(image) + img = images.read(image) except UnidentifiedImageError as e: print(e) continue @@ -87,7 +89,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal # otherwise user has many masks with the same name but different extensions mask_image_path = masks_found[0] - mask_image = Image.open(mask_image_path) + mask_image = images.read(mask_image_path) p.image_mask = mask_image if use_png_info: @@ -95,8 +97,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal info_img = img if png_info_dir: info_img_path = os.path.join(png_info_dir, os.path.basename(image)) - info_img = Image.open(info_img_path) - geninfo, _ = imgutil.read_info_from_image(info_img) + info_img = images.read(info_img_path) + geninfo, _ = images.read_info_from_image(info_img) parsed_parameters = parse_generation_parameters(geninfo) parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})} except Exception: @@ -147,38 +149,40 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal return batch_results -def img2img_function(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): +def img2img_function(id_task: str, request: gr.Request, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, sketch_fg, init_img_with_mask, init_img_with_mask_fg, inpaint_color_sketch, inpaint_color_sketch_fg, init_img_inpaint, init_mask_inpaint, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, img2img_batch_source_type: str, img2img_batch_upload: list, *args): + override_settings = create_override_settings_dict(override_settings_texts) is_batch = mode == 5 + height, width = int(height), int(width) + if mode == 0: # img2img image = init_img mask = None elif mode == 1: # img2img sketch - image = sketch mask = None + image = Image.alpha_composite(sketch, sketch_fg) elif mode == 2: # inpaint - image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] - mask = processing.create_binary_mask(mask) + image = init_img_with_mask + mask = init_img_with_mask_fg.getchannel('A').convert('L') + mask = Image.merge('RGBA', (mask, mask, mask, Image.new('L', mask.size, 255))) elif mode == 3: # inpaint sketch - image = inpaint_color_sketch - orig = inpaint_color_sketch_orig or inpaint_color_sketch - pred = np.any(np.array(image) != np.array(orig), axis=-1) - mask = Image.fromarray(pred.astype(np.uint8) * 255, "L") - mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100) - blur = ImageFilter.GaussianBlur(mask_blur) - image = Image.composite(image.filter(blur), orig, mask.filter(blur)) + image = Image.alpha_composite(inpaint_color_sketch, inpaint_color_sketch_fg) + mask = inpaint_color_sketch_fg.getchannel('A').convert('L') + short_side = min(mask.size) + dilation_size = int(0.015 * short_side) * 2 + 1 + mask = mask.filter(ImageFilter.MaxFilter(dilation_size)) + mask = Image.merge('RGBA', (mask, mask, mask, Image.new('L', mask.size, 255))) elif mode == 4: # inpaint upload mask image = init_img_inpaint mask = init_mask_inpaint - else: - image = None - mask = None - # Use the EXIF orientation of photos taken by smartphones. - if image is not None: - image = ImageOps.exif_transpose(image) + if mask and isinstance(mask, Image.Image): + mask = mask.point(lambda v: 255 if v > 128 else 0) + + image = images.fix_image(image) + mask = images.fix_image(mask) if selected_scale_tab == 1 and not is_batch: assert image, "Can't scale by because no image is selected" @@ -195,10 +199,8 @@ def img2img_function(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt=prompt, negative_prompt=negative_prompt, styles=prompt_styles, - sampler_name=sampler_name, batch_size=batch_size, n_iter=n_iter, - steps=steps, cfg_scale=cfg_scale, width=width, height=height, @@ -225,8 +227,15 @@ def img2img_function(id_task: str, mode: int, prompt: str, negative_prompt: str, with closing(p): if is_batch: - assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" - processed = process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir) + if img2img_batch_source_type == "upload": + assert isinstance(img2img_batch_upload, list) and img2img_batch_upload + output_dir = "" + inpaint_mask_dir = "" + png_info_dir = img2img_batch_png_info_dir if not shared.cmd_opts.hide_ui_dir_config else "" + processed = process_batch(p, img2img_batch_upload, output_dir, inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=png_info_dir) + else: # "from dir" + assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" + processed = process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir) if processed is None: processed = Processed(p, [], p.seed, "") @@ -247,5 +256,5 @@ def img2img_function(id_task: str, mode: int, prompt: str, negative_prompt: str, return processed.images + processed.extra_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments") -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): - return main_thread.run_and_wait_result(img2img_function, id_task, mode, prompt, negative_prompt, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps, sampler_name, mask_blur, mask_alpha, inpainting_fill, n_iter, batch_size, cfg_scale, image_cfg_scale, denoising_strength, selected_scale_tab, height, width, scale_by, resize_mode, inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, override_settings_texts, img2img_batch_use_png_info, img2img_batch_png_info_props, img2img_batch_png_info_dir, request, *args) +def img2img(id_task: str, request: gr.Request, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, sketch_fg, init_img_with_mask, init_img_with_mask_fg, inpaint_color_sketch, inpaint_color_sketch_fg, init_img_inpaint, init_mask_inpaint, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, img2img_batch_source_type: str, img2img_batch_upload: list, *args): + return main_thread.run_and_wait_result(img2img_function, id_task, request, mode, prompt, negative_prompt, prompt_styles, init_img, sketch, sketch_fg, init_img_with_mask, init_img_with_mask_fg, inpaint_color_sketch, inpaint_color_sketch_fg, init_img_inpaint, init_mask_inpaint, mask_blur, mask_alpha, inpainting_fill, n_iter, batch_size, cfg_scale, image_cfg_scale, denoising_strength, selected_scale_tab, height, width, scale_by, resize_mode, inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, override_settings_texts, img2img_batch_use_png_info, img2img_batch_png_info_props, img2img_batch_png_info_dir, img2img_batch_source_type, img2img_batch_upload, *args) diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index d21d33330..0f488b0de 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -8,7 +8,7 @@ import gradio as gr from modules.paths import data_path -from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions +from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, images, prompt_parser, errors from PIL import Image sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name @@ -74,29 +74,38 @@ def image_from_url_text(filedata): if filedata is None: return None - if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False): + if isinstance(filedata, list): + if len(filedata) == 0: + return None + filedata = filedata[0] + if isinstance(filedata, dict) and filedata.get("is_file", False): + filedata = filedata + + filename = None if type(filedata) == dict and filedata.get("is_file", False): filename = filedata["name"] + + elif isinstance(filedata, tuple) and len(filedata) == 2: # gradio 4.16 sends images from gallery as a list of tuples + return filedata[0] + + if filename: is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename) assert is_in_right_dir, 'trying to open image file outside of allowed directories' filename = filename.rsplit('?', 1)[0] - return Image.open(filename) + return images.read(filename) - if type(filedata) == list: - if len(filedata) == 0: - return None + if isinstance(filedata, str): + if filedata.startswith("data:image/png;base64,"): + filedata = filedata[len("data:image/png;base64,"):] - filedata = filedata[0] - - if filedata.startswith("data:image/png;base64,"): - filedata = filedata[len("data:image/png;base64,"):] + filedata = base64.decodebytes(filedata.encode('utf-8')) + image = images.read(io.BytesIO(filedata)) + return image - filedata = base64.decodebytes(filedata.encode('utf-8')) - image = Image.open(io.BytesIO(filedata)) - return image + return None def add_paste_fields(tabname, init_img, fields, override_settings_component=None): @@ -138,8 +147,6 @@ def register_paste_params_button(binding: ParamBinding): def connect_paste_params_buttons(): for binding in registered_param_bindings: - if binding.tabname not in paste_fields: - continue destination_image_component = paste_fields[binding.tabname]["init_img"] fields = paste_fields[binding.tabname]["fields"] override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"] @@ -148,18 +155,19 @@ def connect_paste_params_buttons(): destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None) if binding.source_image_component and destination_image_component: + need_send_dementions = destination_width_component and binding.tabname != 'inpaint' if isinstance(binding.source_image_component, gr.Gallery): - func = send_image_and_dimensions if destination_width_component else image_from_url_text + func = send_image_and_dimensions if need_send_dementions else image_from_url_text jsfunc = "extract_image_from_gallery" else: - func = send_image_and_dimensions if destination_width_component else lambda x: x + func = send_image_and_dimensions if need_send_dementions else lambda x: x jsfunc = None binding.paste_button.click( fn=func, _js=jsfunc, inputs=[binding.source_image_component], - outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component], + outputs=[destination_image_component, destination_width_component, destination_height_component] if need_send_dementions else [destination_image_component], show_progress=False, ) @@ -187,6 +195,8 @@ def connect_paste_params_buttons(): def send_image_and_dimensions(x): if isinstance(x, Image.Image): img = x + elif isinstance(x, list) and isinstance(x[0], tuple): + img = x[0][0] else: img = image_from_url_text(x) @@ -267,17 +277,6 @@ def parse_generation_parameters(x: str, skip_fields: list[str] | None = None): else: prompt += ("" if prompt == "" else "\n") + line - if shared.opts.infotext_styles != "Ignore": - found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt) - - if shared.opts.infotext_styles == "Apply": - res["Styles array"] = found_styles - elif shared.opts.infotext_styles == "Apply if any" and found_styles: - res["Styles array"] = found_styles - - res["Prompt"] = prompt - res["Negative prompt"] = negative_prompt - for k, v in re_param.findall(lastline): try: if v[0] == '"' and v[-1] == '"': @@ -292,6 +291,26 @@ def parse_generation_parameters(x: str, skip_fields: list[str] | None = None): except Exception: print(f"Error parsing \"{k}: {v}\"") + # Extract styles from prompt + if shared.opts.infotext_styles != "Ignore": + found_styles, prompt_no_styles, negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt) + + same_hr_styles = True + if ("Hires prompt" in res or "Hires negative prompt" in res) and (infotext_ver > infotext_versions.v180_hr_styles if (infotext_ver := infotext_versions.parse_version(res.get("Version"))) else True): + hr_prompt, hr_negative_prompt = res.get("Hires prompt", prompt), res.get("Hires negative prompt", negative_prompt) + hr_found_styles, hr_prompt_no_styles, hr_negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(hr_prompt, hr_negative_prompt) + if same_hr_styles := found_styles == hr_found_styles: + res["Hires prompt"] = '' if hr_prompt_no_styles == prompt_no_styles else hr_prompt_no_styles + res['Hires negative prompt'] = '' if hr_negative_prompt_no_styles == negative_prompt_no_styles else hr_negative_prompt_no_styles + + if same_hr_styles: + prompt, negative_prompt = prompt_no_styles, negative_prompt_no_styles + if (shared.opts.infotext_styles == "Apply if any" and found_styles) or shared.opts.infotext_styles == "Apply": + res['Styles array'] = found_styles + + res["Prompt"] = prompt + res["Negative prompt"] = negative_prompt + # Missing CLIP skip means it was set to 1 (the default) if "Clip skip" not in res: res["Clip skip"] = "1" @@ -307,6 +326,9 @@ def parse_generation_parameters(x: str, skip_fields: list[str] | None = None): if "Hires sampler" not in res: res["Hires sampler"] = "Use same sampler" + if "Hires schedule type" not in res: + res["Hires schedule type"] = "Use same scheduler" + if "Hires checkpoint" not in res: res["Hires checkpoint"] = "Use same checkpoint" @@ -358,9 +380,15 @@ def parse_generation_parameters(x: str, skip_fields: list[str] | None = None): if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable": res["Cache FP16 weight for LoRA"] = False - if "Emphasis" not in res: + prompt_attention = prompt_parser.parse_prompt_attention(prompt) + prompt_attention += prompt_parser.parse_prompt_attention(negative_prompt) + prompt_uses_emphasis = len(prompt_attention) != len([p for p in prompt_attention if p[1] == 1.0 or p[0] == 'BREAK']) + if "Emphasis" not in res and prompt_uses_emphasis: res["Emphasis"] = "Original" + if "Refiner switch by sampling steps" not in res: + res["Refiner switch by sampling steps"] = False + infotext_versions.backcompat(res) for key in skip_fields: @@ -396,6 +424,9 @@ def create_override_settings_dict(text_pairs): res = {} + if not text_pairs: + return res + params = {} for pair in text_pairs: k, v = pair.split(":", maxsplit=1) @@ -458,7 +489,7 @@ def get_override_settings(params, *, skip_fields=None): def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname): def paste_func(prompt): - if not prompt and not shared.cmd_opts.hide_ui_dir_config: + if not prompt and not shared.cmd_opts.hide_ui_dir_config and not shared.cmd_opts.no_prompt_history: filename = os.path.join(data_path, "params.txt") try: with open(filename, "r", encoding="utf8") as file: @@ -472,7 +503,11 @@ def paste_func(prompt): for output, key in paste_fields: if callable(key): - v = key(params) + try: + v = key(params) + except Exception: + errors.report(f"Error executing {key}", exc_info=True) + v = None else: v = params.get(key, None) diff --git a/modules/infotext_versions.py b/modules/infotext_versions.py index 23b45c3f9..cea676cda 100644 --- a/modules/infotext_versions.py +++ b/modules/infotext_versions.py @@ -5,6 +5,8 @@ v160 = version.parse("1.6.0") v170_tsnr = version.parse("v1.7.0-225") +v180 = version.parse("1.8.0") +v180_hr_styles = version.parse("1.8.0-139") def parse_version(text): @@ -40,3 +42,5 @@ def backcompat(d): if ver < v170_tsnr: d["Downcast alphas_cumprod"] = True + if ver < v180 and d.get('Refiner'): + d["Refiner switch by sampling steps"] = True diff --git a/modules/initialize.py b/modules/initialize.py index 180e1f8e6..ec4d58a43 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -50,7 +50,7 @@ def imports(): shared_init.initialize() startup_timer.record("initialize shared") - from modules import processing, gradio_extensons, ui # noqa: F401 + from modules import processing, gradio_extensions, ui # noqa: F401 startup_timer.record("other imports") @@ -65,6 +65,7 @@ def check_versions(): def initialize(): from modules import initialize_util initialize_util.fix_torch_version() + initialize_util.fix_pytorch_lightning() initialize_util.fix_asyncio_event_loop_policy() initialize_util.validate_tls_options() initialize_util.configure_sigint_handler() @@ -123,7 +124,7 @@ def initialize_rest(*, reload_script_modules=False): with startup_timer.subcategory("load scripts"): scripts.load_scripts() - if reload_script_modules: + if reload_script_modules and shared.opts.enable_reloading_ui_scripts: for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: importlib.reload(module) startup_timer.record("reload script modules") diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 7801d9329..693b083c5 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -4,6 +4,8 @@ import sys import re +import starlette + from modules.timer import startup_timer @@ -24,6 +26,13 @@ def fix_torch_version(): torch.__long_version__ = torch.__version__ torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) +def fix_pytorch_lightning(): + # Checks if pytorch_lightning.utilities.distributed already exists in the sys.modules cache + if 'pytorch_lightning.utilities.distributed' not in sys.modules: + import pytorch_lightning + # Lets the user know that the library was not found and then will set it to pytorch_lightning.utilities.rank_zero + print("Pytorch_lightning.distributed not found, attempting pytorch_lightning.rank_zero") + sys.modules["pytorch_lightning.utilities.distributed"] = pytorch_lightning.utilities.rank_zero def fix_asyncio_event_loop_policy(): """ @@ -186,8 +195,7 @@ def configure_opts_onchange(): def setup_middleware(app): from starlette.middleware.gzip import GZipMiddleware - app.middleware_stack = None # reset current middleware to allow modifying user provided list - app.add_middleware(GZipMiddleware, minimum_size=1000) + app.user_middleware.insert(0, starlette.middleware.Middleware(GZipMiddleware, minimum_size=1000)) configure_cors_middleware(app) app.build_middleware_stack() # rebuild middleware stack on-the-fly @@ -205,5 +213,6 @@ def configure_cors_middleware(app): cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',') if cmd_opts.cors_allow_origins_regex: cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex - app.add_middleware(CORSMiddleware, **cors_options) + + app.user_middleware.insert(0, starlette.middleware.Middleware(CORSMiddleware, **cors_options)) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 596d14b66..4aac9e97c 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -9,6 +9,7 @@ import importlib.metadata import platform import json +import shlex from functools import lru_cache from typing import NamedTuple from pathlib import Path @@ -60,7 +61,7 @@ def check_python_version(): You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/ -{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""} +{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre" if is_windows else ""} Use --skip-python-version-check to suppress this warning. """) @@ -81,7 +82,7 @@ def git_tag_a1111(): except Exception: try: - changelog_md = os.path.join(os.path.dirname(os.path.dirname(__file__)), "CHANGELOG.md") + changelog_md = os.path.join(script_path, "CHANGELOG.md") with open(changelog_md, "r", encoding="utf-8") as file: line = next((line.strip() for line in file if line.strip()), "") line = line.replace("## ", "") @@ -240,7 +241,7 @@ def run_extension_installer(extension_dir): try: env = os.environ.copy() - env['PYTHONPATH'] = f"{os.path.abspath('.')}{os.pathsep}{env.get('PYTHONPATH', '')}" + env['PYTHONPATH'] = f"{script_path}{os.pathsep}{env.get('PYTHONPATH', '')}" stdout = run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env).strip() if stdout: @@ -490,7 +491,6 @@ def prepare_environment(): exit(0) - def configure_for_tests(): if "--api" not in sys.argv: sys.argv.append("--api") @@ -537,7 +537,7 @@ class ModelRef(NamedTuple): def start(): - print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}") + print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}") import webui if '--nowebui' in sys.argv: webui.api_only() diff --git a/modules/lowvram.py b/modules/lowvram.py index 908b5962d..b6dcaf527 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -1,9 +1,12 @@ +from collections import namedtuple + import torch from modules import devices, shared module_in_gpu = None cpu = torch.device("cpu") +ModuleWithParent = namedtuple('ModuleWithParent', ['module', 'parent'], defaults=['None']) def send_everything_to_cpu(): return diff --git a/modules/mac_specific.py b/modules/mac_specific.py index d96d86d79..039689f32 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -12,7 +12,7 @@ # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, # use check `getattr` and try it for compatibility. -# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty, +# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availability, # since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279 def check_for_mps() -> bool: if version.parse(torch.__version__) <= version.parse("2.0.1"): diff --git a/modules/masking.py b/modules/masking.py index 29a394527..2fc830319 100644 --- a/modules/masking.py +++ b/modules/masking.py @@ -1,17 +1,39 @@ from PIL import Image, ImageFilter, ImageOps -def get_crop_region(mask, pad=0): - """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. - For example, if a user has painted the top-right part of a 512x512 image, the result may be (256, 0, 512, 256)""" - mask_img = mask if isinstance(mask, Image.Image) else Image.fromarray(mask) - box = mask_img.getbbox() - if box: +def get_crop_region_v2(mask, pad=0): + """ + Finds a rectangular region that contains all masked ares in a mask. + Returns None if mask is completely black mask (all 0) + + Parameters: + mask: PIL.Image.Image L mode or numpy 1d array + pad: int number of pixels that the region will be extended on all sides + Returns: (x1, y1, x2, y2) | None + + Introduced post 1.9.0 + """ + mask = mask if isinstance(mask, Image.Image) else Image.fromarray(mask) + if box := mask.getbbox(): x1, y1, x2, y2 = box - else: # when no box is found - x1, y1 = mask_img.size - x2 = y2 = 0 - return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask_img.size[0]), min(y2 + pad, mask_img.size[1]) + return (max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask.size[0]), min(y2 + pad, mask.size[1])) if pad else box + + +def get_crop_region(mask, pad=0): + """ + Same function as get_crop_region_v2 but handles completely black mask (all 0) differently + when mask all black still return coordinates but the coordinates may be invalid ie x2>x1 or y2>y1 + Notes: it is possible for the coordinates to be "valid" again if pad size is sufficiently large + (mask_size.x-pad, mask_size.y-pad, pad, pad) + + Extension developer should use get_crop_region_v2 instead unless for compatibility considerations. + """ + mask = mask if isinstance(mask, Image.Image) else Image.fromarray(mask) + if box := get_crop_region_v2(mask, pad): + return box + x1, y1 = mask.size + x2 = y2 = 0 + return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask.size[0]), min(y2 + pad, mask.size[1]) def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height): diff --git a/modules/modelloader.py b/modules/modelloader.py index e100bb246..36e7415af 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -23,6 +23,7 @@ def load_file_from_url( model_dir: str, progress: bool = True, file_name: str | None = None, + hash_prefix: str | None = None, ) -> str: """Download a file from `url` into `model_dir`, using the file present if possible. @@ -36,11 +37,11 @@ def load_file_from_url( if not os.path.exists(cached_file): print(f'Downloading: "{url}" to {cached_file}\n') from torch.hub import download_url_to_file - download_url_to_file(url, cached_file, progress=progress) + download_url_to_file(url, cached_file, progress=progress, hash_prefix=hash_prefix) return cached_file -def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list: +def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None, hash_prefix=None) -> list: """ A one-and done loader to try finding the desired models in specified directories. @@ -49,6 +50,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None @param model_path: The location to store/find models in. @param command_path: A command-line argument to search for models in first. @param ext_filter: An optional list of filename extensions to filter by + @param hash_prefix: the expected sha256 of the model_url @return: A list of paths containing the desired model(s) """ output = [] @@ -78,7 +80,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None if model_url is not None and len(output) == 0: if download_name is not None: - output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name)) + output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name, hash_prefix=hash_prefix)) else: output.append(model_url) @@ -110,7 +112,7 @@ def load_upscalers(): except Exception: pass - datas = [] + data = [] commandline_options = vars(shared.cmd_opts) # some of upscaler classes will not go away after reloading their modules, and we'll end @@ -129,14 +131,35 @@ def load_upscalers(): scaler = cls(commandline_model_path) scaler.user_path = commandline_model_path scaler.model_download_path = commandline_model_path or scaler.model_path - datas += scaler.scalers + data += scaler.scalers shared.sd_upscalers = sorted( - datas, + data, # Special case for UpscalerNone keeps it at the beginning of the list. key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else "" ) +# None: not loaded, False: failed to load, True: loaded +_spandrel_extra_init_state = None + + +def _init_spandrel_extra_archs() -> None: + """ + Try to initialize `spandrel_extra_archs` (exactly once). + """ + global _spandrel_extra_init_state + if _spandrel_extra_init_state is not None: + return + + try: + import spandrel + import spandrel_extra_arches + spandrel.MAIN_REGISTRY.add(*spandrel_extra_arches.EXTRA_REGISTRY) + _spandrel_extra_init_state = True + except Exception: + logger.warning("Failed to load spandrel_extra_arches", exc_info=True) + _spandrel_extra_init_state = False + def load_spandrel_model( path: str | os.PathLike, @@ -146,11 +169,16 @@ def load_spandrel_model( dtype: str | torch.dtype | None = None, expected_architecture: str | None = None, ) -> spandrel.ModelDescriptor: + global _spandrel_extra_init_state + import spandrel + _init_spandrel_extra_archs() + model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path)) - if expected_architecture and model_descriptor.architecture != expected_architecture: + arch = model_descriptor.architecture + if expected_architecture and arch.name != expected_architecture: logger.warning( - f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})", + f"Model {path!r} is not a {expected_architecture!r} model (got {arch.name!r})", ) half = False if prefer_half: @@ -164,6 +192,6 @@ def load_spandrel_model( model_descriptor.model.eval() logger.debug( "Loaded %s from %s (device=%s, half=%s, dtype=%s)", - model_descriptor, path, device, half, dtype, + arch, path, device, half, dtype, ) return model_descriptor diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index 6db340da4..7b51c83c5 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -341,7 +341,7 @@ def p_losses(self, x_start, t, noise=None): elif self.parameterization == "x0": target = x_start else: - raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported") loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) @@ -901,7 +901,7 @@ def forward(self, x, c, *args, **kwargs): def apply_model(self, x_noisy, t, cond, return_ids=False): if isinstance(cond, dict): - # hybrid case, cond is exptected to be a dict + # hybrid case, cond is expected to be a dict pass else: if not isinstance(cond, list): @@ -937,7 +937,7 @@ def apply_model(self, x_noisy, t, cond, return_ids=False): cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] elif self.cond_stage_key == 'coordinates_bbox': - assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + assert 'original_image_size' in self.split_input_params, 'BoundingBoxRescaling is missing original_image_size' # assuming padding of unfold is always 0 and its dilation is always 1 n_patches_per_row = int((w - ks[0]) / stride[0] + 1) @@ -947,7 +947,7 @@ def apply_model(self, x_noisy, t, cond, return_ids=False): num_downs = self.first_stage_model.encoder.num_resolutions - 1 rescale_latent = 2 ** (num_downs) - # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # get top left positions of patches as conforming for the bbbox tokenizer, therefore we # need to rescale the tl patch coordinates to be in between (0,1) tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index 4a3651513..3333bc808 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -323,7 +323,7 @@ def cond_grad_fn(x, t_input, condition): def model_fn(x, t_continuous, condition, unconditional_condition): """ - The noise predicition model function that is used for DPM-Solver. + The noise prediction model function that is used for DPM-Solver. """ if t_continuous.reshape((-1,)).shape[0] == 1: t_continuous = t_continuous.expand((x.shape[0])) @@ -445,7 +445,7 @@ def data_prediction_fn(self, x, t): s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) x0 = torch.clamp(x0, -s, s) / s - return x0.to(x) + return x0 def model_fn(self, x, t): """ diff --git a/modules/models/sd3/mmdit.py b/modules/models/sd3/mmdit.py new file mode 100644 index 000000000..8ddf49a4e --- /dev/null +++ b/modules/models/sd3/mmdit.py @@ -0,0 +1,622 @@ +### This file contains impls for MM-DiT, the core model component of SD3 + +import math +from typing import Dict, Optional +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange, repeat +from modules.models.sd3.other_impls import attention, Mlp + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding""" + def __init__( + self, + img_size: Optional[int] = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + flatten: bool = True, + bias: bool = True, + strict_img_size: bool = True, + dynamic_img_pad: bool = False, + dtype=None, + device=None, + ): + super().__init__() + self.patch_size = (patch_size, patch_size) + if img_size is not None: + self.img_size = (img_size, img_size) + self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + else: + self.img_size = None + self.grid_size = None + self.num_patches = None + + # flatten spatial dim and transpose to channels last, kept for bwd compat + self.flatten = flatten + self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # NCHW -> NLC + return x + + +def modulate(x, shift, scale): + if shift is None: + shift = torch.zeros_like(scale) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scaling_factor=None, offset=None): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + if scaling_factor is not None: + grid = grid / scaling_factor + if offset is not None: + grid = grid - offset + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + + +class TimestepEmbedder(nn.Module): + """Embeds scalar timesteps into vector representations.""" + + def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(dtype=t.dtype) + return embedding + + def forward(self, t, dtype, **kwargs): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class VectorEmbedder(nn.Module): + """Embeds a flat vector of dimension input_dim""" + + def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +################################################################################# +# Core DiT Model # +################################################################################# + + +class QkvLinear(torch.nn.Linear): + pass + +def split_qkv(qkv, head_dim): + qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0) + return qkv[0], qkv[1], qkv[2] + +def optimized_attention(qkv, num_heads): + return attention(qkv[0], qkv[1], qkv[2], num_heads) + +class SelfAttention(nn.Module): + ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug") + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + attn_mode: str = "xformers", + pre_only: bool = False, + qk_norm: Optional[str] = None, + rmsnorm: bool = False, + dtype=None, + device=None, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.qkv = QkvLinear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + if not pre_only: + self.proj = nn.Linear(dim, dim, dtype=dtype, device=device) + assert attn_mode in self.ATTENTION_MODES + self.attn_mode = attn_mode + self.pre_only = pre_only + + if qk_norm == "rms": + self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device) + self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device) + elif qk_norm == "ln": + self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device) + self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device) + elif qk_norm is None: + self.ln_q = nn.Identity() + self.ln_k = nn.Identity() + else: + raise ValueError(qk_norm) + + def pre_attention(self, x: torch.Tensor): + B, L, C = x.shape + qkv = self.qkv(x) + q, k, v = split_qkv(qkv, self.head_dim) + q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1) + k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1) + return (q, k, v) + + def post_attention(self, x: torch.Tensor) -> torch.Tensor: + assert not self.pre_only + x = self.proj(x) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + (q, k, v) = self.pre_attention(x) + x = attention(q, k, v, self.num_heads) + x = self.post_attention(x) + return x + + +class RMSNorm(torch.nn.Module): + def __init__( + self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None + ): + """ + Initialize the RMSNorm normalization layer. + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + """ + super().__init__() + self.eps = eps + self.learnable_scale = elementwise_affine + if self.learnable_scale: + self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + else: + self.register_parameter("weight", None) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The normalized tensor. + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + """ + x = self._norm(x) + if self.learnable_scale: + return x * self.weight.to(device=x.device, dtype=x.dtype) + else: + return x + + +class SwiGLUFeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float] = None, + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (ColumnParallelLinear): Linear transformation for the first layer. + w2 (RowParallelLinear): Linear transformation for the second layer. + w3 (ColumnParallelLinear): Linear transformation for the third layer. + + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +class DismantledBlock(nn.Module): + """A DiT block with gated adaptive layer norm (adaLN) conditioning.""" + + ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug") + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: str = "xformers", + qkv_bias: bool = False, + pre_only: bool = False, + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + qk_norm: Optional[str] = None, + dtype=None, + device=None, + **block_kwargs, + ): + super().__init__() + assert attn_mode in self.ATTENTION_MODES + if not rmsnorm: + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + else: + self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=pre_only, qk_norm=qk_norm, rmsnorm=rmsnorm, dtype=dtype, device=device) + if not pre_only: + if not rmsnorm: + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + else: + self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + if not pre_only: + if not swiglu: + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=nn.GELU(approximate="tanh"), dtype=dtype, device=device) + else: + self.mlp = SwiGLUFeedForward(dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256) + self.scale_mod_only = scale_mod_only + if not scale_mod_only: + n_mods = 6 if not pre_only else 2 + else: + n_mods = 4 if not pre_only else 1 + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device)) + self.pre_only = pre_only + + def pre_attention(self, x: torch.Tensor, c: torch.Tensor): + assert x is not None, "pre_attention called with None input" + if not self.pre_only: + if not self.scale_mod_only: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + else: + shift_msa = None + shift_mlp = None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(4, dim=1) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp) + else: + if not self.scale_mod_only: + shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1) + else: + shift_msa = None + scale_msa = self.adaLN_modulation(c) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, None + + def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): + assert not self.pre_only + x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + assert not self.pre_only + (q, k, v), intermediates = self.pre_attention(x, c) + attn = attention(q, k, v, self.attn.num_heads) + return self.post_attention(attn, *intermediates) + + +def block_mixing(context, x, context_block, x_block, c): + assert context is not None, "block_mixing called with None context" + context_qkv, context_intermediates = context_block.pre_attention(context, c) + + x_qkv, x_intermediates = x_block.pre_attention(x, c) + + o = [] + for t in range(3): + o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1)) + q, k, v = tuple(o) + + attn = attention(q, k, v, x_block.attn.num_heads) + context_attn, x_attn = (attn[:, : context_qkv[0].shape[1]], attn[:, context_qkv[0].shape[1] :]) + + if not context_block.pre_only: + context = context_block.post_attention(context_attn, *context_intermediates) + else: + context = None + x = x_block.post_attention(x_attn, *x_intermediates) + return context, x + + +class JointBlock(nn.Module): + """just a small wrapper to serve as a fsdp unit""" + + def __init__(self, *args, **kwargs): + super().__init__() + pre_only = kwargs.pop("pre_only") + qk_norm = kwargs.pop("qk_norm", None) + self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs) + self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs) + + def forward(self, *args, **kwargs): + return block_mixing(*args, context_block=self.context_block, x_block=self.x_block, **kwargs) + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__(self, hidden_size: int, patch_size: int, out_channels: int, total_out_channels: Optional[int] = None, dtype=None, device=None): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = ( + nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) + if (total_out_channels is None) + else nn.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device) + ) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class MMDiT(nn.Module): + """Diffusion model with a Transformer backbone.""" + + def __init__( + self, + input_size: int = 32, + patch_size: int = 2, + in_channels: int = 4, + depth: int = 28, + mlp_ratio: float = 4.0, + learn_sigma: bool = False, + adm_in_channels: Optional[int] = None, + context_embedder_config: Optional[Dict] = None, + register_length: int = 0, + attn_mode: str = "torch", + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + out_channels: Optional[int] = None, + pos_embed_scaling_factor: Optional[float] = None, + pos_embed_offset: Optional[float] = None, + pos_embed_max_size: Optional[int] = None, + num_patches = None, + qk_norm: Optional[str] = None, + qkv_bias: bool = True, + dtype = None, + device = None, + ): + super().__init__() + self.dtype = dtype + self.learn_sigma = learn_sigma + self.in_channels = in_channels + default_out_channels = in_channels * 2 if learn_sigma else in_channels + self.out_channels = out_channels if out_channels is not None else default_out_channels + self.patch_size = patch_size + self.pos_embed_scaling_factor = pos_embed_scaling_factor + self.pos_embed_offset = pos_embed_offset + self.pos_embed_max_size = pos_embed_max_size + + # apply magic --> this defines a head_size of 64 + hidden_size = 64 * depth + num_heads = depth + + self.num_heads = num_heads + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True, strict_img_size=self.pos_embed_max_size is None, dtype=dtype, device=device) + self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device) + + if adm_in_channels is not None: + assert isinstance(adm_in_channels, int) + self.y_embedder = VectorEmbedder(adm_in_channels, hidden_size, dtype=dtype, device=device) + + self.context_embedder = nn.Identity() + if context_embedder_config is not None: + if context_embedder_config["target"] == "torch.nn.Linear": + self.context_embedder = nn.Linear(**context_embedder_config["params"], dtype=dtype, device=device) + + self.register_length = register_length + if self.register_length > 0: + self.register = nn.Parameter(torch.randn(1, register_length, hidden_size, dtype=dtype, device=device)) + + # num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + # just use a buffer already + if num_patches is not None: + self.register_buffer( + "pos_embed", + torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device), + ) + else: + self.pos_embed = None + + self.joint_blocks = nn.ModuleList( + [ + JointBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=i == depth - 1, rmsnorm=rmsnorm, scale_mod_only=scale_mod_only, swiglu=swiglu, qk_norm=qk_norm, dtype=dtype, device=device) + for i in range(depth) + ] + ) + + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, dtype=dtype, device=device) + + def cropped_pos_embed(self, hw): + assert self.pos_embed_max_size is not None + p = self.x_embedder.patch_size[0] + h, w = hw + # patched size + h = h // p + w = w // p + assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) + assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size) + top = (self.pos_embed_max_size - h) // 2 + left = (self.pos_embed_max_size - w) // 2 + spatial_pos_embed = rearrange( + self.pos_embed, + "1 (h w) c -> 1 h w c", + h=self.pos_embed_max_size, + w=self.pos_embed_max_size, + ) + spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] + spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c") + return spatial_pos_embed + + def unpatchify(self, x, hw=None): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + if hw is None: + h = w = int(x.shape[1] ** 0.5) + else: + h, w = hw + h = h // p + w = w // p + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) + return imgs + + def forward_core_with_concat(self, x: torch.Tensor, c_mod: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.register_length > 0: + context = torch.cat((repeat(self.register, "1 ... -> b ...", b=x.shape[0]), context if context is not None else torch.Tensor([]).type_as(x)), 1) + + # context is B, L', D + # x is B, L, D + for block in self.joint_blocks: + context, x = block(context, x, c=c_mod) + + x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels) + return x + + def forward(self, x: torch.Tensor, t: torch.Tensor, y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of class labels + """ + hw = x.shape[-2:] + x = self.x_embedder(x) + self.cropped_pos_embed(hw) + c = self.t_embedder(t, dtype=x.dtype) # (N, D) + if y is not None: + y = self.y_embedder(y) # (N, D) + c = c + y # (N, D) + + context = self.context_embedder(context) + + x = self.forward_core_with_concat(x, c, context) + + x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W) + return x diff --git a/modules/models/sd3/other_impls.py b/modules/models/sd3/other_impls.py new file mode 100644 index 000000000..78c1dc687 --- /dev/null +++ b/modules/models/sd3/other_impls.py @@ -0,0 +1,510 @@ +### This file contains impls for underlying related models (CLIP, T5, etc) + +import torch +import math +from torch import nn +from transformers import CLIPTokenizer, T5TokenizerFast + +from modules import sd_hijack + + +################################################################################################# +### Core/Utility +################################################################################################# + + +class AutocastLinear(nn.Linear): + """Same as usual linear layer, but casts its weights to whatever the parameter type is. + + This is different from torch.autocast in a way that float16 layer processing float32 input + will return float16 with autocast on, and float32 with this. T5 seems to be fucked + if you do it in full float16 (returning almost all zeros in the final output). + """ + + def forward(self, x): + return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + + +def attention(q, k, v, heads, mask=None): + """Convenience wrapper around a basic attention operation""" + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = [t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)] + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + return out.transpose(1, 2).reshape(b, -1, heads * dim_head) + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks""" + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, dtype=None, device=None): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device) + self.act = act_layer + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +################################################################################################# +### CLIP +################################################################################################# + + +class CLIPAttention(torch.nn.Module): + def __init__(self, embed_dim, heads, dtype, device): + super().__init__() + self.heads = heads + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + + def forward(self, x, mask=None): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + out = attention(q, k, v, self.heads, mask) + return self.out_proj(out) + + +ACTIVATIONS = { + "quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), + "gelu": torch.nn.functional.gelu, +} + +class CLIPLayer(torch.nn.Module): + def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): + super().__init__() + self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) + self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + #self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) + self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device) + + def forward(self, x, mask=None): + x += self.self_attn(self.layer_norm1(x), mask) + x += self.mlp(self.layer_norm2(x)) + return x + + +class CLIPEncoder(torch.nn.Module): + def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): + super().__init__() + self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)]) + + def forward(self, x, mask=None, intermediate_output=None): + if intermediate_output is not None: + if intermediate_output < 0: + intermediate_output = len(self.layers) + intermediate_output + intermediate = None + for i, layer in enumerate(self.layers): + x = layer(x, mask) + if i == intermediate_output: + intermediate = x.clone() + return x, intermediate + + +class CLIPEmbeddings(torch.nn.Module): + def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, textual_inversion_key="clip_l"): + super().__init__() + self.token_embedding = sd_hijack.TextualInversionEmbeddings(vocab_size, embed_dim, dtype=dtype, device=device, textual_inversion_key=textual_inversion_key) + self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens): + return self.token_embedding(input_tokens) + self.position_embedding.weight + + +class CLIPTextModel_(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + num_layers = config_dict["num_hidden_layers"] + embed_dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + intermediate_size = config_dict["intermediate_size"] + intermediate_activation = config_dict["hidden_act"] + super().__init__() + self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l')) + self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) + self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True): + x = self.embeddings(input_tokens) + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output) + x = self.final_layer_norm(x) + if i is not None and final_layer_norm_intermediate: + i = self.final_layer_norm(i) + pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),] + return x, i, pooled_output + + +class CLIPTextModel(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_hidden_layers"] + self.text_model = CLIPTextModel_(config_dict, dtype, device) + embed_dim = config_dict["hidden_size"] + self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) + self.text_projection.weight.copy_(torch.eye(embed_dim)) + self.dtype = dtype + + def get_input_embeddings(self): + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, embeddings): + self.text_model.embeddings.token_embedding = embeddings + + def forward(self, *args, **kwargs): + x = self.text_model(*args, **kwargs) + out = self.text_projection(x[2]) + return (x[0], x[1], out, x[2]) + + +class SDTokenizer: + def __init__(self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None): + self.tokenizer = tokenizer + self.max_length = max_length + self.min_length = min_length + empty = self.tokenizer('')["input_ids"] + if has_start_token: + self.tokens_start = 1 + self.start_token = empty[0] + self.end_token = empty[1] + else: + self.tokens_start = 0 + self.start_token = None + self.end_token = empty[0] + self.pad_with_end = pad_with_end + self.pad_to_max_length = pad_to_max_length + vocab = self.tokenizer.get_vocab() + self.inv_vocab = {v: k for k, v in vocab.items()} + self.max_word_length = 8 + + + def tokenize_with_weights(self, text:str): + """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" + if self.pad_with_end: + pad_token = self.end_token + else: + pad_token = 0 + batch = [] + if self.start_token is not None: + batch.append((self.start_token, 1.0)) + to_tokenize = text.replace("\n", " ").split(' ') + to_tokenize = [x for x in to_tokenize if x != ""] + for word in to_tokenize: + batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]]) + batch.append((self.end_token, 1.0)) + if self.pad_to_max_length: + batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) + if self.min_length is not None and len(batch) < self.min_length: + batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) + return [batch] + + +class SDXLClipGTokenizer(SDTokenizer): + def __init__(self, tokenizer): + super().__init__(pad_with_end=False, tokenizer=tokenizer) + + +class SD3Tokenizer: + def __init__(self): + clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) + self.clip_g = SDXLClipGTokenizer(clip_tokenizer) + self.t5xxl = T5XXLTokenizer() + + def tokenize_with_weights(self, text:str): + out = {} + out["g"] = self.clip_g.tokenize_with_weights(text) + out["l"] = self.clip_l.tokenize_with_weights(text) + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text) + return out + + +class ClipTokenWeightEncoder: + def encode_token_weights(self, token_weight_pairs): + tokens = [a[0] for a in token_weight_pairs[0]] + out, pooled = self([tokens]) + if pooled is not None: + first_pooled = pooled[0:1].cpu() + else: + first_pooled = pooled + output = [out[0:1]] + return torch.cat(output, dim=-2).cpu(), first_pooled + + +class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + LAYERS = ["last", "pooled", "hidden"] + def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel, + special_tokens=None, layer_norm_hidden_state=True, return_projected_pooled=True): + super().__init__() + assert layer in self.LAYERS + self.transformer = model_class(textmodel_json_config, dtype, device) + self.num_layers = self.transformer.num_layers + self.max_length = max_length + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + self.layer = layer + self.layer_idx = None + self.special_tokens = special_tokens if special_tokens is not None else {"start": 49406, "end": 49407, "pad": 49407} + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.layer_norm_hidden_state = layer_norm_hidden_state + self.return_projected_pooled = return_projected_pooled + if layer == "hidden": + assert layer_idx is not None + assert abs(layer_idx) < self.num_layers + self.set_clip_options({"layer": layer_idx}) + self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) + + def set_clip_options(self, options): + layer_idx = options.get("layer", self.layer_idx) + self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) + if layer_idx is None or abs(layer_idx) > self.num_layers: + self.layer = "last" + else: + self.layer = "hidden" + self.layer_idx = layer_idx + + def forward(self, tokens): + backup_embeds = self.transformer.get_input_embeddings() + tokens = torch.asarray(tokens, dtype=torch.int64, device=backup_embeds.weight.device) + outputs = self.transformer(tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) + self.transformer.set_input_embeddings(backup_embeds) + if self.layer == "last": + z = outputs[0] + else: + z = outputs[1] + pooled_output = None + if len(outputs) >= 3: + if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None: + pooled_output = outputs[3].float() + elif outputs[2] is not None: + pooled_output = outputs[2].float() + return z.float(), pooled_output + + +class SDXLClipG(SDClipModel): + """Wraps the CLIP-G model into the SD-CLIP-Model interface""" + def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None): + if layer == "penultimate": + layer="hidden" + layer_idx=-2 + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False) + + +class T5XXLModel(SDClipModel): + """Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience""" + def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5) + + +################################################################################################# +### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl +################################################################################################# + +class T5XXLTokenizer(SDTokenizer): + """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" + def __init__(self): + super().__init__(pad_with_end=False, tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77) + + +class T5LayerNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device)) + self.variance_epsilon = eps + + def forward(self, x): + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + return self.weight.to(device=x.device, dtype=x.dtype) * x + + +class T5DenseGatedActDense(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.wi_0 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wi_1 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wo = AutocastLinear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) + + def forward(self, x): + hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") + hidden_linear = self.wi_1(x) + x = hidden_gelu * hidden_linear + x = self.wo(x) + return x + + +class T5LayerFF(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x): + forwarded_states = self.layer_norm(x) + forwarded_states = self.DenseReluDense(forwarded_states) + x += forwarded_states + return x + + +class T5Attention(torch.nn.Module): + def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.k = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.v = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.o = AutocastLinear(inner_dim, model_dim, bias=False, dtype=dtype, device=device) + self.num_heads = num_heads + self.relative_attention_bias = None + if relative_attention_bias: + self.relative_attention_num_buckets = 32 + self.relative_attention_max_distance = 128 + self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min(relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)) + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=True, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward(self, x, past_bias=None): + q = self.q(x) + k = self.k(x) + v = self.v(x) + + if self.relative_attention_bias is not None: + past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device) + if past_bias is not None: + mask = past_bias + else: + mask = None + + out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None) + + return self.o(out), past_bias + + +class T5LayerSelfAttention(torch.nn.Module): + def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x, past_bias=None): + output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias) + x += output + return x, past_bias + + +class T5Block(torch.nn.Module): + def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + self.layer = torch.nn.ModuleList() + self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device)) + self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device)) + + def forward(self, x, past_bias=None): + x, past_bias = self.layer[0](x, past_bias) + x = self.layer[-1](x) + return x, past_bias + + +class T5Stack(torch.nn.Module): + def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device): + super().__init__() + self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device) + self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)]) + self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True): + intermediate = None + x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes + past_bias = None + for i, layer in enumerate(self.block): + x, past_bias = layer(x, past_bias) + if i == intermediate_output: + intermediate = x.clone() + x = self.final_layer_norm(x) + if intermediate is not None and final_layer_norm_intermediate: + intermediate = self.final_layer_norm(intermediate) + return x, intermediate + + +class T5(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_layers"] + self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device) + self.dtype = dtype + + def get_input_embeddings(self): + return self.encoder.embed_tokens + + def set_input_embeddings(self, embeddings): + self.encoder.embed_tokens = embeddings + + def forward(self, *args, **kwargs): + return self.encoder(*args, **kwargs) diff --git a/modules/models/sd3/sd3_cond.py b/modules/models/sd3/sd3_cond.py new file mode 100644 index 000000000..66f59e298 --- /dev/null +++ b/modules/models/sd3/sd3_cond.py @@ -0,0 +1,222 @@ +import os +import safetensors +import torch +import typing + +from transformers import CLIPTokenizer, T5TokenizerFast + +from modules import shared, devices, modelloader, sd_hijack_clip, prompt_parser +from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer + + +class SafetensorsMapping(typing.Mapping): + def __init__(self, file): + self.file = file + + def __len__(self): + return len(self.file.keys()) + + def __iter__(self): + for key in self.file.keys(): + yield key + + def __getitem__(self, key): + return self.file.get_tensor(key) + + +CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors" +CLIPL_CONFIG = { + "hidden_act": "quick_gelu", + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12, +} + +CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors" +CLIPG_CONFIG = { + "hidden_act": "gelu", + "hidden_size": 1280, + "intermediate_size": 5120, + "num_attention_heads": 20, + "num_hidden_layers": 32, + "textual_inversion_key": "clip_g", +} + +T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors" +T5_CONFIG = { + "d_ff": 10240, + "d_model": 4096, + "num_heads": 64, + "num_layers": 24, + "vocab_size": 32128, +} + + +class Sd3ClipLG(sd_hijack_clip.TextConditionalModel): + def __init__(self, clip_l, clip_g): + super().__init__() + + self.clip_l = clip_l + self.clip_g = clip_g + + self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + + empty = self.tokenizer('')["input_ids"] + self.id_start = empty[0] + self.id_end = empty[1] + self.id_pad = empty[1] + + self.return_pooled = True + + def tokenize(self, texts): + return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] + + def encode_with_transformers(self, tokens): + tokens_g = tokens.clone() + + for batch_pos in range(tokens_g.shape[0]): + index = tokens_g[batch_pos].cpu().tolist().index(self.id_end) + tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 0 + + l_out, l_pooled = self.clip_l(tokens) + g_out, g_pooled = self.clip_g(tokens_g) + + lg_out = torch.cat([l_out, g_out], dim=-1) + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + + vector_out = torch.cat((l_pooled, g_pooled), dim=-1) + + lg_out.pooled = vector_out + return lg_out + + def encode_embedding_init_text(self, init_text, nvpt): + return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX + + +class Sd3T5(torch.nn.Module): + def __init__(self, t5xxl): + super().__init__() + + self.t5xxl = t5xxl + self.tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl") + + empty = self.tokenizer('', padding='max_length', max_length=2)["input_ids"] + self.id_end = empty[0] + self.id_pad = empty[1] + + def tokenize(self, texts): + return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] + + def tokenize_line(self, line, *, target_token_count=None): + if shared.opts.emphasis != "None": + parsed = prompt_parser.parse_prompt_attention(line) + else: + parsed = [[line, 1.0]] + + tokenized = self.tokenize([text for text, _ in parsed]) + + tokens = [] + multipliers = [] + + for text_tokens, (text, weight) in zip(tokenized, parsed): + if text == 'BREAK' and weight == -1: + continue + + tokens += text_tokens + multipliers += [weight] * len(text_tokens) + + tokens += [self.id_end] + multipliers += [1.0] + + if target_token_count is not None: + if len(tokens) < target_token_count: + tokens += [self.id_pad] * (target_token_count - len(tokens)) + multipliers += [1.0] * (target_token_count - len(tokens)) + else: + tokens = tokens[0:target_token_count] + multipliers = multipliers[0:target_token_count] + + return tokens, multipliers + + def forward(self, texts, *, token_count): + if not self.t5xxl or not shared.opts.sd3_enable_t5: + return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype) + + tokens_batch = [] + + for text in texts: + tokens, multipliers = self.tokenize_line(text, target_token_count=token_count) + tokens_batch.append(tokens) + + t5_out, t5_pooled = self.t5xxl(tokens_batch) + + return t5_out + + def encode_embedding_init_text(self, init_text, nvpt): + return torch.zeros((nvpt, 4096), device=devices.device) # XXX + + +class SD3Cond(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.tokenizer = SD3Tokenizer() + + with torch.no_grad(): + self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype) + self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) + + if shared.opts.sd3_enable_t5: + self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype) + else: + self.t5xxl = None + + self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g) + self.model_t5 = Sd3T5(self.t5xxl) + + def forward(self, prompts: list[str]): + with devices.without_autocast(): + lg_out, vector_out = self.model_lg(prompts) + t5_out = self.model_t5(prompts, token_count=lg_out.shape[1]) + lgt_out = torch.cat([lg_out, t5_out], dim=-2) + + return { + 'crossattn': lgt_out, + 'vector': vector_out, + } + + def before_load_weights(self, state_dict): + clip_path = os.path.join(shared.models_path, "CLIP") + + if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict: + clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors") + with safetensors.safe_open(clip_g_file, framework="pt") as file: + self.clip_g.transformer.load_state_dict(SafetensorsMapping(file)) + + if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict: + clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors") + with safetensors.safe_open(clip_l_file, framework="pt") as file: + self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) + + if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict: + t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors") + with safetensors.safe_open(t5_file, framework="pt") as file: + self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) + + def encode_embedding_init_text(self, init_text, nvpt): + return self.model_lg.encode_embedding_init_text(init_text, nvpt) + + def tokenize(self, texts): + return self.model_lg.tokenize(texts) + + def medvram_modules(self): + return [self.clip_g, self.clip_l, self.t5xxl] + + def get_token_count(self, text): + _, token_count = self.model_lg.process_texts([text]) + + return token_count + + def get_target_prompt_token_count(self, token_count): + return self.model_lg.get_target_prompt_token_count(token_count) diff --git a/modules/models/sd3/sd3_impls.py b/modules/models/sd3/sd3_impls.py new file mode 100644 index 000000000..59f11b2cb --- /dev/null +++ b/modules/models/sd3/sd3_impls.py @@ -0,0 +1,374 @@ +### Impls of the SD3 core diffusion model and VAE + +import torch +import math +import einops +from modules.models.sd3.mmdit import MMDiT +from PIL import Image + + +################################################################################################# +### MMDiT Model Wrapping +################################################################################################# + + +class ModelSamplingDiscreteFlow(torch.nn.Module): + """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models""" + def __init__(self, shift=1.0): + super().__init__() + self.shift = shift + timesteps = 1000 + ts = self.sigma(torch.arange(1, timesteps + 1, 1)) + self.register_buffer('sigmas', ts) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return sigma * 1000 + + def sigma(self, timestep: torch.Tensor): + timestep = timestep / 1000.0 + if self.shift == 1.0: + return timestep + return self.shift * timestep / (1 + (self.shift - 1) * timestep) + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input - model_output * sigma + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + return sigma * noise + (1.0 - sigma) * latent_image + + +class BaseModel(torch.nn.Module): + """Wrapper around the core MM-DiT model""" + def __init__(self, shift=1.0, device=None, dtype=torch.float32, state_dict=None, prefix=""): + super().__init__() + # Important configuration values can be quickly determined by checking shapes in the source file + # Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change) + patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2] + depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64 + num_patches = state_dict[f"{prefix}pos_embed"].shape[1] + pos_embed_max_size = round(math.sqrt(num_patches)) + adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1] + context_shape = state_dict[f"{prefix}context_embedder.weight"].shape + context_embedder_config = { + "target": "torch.nn.Linear", + "params": { + "in_features": context_shape[1], + "out_features": context_shape[0] + } + } + self.diffusion_model = MMDiT(input_size=None, pos_embed_scaling_factor=None, pos_embed_offset=None, pos_embed_max_size=pos_embed_max_size, patch_size=patch_size, in_channels=16, depth=depth, num_patches=num_patches, adm_in_channels=adm_in_channels, context_embedder_config=context_embedder_config, device=device, dtype=dtype) + self.model_sampling = ModelSamplingDiscreteFlow(shift=shift) + self.depth = depth + + def apply_model(self, x, sigma, c_crossattn=None, y=None): + dtype = self.get_dtype() + timestep = self.model_sampling.timestep(sigma).float() + model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype)).float() + return self.model_sampling.calculate_denoised(sigma, model_output, x) + + def forward(self, *args, **kwargs): + return self.apply_model(*args, **kwargs) + + def get_dtype(self): + return self.diffusion_model.dtype + + +class CFGDenoiser(torch.nn.Module): + """Helper for applying CFG Scaling to diffusion outputs""" + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, x, timestep, cond, uncond, cond_scale): + # Run cond and uncond in a batch together + batched = self.model.apply_model(torch.cat([x, x]), torch.cat([timestep, timestep]), c_crossattn=torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]), y=torch.cat([cond["y"], uncond["y"]])) + # Then split and apply CFG Scaling + pos_out, neg_out = batched.chunk(2) + scaled = neg_out + (pos_out - neg_out) * cond_scale + return scaled + + +class SD3LatentFormat: + """Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift""" + def __init__(self): + self.scale_factor = 1.5305 + self.shift_factor = 0.0609 + + def process_in(self, latent): + return (latent - self.shift_factor) * self.scale_factor + + def process_out(self, latent): + return (latent / self.scale_factor) + self.shift_factor + + def decode_latent_to_preview(self, x0): + """Quick RGB approximate preview of sd3 latents""" + factors = torch.tensor([ + [-0.0645, 0.0177, 0.1052], [ 0.0028, 0.0312, 0.0650], + [ 0.1848, 0.0762, 0.0360], [ 0.0944, 0.0360, 0.0889], + [ 0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284], + [ 0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047], + [-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039], + [ 0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481], + [ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867], + [-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259] + ], device="cpu") + latent_image = x0[0].permute(1, 2, 0).cpu() @ factors + + latents_ubyte = (((latent_image + 1) / 2) + .clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + .byte()).cpu() + + return Image.fromarray(latents_ubyte.numpy()) + + +################################################################################################# +### K-Diffusion Sampling +################################################################################################# + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + return x[(...,) + (None,) * dims_to_append] + + +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / append_dims(sigma, x.ndim) + + +@torch.no_grad() +@torch.autocast("cuda", dtype=torch.float16) +def sample_euler(model, x, sigmas, extra_args=None): + """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in range(len(sigmas) - 1): + sigma_hat = sigmas[i] + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + dt = sigmas[i + 1] - sigma_hat + # Euler method + x = x + d * dt + return x + + +################################################################################################# +### VAE +################################################################################################# + + +def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) + + +class ResnetBlock(torch.nn.Module): + def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = Normalize(in_channels, dtype=dtype, device=device) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.norm2 = Normalize(out_channels, dtype=dtype, device=device) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + if self.in_channels != self.out_channels: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + else: + self.nin_shortcut = None + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + hidden = x + hidden = self.norm1(hidden) + hidden = self.swish(hidden) + hidden = self.conv1(hidden) + hidden = self.norm2(hidden) + hidden = self.swish(hidden) + hidden = self.conv2(hidden) + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + return x + hidden + + +class AttnBlock(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.norm = Normalize(in_channels, dtype=dtype, device=device) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + + def forward(self, x): + hidden = self.norm(x) + q = self.q(hidden) + k = self.k(hidden) + v = self.v(hidden) + b, c, h, w = q.shape + q, k, v = [einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous() for x in (q, k, v)] + hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default + hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + hidden = self.proj_out(hidden) + return x + hidden + + +class Downsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device) + + def forward(self, x): + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class VAEEncoder(torch.nn.Module): + def __init__(self, ch=128, ch_mult=(1,2,4,4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = torch.nn.ModuleList() + for i_level in range(self.num_resolutions): + block = torch.nn.ModuleList() + attn = torch.nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for _ in range(num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) + block_in = block_out + down = torch.nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, dtype=dtype, device=device) + self.down.append(down) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = self.swish(h) + h = self.conv_out(h) + return h + + +class VAEDecoder(torch.nn.Module): + def __init__(self, ch=128, out_ch=3, ch_mult=(1, 2, 4, 4), num_res_blocks=2, resolution=256, z_channels=16, dtype=torch.float32, device=None): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + # upsampling + self.up = torch.nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = torch.nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) + block_in = block_out + up = torch.nn.Module() + up.block = block + if i_level != 0: + up.upsample = Upsample(block_in, dtype=dtype, device=device) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, z): + # z to block_in + hidden = self.conv_in(z) + # middle + hidden = self.mid.block_1(hidden) + hidden = self.mid.attn_1(hidden) + hidden = self.mid.block_2(hidden) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + hidden = self.up[i_level].block[i_block](hidden) + if i_level != 0: + hidden = self.up[i_level].upsample(hidden) + # end + hidden = self.norm_out(hidden) + hidden = self.swish(hidden) + hidden = self.conv_out(hidden) + return hidden + + +class SDVAE(torch.nn.Module): + def __init__(self, dtype=torch.float32, device=None): + super().__init__() + self.encoder = VAEEncoder(dtype=dtype, device=device) + self.decoder = VAEDecoder(dtype=dtype, device=device) + + @torch.autocast("cuda", dtype=torch.float16) + def decode(self, latent): + return self.decoder(latent) + + @torch.autocast("cuda", dtype=torch.float16) + def encode(self, image): + hidden = self.encoder(image) + mean, logvar = torch.chunk(hidden, 2, dim=1) + logvar = torch.clamp(logvar, -30.0, 20.0) + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) diff --git a/modules/models/sd3/sd3_model.py b/modules/models/sd3/sd3_model.py new file mode 100644 index 000000000..a8a30e7f6 --- /dev/null +++ b/modules/models/sd3/sd3_model.py @@ -0,0 +1,96 @@ +import contextlib + +import torch + +import k_diffusion +from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat +from modules.models.sd3.sd3_cond import SD3Cond + +from modules import shared, devices + + +class SD3Denoiser(k_diffusion.external.DiscreteSchedule): + def __init__(self, inner_model, sigmas): + super().__init__(sigmas, quantize=shared.opts.enable_quantization) + self.inner_model = inner_model + + def forward(self, input, sigma, **kwargs): + return self.inner_model.apply_model(input, sigma, **kwargs) + + +class SD3Inferencer(torch.nn.Module): + def __init__(self, state_dict, shift=3, use_ema=False): + super().__init__() + + self.shift = shift + + with torch.no_grad(): + self.model = BaseModel(shift=shift, state_dict=state_dict, prefix="model.diffusion_model.", device="cpu", dtype=devices.dtype) + self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae) + self.first_stage_model.dtype = self.model.diffusion_model.dtype + + self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1) + + self.text_encoders = SD3Cond() + self.cond_stage_key = 'txt' + + self.parameterization = "eps" + self.model.conditioning_key = "crossattn" + + self.latent_format = SD3LatentFormat() + self.latent_channels = 16 + + @property + def cond_stage_model(self): + return self.text_encoders + + def before_load_weights(self, state_dict): + self.cond_stage_model.before_load_weights(state_dict) + + def ema_scope(self): + return contextlib.nullcontext() + + def get_learned_conditioning(self, batch: list[str]): + return self.cond_stage_model(batch) + + def apply_model(self, x, t, cond): + return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector']) + + def decode_first_stage(self, latent): + latent = self.latent_format.process_out(latent) + return self.first_stage_model.decode(latent) + + def encode_first_stage(self, image): + latent = self.first_stage_model.encode(image) + return self.latent_format.process_in(latent) + + def get_first_stage_encoding(self, x): + return x + + def create_denoiser(self): + return SD3Denoiser(self, self.model.model_sampling.sigmas) + + def medvram_fields(self): + return [ + (self, 'first_stage_model'), + (self, 'text_encoders'), + (self, 'model'), + ] + + def add_noise_to_latent(self, x, noise, amount): + return x * (1 - amount) + noise * amount + + def fix_dimensions(self, width, height): + return width // 16 * 16, height // 16 * 16 + + def diffusers_weight_mapping(self): + for i in range(self.model.depth): + yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_q_proj" + yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_k_proj" + yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_v_proj" + yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_joint_blocks_{i}_x_block_attn_proj" + + yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_q_proj" + yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_k_proj" + yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_v_proj" + yield f"transformer.transformer_blocks.{i}.attn.add_out_proj.0", f"diffusion_model_joint_blocks_{i}_context_block_attn_proj" diff --git a/modules/options.py b/modules/options.py index 35ccade25..2a78a825e 100644 --- a/modules/options.py +++ b/modules/options.py @@ -240,6 +240,9 @@ def dumpjson(self): item_categories = {} for item in self.data_labels.values(): + if item.section[0] is None: + continue + category = categories.mapping.get(item.category_id) category = "Uncategorized" if category is None else category.label if category not in item_categories: diff --git a/modules/paths_internal.py b/modules/paths_internal.py index 2ed1392a4..67521f5cd 100644 --- a/modules/paths_internal.py +++ b/modules/paths_internal.py @@ -7,7 +7,7 @@ from pathlib import Path -normalized_filepath = lambda filepath: str(Path(filepath).resolve()) +normalized_filepath = lambda filepath: str(Path(filepath).absolute()) commandline_args = os.environ.get('COMMANDLINE_ARGS', "") sys.argv += shlex.split(commandline_args) @@ -24,14 +24,15 @@ # Parse the --data-dir flag first so we can use it as a base for our other argument default values parser_pre = argparse.ArgumentParser(add_help=False) parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(modules_path), help="base path where all user data is stored", ) +parser_pre.add_argument("--models-dir", type=str, default=None, help="base path where models are stored; overrides --data-dir", ) cmd_opts_pre = parser_pre.parse_known_args()[0] data_path = cmd_opts_pre.data_dir -models_path = os.path.join(data_path, "models") +models_path = cmd_opts_pre.models_dir if cmd_opts_pre.models_dir else os.path.join(data_path, "models") extensions_dir = os.path.join(data_path, "extensions") extensions_builtin_dir = os.path.join(script_path, "extensions-builtin") config_states_dir = os.path.join(script_path, "config_states") -default_output_dir = os.path.join(data_path, "output") +default_output_dir = os.path.join(data_path, "outputs") roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf') diff --git a/modules/postprocessing.py b/modules/postprocessing.py index f14882321..caf2fe4d7 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -13,14 +13,17 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, outputs = [] + if isinstance(image, dict): + image = image["composite"] + def get_images(extras_mode, image, image_folder, input_dir): if extras_mode == 1: for img in image_folder: if isinstance(img, Image.Image): - image = img + image = images.fix_image(img) fn = '' else: - image = Image.open(os.path.abspath(img.name)) + image = images.read(os.path.abspath(img.name)) fn = os.path.splitext(img.orig_name)[0] yield image, fn elif extras_mode == 2: @@ -51,22 +54,24 @@ def get_images(extras_mode, image, image_folder, input_dir): shared.state.textinfo = name shared.state.skipped = False - if shared.state.interrupted: + if shared.state.interrupted or shared.state.stopping_generation: break if isinstance(image_placeholder, str): try: - image_data = Image.open(image_placeholder) + image_data = images.read(image_placeholder) except Exception: continue else: image_data = image_placeholder + image_data = image_data if image_data.mode in ("RGBA", "RGB") else image_data.convert("RGB") + parameters, existing_pnginfo = images.read_info_from_image(image_data) if parameters: existing_pnginfo["parameters"] = parameters - initial_pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB")) + initial_pp = scripts_postprocessing.PostprocessedImage(image_data) scripts.scripts_postproc.run(initial_pp, args) @@ -122,8 +127,6 @@ def get_images(extras_mode, image, image_folder, input_dir): if extras_mode != 2 or show_extras_results: outputs.append(pp.image) - image_data.close() - devices.torch_gc() shared.state.end() return outputs, ui_common.plaintext_to_html(infotext), '' @@ -133,13 +136,15 @@ def run_postprocessing_webui(id_task, *args, **kwargs): return run_postprocessing(*args, **kwargs) -def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): +def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True, max_side_length: int = 0): """old handler for API""" args = scripts.scripts_postproc.create_args_for_run({ "Upscale": { + "upscale_enabled": True, "upscale_mode": resize_mode, "upscale_by": upscaling_resize, + "max_side_length": max_side_length, "upscale_to_width": upscaling_resize_w, "upscale_to_height": upscaling_resize_h, "upscale_crop": upscaling_crop, diff --git a/modules/processing.py b/modules/processing.py index 64e564e00..1398d5b29 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,7 +16,7 @@ from typing import Any import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling from modules.rng import slerp # noqa: F401 from modules.sd_hijack import model_hijack from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes @@ -117,20 +117,17 @@ def txt2img_image_conditioning(sd_model, x, width, height): return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) else: - sd = sd_model.model.state_dict() - diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input is not None: - if diffusion_model_input.shape[1] == 9: - # The "masked-image" in this case will just be all 0.5 since the entire image is masked. - image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 - image_conditioning = images_tensor_to_samples(image_conditioning, - approximation_indexes.get(opts.sd_vae_encode_method)) + if sd_model.is_sdxl_inpaint: + # The "masked-image" in this case will just be all 0.5 since the entire image is masked. + image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 + image_conditioning = images_tensor_to_samples(image_conditioning, + approximation_indexes.get(opts.sd_vae_encode_method)) - # Add the fake full 1s mask to the first dimension. - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) - return image_conditioning + return image_conditioning # Dummy zero conditioning if we're not using inpainting or unclip models. # Still takes up a bit of memory, but no encoder call. @@ -154,6 +151,7 @@ class StableDiffusionProcessing: seed_resize_from_w: int = -1 seed_enable_extras: bool = True sampler_name: str = None + scheduler: str = None batch_size: int = 1 n_iter: int = 1 steps: int = 50 @@ -189,8 +187,8 @@ class StableDiffusionProcessing: script_args_value: list = field(default=None, init=False) scripts_setup_complete: bool = field(default=False, init=False) - cached_uc = [None, None] - cached_c = [None, None] + cached_uc = [None, None, None] + cached_c = [None, None, None] comments: dict = None sampler: sd_samplers_common.Sampler | None = field(default=None, init=False) @@ -229,6 +227,9 @@ class StableDiffusionProcessing: is_api: bool = field(default=False, init=False) + latents_after_sampling = [] + pixels_after_sampling = [] + def __post_init__(self): if self.sampler_index is not None: print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) @@ -239,11 +240,6 @@ def __post_init__(self): self.styles = [] self.sampler_noise_scheduler_override = None - self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond - self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn - self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin - self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf') - self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise self.extra_generation_params = self.extra_generation_params or {} self.override_settings = self.override_settings or {} @@ -261,8 +257,17 @@ def __post_init__(self): self.cached_c = StableDiffusionProcessing.cached_c self.extra_result_images = [] + self.latents_after_sampling = [] + self.pixels_after_sampling = [] self.modified_noise = None + def fill_fields_from_opts(self): + self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond + self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn + self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin + self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf') + self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise + @property def sd_model(self): return shared.sd_model @@ -394,11 +399,8 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) - sd = self.sampler.model_wrap.inner_model.model.state_dict() - diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input is not None: - if diffusion_model_input.shape[1] == 9: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + if self.sampler.model_wrap.inner_model.is_sdxl_inpaint: + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) @@ -488,12 +490,14 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr for cache in caches: if cache[0] is not None and cached_params == cache[0]: + modules.sd_hijack.model_hijack.extra_generation_params.update(cache[2]) return cache[1] cache = caches[0] with devices.autocast(): cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling) + cache[2] = modules.sd_hijack.model_hijack.extra_generation_params cache[0] = cached_params return cache[1] @@ -574,7 +578,7 @@ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt] self.all_seeds = all_seeds or p.all_seeds or [self.seed] self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed] - self.infotexts = infotexts or [info] + self.infotexts = infotexts or [info] * len(images_list) self.version = program_version() def js(self): @@ -613,7 +617,7 @@ def js(self): "version": self.version, } - return json.dumps(obj) + return json.dumps(obj, default=lambda o: None) def infotext(self, p: StableDiffusionProcessing, index): return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size) @@ -672,7 +676,53 @@ def program_version(): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None): - if index is None: + """ + this function is used to generate the infotext that is stored in the generated images, it's contains the parameters that are required to generate the imagee + Args: + p: StableDiffusionProcessing + all_prompts: list[str] + all_seeds: list[int] + all_subseeds: list[int] + comments: list[str] + iteration: int + position_in_batch: int + use_main_prompt: bool + index: int + all_negative_prompts: list[str] + + Returns: str + + Extra generation params + p.extra_generation_params dictionary allows for additional parameters to be added to the infotext + this can be use by the base webui or extensions. + To add a new entry, add a new key value pair, the dictionary key will be used as the key of the parameter in the infotext + the value generation_params can be defined as: + - str | None + - List[str|None] + - callable func(**kwargs) -> str | None + + When defined as a string, it will be used as without extra processing; this is this most common use case. + + Defining as a list allows for parameter that changes across images in the job, for example, the 'Seed' parameter. + The list should have the same length as the total number of images in the entire job. + + Defining as a callable function allows parameter cannot be generated earlier or when extra logic is required. + For example 'Hires prompt', due to reasons the hr_prompt might be changed by process in the pipeline or extensions + and may vary across different images, defining as a static string or list would not work. + + The function takes locals() as **kwargs, as such will have access to variables like 'p' and 'index'. + the base signature of the function should be: + func(**kwargs) -> str | None + optionally it can have additional arguments that will be used in the function: + func(p, index, **kwargs) -> str | None + note: for better future compatibility even though this function will have access to all variables in the locals(), + it is recommended to only use the arguments present in the function signature of create_infotext. + For actual implementation examples, see StableDiffusionProcessingTxt2Img.init > get_hr_prompt. + """ + + if use_main_prompt: + index = 0 + elif index is None: index = position_in_batch + iteration * p.batch_size if all_negative_prompts is None: @@ -683,6 +733,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter token_merging_ratio = p.get_token_merging_ratio() token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True) + prompt_text = p.main_prompt if use_main_prompt else all_prompts[index] + negative_prompt = p.main_negative_prompt if use_main_prompt else all_negative_prompts[index] + uses_ensd = opts.eta_noise_seed_delta != 0 if uses_ensd: uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p) @@ -690,6 +743,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter generation_params = { "Steps": p.steps, "Sampler": p.sampler_name, + "Schedule type": p.scheduler, "CFG scale": p.cfg_scale, "Image CFG scale": getattr(p, 'image_cfg_scale', None), "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index], @@ -712,17 +766,25 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr, "Init image hash": getattr(p, 'init_img_hash', None), "RNG": opts.randn_source if opts.randn_source != "GPU" else None, - "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, "Tiling": "True" if p.tiling else None, **p.extra_generation_params, "Version": program_version() if opts.add_version_to_infotext else None, "User": p.user if opts.add_user_name_to_info else None, } + for key, value in generation_params.items(): + try: + if isinstance(value, list): + generation_params[key] = value[index] + elif callable(value): + generation_params[key] = value(**locals()) + except Exception: + errors.report(f'Error creating infotext for key "{key}"', exc_info=True) + generation_params[key] = None + generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None]) - prompt_text = p.main_prompt if use_main_prompt else all_prompts[index] - negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else "" + negative_prompt_text = f"\nNegative prompt: {negative_prompt}" if negative_prompt else "" return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip() @@ -749,7 +811,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if k == 'sd_vae': sd_vae.reload_vae_weights() - res = process_images_inner(p) + # backwards compatibility, fix sampler and scheduler if invalid + sd_samplers.fix_p_invalid_sampler_and_scheduler(p) + + with profiling.Profiler(): + res = process_images_inner(p) finally: # restore opts to original state @@ -787,6 +853,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.refiner_checkpoint_info is None: raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}') + if hasattr(shared.sd_model, 'fix_dimensions'): + p.width, p.height = shared.sd_model.fix_dimensions(p.width, p.height) + p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra p.sd_model_hash = shared.sd_model.sd_model_hash p.sd_vae_name = sd_vae.get_loaded_vae_name() @@ -795,6 +864,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: apply_circular_forge(p.sd_model, p.tiling) modules.sd_hijack.model_hijack.clear_comments() + p.fill_fields_from_opts() p.setup_prompts() if isinstance(seed, list): @@ -845,7 +915,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] - p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) + latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C) + p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) if p.scripts is not None: p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) @@ -863,52 +934,26 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) + p.setup_conds() + + p.extra_generation_params.update(model_hijack.extra_generation_params) + # params.txt should be saved after scripts.process_batch, since the # infotext could be modified by that callback # Example: a wildcard processed by process_batch sets an extra model # strength, which is saved as "Model Strength: 1.0" in the infotext - if n == 0: + if n == 0 and not cmd_opts.no_prompt_history: with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: processed = Processed(p, []) file.write(processed.infotext(p, 0)) - p.setup_conds() - for comment in model_hijack.comments: p.comment(comment) - p.extra_generation_params.update(model_hijack.extra_generation_params) - if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - def rescale_zero_terminal_snr_abar(alphas_cumprod): - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= (alphas_bar_sqrt_T) - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas_bar[-1] = 4.8973451890853435e-08 - return alphas_bar - - if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'): - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) - - if opts.use_downcasted_alpha_bar: - p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) - if opts.sd_noise_schedule == "Zero Terminal SNR": - p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) + sd_models.apply_alpha_schedule_override(p.sd_model, p) alphas_cumprod_modifiers = p.sd_model.forge_objects.unet.model_options.get('alphas_cumprod_modifiers', []) alphas_cumprod_backup = None @@ -921,6 +966,9 @@ def rescale_zero_terminal_snr_abar(alphas_cumprod): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) + for x_sample in samples_ddim: + p.latents_after_sampling.append(x_sample) + if alphas_cumprod_backup is not None: p.sd_model.alphas_cumprod = alphas_cumprod_backup p.sd_model.forge_objects.unet.model.model_sampling.set_sigmas(((1 - p.sd_model.alphas_cumprod) / p.sd_model.alphas_cumprod) ** 0.5) @@ -933,6 +981,8 @@ def rescale_zero_terminal_snr_abar(alphas_cumprod): if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim else: + devices.test_for_nans(samples_ddim, "unet") + if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) @@ -979,7 +1029,7 @@ def infotext(index=0, use_main_prompt=False): image = Image.fromarray(x_sample) if p.scripts is not None: - pp = scripts.PostprocessImageArgs(image) + pp = scripts.PostprocessImageArgs(image, i + p.iteration * p.batch_size) p.scripts.postprocess_image(p, pp) image = pp.image @@ -1009,8 +1059,10 @@ def infotext(index=0, use_main_prompt=False): # and use it in the composite step. image, original_denoised_image = apply_overlay(image, p.paste_to, overlay_image) + p.pixels_after_sampling.append(image) + if p.scripts is not None: - pp = scripts.PostprocessImageArgs(image) + pp = scripts.PostprocessImageArgs(image, i + p.iteration * p.batch_size) p.scripts.postprocess_image_after_composite(p, pp) image = pp.image @@ -1109,12 +1161,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): hr_resize_y: int = 0 hr_checkpoint_name: str = None hr_sampler_name: str = None + hr_scheduler: str = None hr_prompt: str = '' hr_negative_prompt: str = '' force_task_id: str = None - cached_hr_uc = [None, None] - cached_hr_c = [None, None] + cached_hr_uc = [None, None, None] + cached_hr_c = [None, None, None] hr_checkpoint_info: dict = field(default=None, init=False) hr_upscale_to_x: int = field(default=0, init=False) @@ -1197,11 +1250,21 @@ def init(self, all_prompts, all_seeds, all_subseeds): if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name: self.extra_generation_params["Hires sampler"] = self.hr_sampler_name - if tuple(self.hr_prompt) != tuple(self.prompt): - self.extra_generation_params["Hires prompt"] = self.hr_prompt + def get_hr_prompt(p, index, prompt_text, **kwargs): + hr_prompt = p.all_hr_prompts[index] + return hr_prompt if hr_prompt != prompt_text else None + + def get_hr_negative_prompt(p, index, negative_prompt, **kwargs): + hr_negative_prompt = p.all_hr_negative_prompts[index] + return hr_negative_prompt if hr_negative_prompt != negative_prompt else None + + self.extra_generation_params["Hires prompt"] = get_hr_prompt + self.extra_generation_params["Hires negative prompt"] = get_hr_negative_prompt + + self.extra_generation_params["Hires schedule type"] = None # to be set in sd_samplers_kdiffusion.py - if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt): - self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt + if self.hr_scheduler is None: + self.hr_scheduler = self.scheduler self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") if self.enable_hr and self.latent_scale_mode is None: @@ -1370,6 +1433,13 @@ def save_intermediate(image, index): if self.scripts is not None: self.scripts.before_hr(self) + self.scripts.process_before_every_sampling( + p=self, + x=samples, + noise=noise, + c=self.hr_c, + uc=self.hr_uc, + ) self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy() apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True)) @@ -1568,16 +1638,23 @@ def init(self, all_prompts, all_seeds, all_subseeds): if self.inpaint_full_res: self.mask_for_overlay = image_mask mask = image_mask.convert('L') - crop_region = masking.get_crop_region(mask, self.inpaint_full_res_padding) - crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) - x1, y1, x2, y2 = crop_region - - mask = mask.crop(crop_region) - image_mask = images.resize_image(2, mask, self.width, self.height) - self.paste_to = (x1, y1, x2-x1, y2-y1) - - self.extra_generation_params["Inpaint area"] = "Only masked" - self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding + crop_region = masking.get_crop_region_v2(mask, self.inpaint_full_res_padding) + if crop_region: + crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) + x1, y1, x2, y2 = crop_region + mask = mask.crop(crop_region) + image_mask = images.resize_image(2, mask, self.width, self.height) + self.paste_to = (x1, y1, x2-x1, y2-y1) + self.extra_generation_params["Inpaint area"] = "Only masked" + self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding + else: + crop_region = None + image_mask = None + self.mask_for_overlay = None + self.inpaint_full_res = False + massage = 'Unable to perform "Inpaint Only mask" because mask is blank, switch to img2img mode.' + model_hijack.comments.append(massage) + logging.info(massage) else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) np_mask = np.array(image_mask) @@ -1588,6 +1665,9 @@ def init(self, all_prompts, all_seeds, all_subseeds): latent_mask = self.latent_mask if self.latent_mask is not None else image_mask + if self.scripts is not None: + self.scripts.before_process_init_images(self, dict(crop_region=crop_region, image_mask=image_mask)) + add_color_corrections = opts.img2img_color_correction and self.color_corrections is None if add_color_corrections: self.color_corrections = [] @@ -1605,6 +1685,8 @@ def init(self, all_prompts, all_seeds, all_subseeds): image = images.resize_image(self.resize_mode, image, self.width, self.height) if image_mask is not None: + if self.mask_for_overlay.size != (image.width, image.height): + self.mask_for_overlay = images.resize_image(self.resize_mode, self.mask_for_overlay, image.width, image.height) image_masked = Image.new('RGBa', (image.width, image.height)) image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) @@ -1663,10 +1745,10 @@ def init(self, all_prompts, all_seeds, all_subseeds): latmask = latmask[0] if self.mask_round: latmask = np.around(latmask) - latmask = np.tile(latmask[None], (4, 1, 1)) + latmask = np.tile(latmask[None], (self.init_latent.shape[1], 1, 1)) - self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) - self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype) + self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(devices.dtype) + self.nmask = torch.asarray(latmask).to(shared.device).type(devices.dtype) # this needs to be fixed to be done in sample() using actual seeds for batches if self.inpainting_fill == 2: diff --git a/modules/processing_scripts/comments.py b/modules/processing_scripts/comments.py index 638e39f29..cf81dfd8b 100644 --- a/modules/processing_scripts/comments.py +++ b/modules/processing_scripts/comments.py @@ -26,6 +26,13 @@ def process(self, p, *args): p.main_prompt = strip_comments(p.main_prompt) p.main_negative_prompt = strip_comments(p.main_negative_prompt) + if getattr(p, 'enable_hr', False): + p.all_hr_prompts = [strip_comments(x) for x in p.all_hr_prompts] + p.all_hr_negative_prompts = [strip_comments(x) for x in p.all_hr_negative_prompts] + + p.hr_prompt = strip_comments(p.hr_prompt) + p.hr_negative_prompt = strip_comments(p.hr_negative_prompt) + def before_token_counter(params: script_callbacks.BeforeTokenCounterParams): if not shared.opts.enable_prompt_comments: diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py index ba33d8a4b..01504a5f4 100644 --- a/modules/processing_scripts/refiner.py +++ b/modules/processing_scripts/refiner.py @@ -22,7 +22,7 @@ def show(self, is_img2img): def ui(self, is_img2img): with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner: with gr.Row(): - refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation") + refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=["", *sd_models.checkpoint_tiles()], value='', tooltip="switch to another model in the middle of generation") create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh")) refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation") diff --git a/modules/processing_scripts/sampler.py b/modules/processing_scripts/sampler.py new file mode 100644 index 000000000..1d465552c --- /dev/null +++ b/modules/processing_scripts/sampler.py @@ -0,0 +1,45 @@ +import gradio as gr + +from modules import scripts, sd_samplers, sd_schedulers, shared +from modules.infotext_utils import PasteField +from modules.ui_components import FormRow, FormGroup + + +class ScriptSampler(scripts.ScriptBuiltinUI): + section = "sampler" + + def __init__(self): + self.steps = None + self.sampler_name = None + self.scheduler = None + + def title(self): + return "Sampler" + + def ui(self, is_img2img): + sampler_names = [x.name for x in sd_samplers.visible_samplers()] + scheduler_names = [x.label for x in sd_schedulers.schedulers] + + if shared.opts.samplers_in_dropdown: + with FormRow(elem_id=f"sampler_selection_{self.tabname}"): + self.sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{self.tabname}_sampling", choices=sampler_names, value=sampler_names[0]) + self.scheduler = gr.Dropdown(label='Schedule type', elem_id=f"{self.tabname}_scheduler", choices=scheduler_names, value=scheduler_names[0]) + self.steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{self.tabname}_steps", label="Sampling steps", value=20) + else: + with FormGroup(elem_id=f"sampler_selection_{self.tabname}"): + self.steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{self.tabname}_steps", label="Sampling steps", value=20) + self.sampler_name = gr.Radio(label='Sampling method', elem_id=f"{self.tabname}_sampling", choices=sampler_names, value=sampler_names[0]) + self.scheduler = gr.Dropdown(label='Schedule type', elem_id=f"{self.tabname}_scheduler", choices=scheduler_names, value=scheduler_names[0]) + + self.infotext_fields = [ + PasteField(self.steps, "Steps", api="steps"), + PasteField(self.sampler_name, sd_samplers.get_sampler_from_infotext, api="sampler_name"), + PasteField(self.scheduler, sd_samplers.get_scheduler_from_infotext, api="scheduler"), + ] + + return self.steps, self.sampler_name, self.scheduler + + def setup(self, p, steps, sampler_name, scheduler): + p.steps = steps + p.sampler_name = sampler_name + p.scheduler = scheduler diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index 7a4c01598..717e8ef63 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -34,7 +34,7 @@ def ui(self, is_img2img): random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), tooltip="Set seed to -1, which will cause a new random number to be used every time") reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), tooltip="Reuse seed from last generation, mostly useful if it was randomized") - seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False) + seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False, scale=0, min_width=60) with gr.Group(visible=False, elem_id=self.elem_id("seed_extras")) as seed_extras: with gr.Row(elem_id=self.elem_id("subseed_row")): diff --git a/modules/profiling.py b/modules/profiling.py new file mode 100644 index 000000000..2729e0f30 --- /dev/null +++ b/modules/profiling.py @@ -0,0 +1,46 @@ +import torch + +from modules import shared, ui_gradio_extensions + + +class Profiler: + def __init__(self): + if not shared.opts.profiling_enable: + self.profiler = None + return + + activities = [] + if "CPU" in shared.opts.profiling_activities: + activities.append(torch.profiler.ProfilerActivity.CPU) + if "CUDA" in shared.opts.profiling_activities: + activities.append(torch.profiler.ProfilerActivity.CUDA) + + if not activities: + self.profiler = None + return + + self.profiler = torch.profiler.profile( + activities=activities, + record_shapes=shared.opts.profiling_record_shapes, + profile_memory=shared.opts.profiling_profile_memory, + with_stack=shared.opts.profiling_with_stack + ) + + def __enter__(self): + if self.profiler: + self.profiler.__enter__() + + return self + + def __exit__(self, exc_type, exc, exc_tb): + if self.profiler: + shared.state.textinfo = "Finishing profile..." + + self.profiler.__exit__(exc_type, exc, exc_tb) + + self.profiler.export_chrome_trace(shared.opts.profiling_filename) + + +def webpath(): + return ui_gradio_extensions.webpath(shared.opts.profiling_filename) + diff --git a/modules/progress.py b/modules/progress.py index 85255e821..6ab789cdf 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -1,3 +1,4 @@ +from __future__ import annotations import base64 import io import time @@ -66,11 +67,11 @@ class ProgressResponse(BaseModel): active: bool = Field(title="Whether the task is being worked on right now") queued: bool = Field(title="Whether the task is in queue") completed: bool = Field(title="Whether the task has already finished") - progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1") - eta: float = Field(default=None, title="ETA in secs") - live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri") - id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image") - textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.") + progress: float | None = Field(default=None, title="Progress", description="The progress with a range of 0 to 1") + eta: float | None = Field(default=None, title="ETA in secs") + live_preview: str | None = Field(default=None, title="Live preview image", description="Current live preview; a data: uri") + id_live_preview: int | None = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image") + textinfo: str | None = Field(default=None, title="Info text", description="Info text used by WebUI.") def setup_progress_api(app): diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index c8b423a0e..70aefbc77 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -268,7 +268,7 @@ def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, class DictWithShape(dict): - def __init__(self, x): + def __init__(self, x, shape=None): super().__init__() self.update(x) diff --git a/modules/rng.py b/modules/rng.py index 4a3bbb207..f3afb4def 100644 --- a/modules/rng.py +++ b/modules/rng.py @@ -40,7 +40,7 @@ def randn_local(seed, shape): def randn_like(x): - """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. + """Generate a tensor with random numbers from a normal distribution using the previously initialized generator. Use either randn() or manual_seed() to initialize the generator.""" @@ -54,7 +54,7 @@ def randn_like(x): def randn_without_seed(shape, generator=None): - """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. + """Generate a tensor with random numbers from a normal distribution using the previously initialized generator. Use either randn() or manual_seed() to initialize the generator.""" diff --git a/modules/safe.py b/modules/safe.py index fe771c1b8..d1e242e89 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -64,8 +64,8 @@ def find_class(self, module, name): raise Exception(f"global '{module}/{name}' is forbidden") -# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/' -allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$") +# Regular expression that accepts 'dirname/version', 'dirname/byteorder', 'dirname/data.pkl', '.data/serialization_id', and 'dirname/data/' +allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|byteorder|.data/serialization_id|(data\.pkl))$") data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$") def check_zip_filenames(filename, names): diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 2c50f43c5..9059d4d93 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,13 +1,14 @@ +from __future__ import annotations + import dataclasses import inspect import os -from collections import namedtuple from typing import Optional, Any from fastapi import FastAPI from gradio import Blocks -from modules import errors, timer +from modules import errors, timer, extensions, shared, util def report_exception(c, job): @@ -116,7 +117,105 @@ class BeforeTokenCounterParams: is_positive: bool = True -ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) +@dataclasses.dataclass +class ScriptCallback: + script: str + callback: any + name: str = "unnamed" + + +def add_callback(callbacks, fun, *, name=None, category='unknown', filename=None): + if filename is None: + stack = [x for x in inspect.stack() if x.filename != __file__] + filename = stack[0].filename if stack else 'unknown file' + + extension = extensions.find_extension(filename) + extension_name = extension.canonical_name if extension else 'base' + + callback_name = f"{extension_name}/{os.path.basename(filename)}/{category}" + if name is not None: + callback_name += f'/{name}' + + unique_callback_name = callback_name + for index in range(1000): + existing = any(x.name == unique_callback_name for x in callbacks) + if not existing: + break + + unique_callback_name = f'{callback_name}-{index+1}' + + callbacks.append(ScriptCallback(filename, fun, unique_callback_name)) + + +def sort_callbacks(category, unordered_callbacks, *, enable_user_sort=True): + callbacks = unordered_callbacks.copy() + callback_lookup = {x.name: x for x in callbacks} + dependencies = {} + + order_instructions = {} + for extension in extensions.extensions: + for order_instruction in extension.metadata.list_callback_order_instructions(): + if order_instruction.name in callback_lookup: + if order_instruction.name not in order_instructions: + order_instructions[order_instruction.name] = [] + + order_instructions[order_instruction.name].append(order_instruction) + + if order_instructions: + for callback in callbacks: + dependencies[callback.name] = [] + + for callback in callbacks: + for order_instruction in order_instructions.get(callback.name, []): + for after in order_instruction.after: + if after not in callback_lookup: + continue + + dependencies[callback.name].append(after) + + for before in order_instruction.before: + if before not in callback_lookup: + continue + + dependencies[before].append(callback.name) + + sorted_names = util.topological_sort(dependencies) + callbacks = [callback_lookup[x] for x in sorted_names] + + if enable_user_sort: + for name in reversed(getattr(shared.opts, 'prioritized_callbacks_' + category, [])): + index = next((i for i, callback in enumerate(callbacks) if callback.name == name), None) + if index is not None: + callbacks.insert(0, callbacks.pop(index)) + + return callbacks + + +def ordered_callbacks(category, unordered_callbacks=None, *, enable_user_sort=True): + if unordered_callbacks is None: + unordered_callbacks = callback_map.get('callbacks_' + category, []) + + if not enable_user_sort: + return sort_callbacks(category, unordered_callbacks, enable_user_sort=False) + + callbacks = ordered_callbacks_map.get(category) + if callbacks is not None and len(callbacks) == len(unordered_callbacks): + return callbacks + + callbacks = sort_callbacks(category, unordered_callbacks) + + ordered_callbacks_map[category] = callbacks + return callbacks + + +def enumerate_callbacks(): + for category, callbacks in callback_map.items(): + if category.startswith('callbacks_'): + category = category[10:] + + yield category, callbacks + + callback_map = dict( callbacks_app_started=[], callbacks_model_loaded=[], @@ -140,18 +239,19 @@ class BeforeTokenCounterParams: callbacks_list_unets=[], callbacks_before_token_counter=[], ) -event_subscriber_map = dict( - callbacks_setting_updated=[], -) + +ordered_callbacks_map = {} def clear_callbacks(): for callback_list in callback_map.values(): callback_list.clear() + ordered_callbacks_map.clear() + def app_started_callback(demo: Optional[Blocks], app: FastAPI): - for c in callback_map['callbacks_app_started']: + for c in ordered_callbacks('app_started'): try: c.callback(demo, app) timer.startup_timer.record(os.path.basename(c.script)) @@ -160,7 +260,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI): def app_reload_callback(): - for c in callback_map['callbacks_on_reload']: + for c in ordered_callbacks('on_reload'): try: c.callback() except Exception: @@ -168,7 +268,7 @@ def app_reload_callback(): def model_loaded_callback(sd_model): - for c in callback_map['callbacks_model_loaded']: + for c in ordered_callbacks('model_loaded'): try: c.callback(sd_model) except Exception: @@ -178,7 +278,7 @@ def model_loaded_callback(sd_model): def ui_tabs_callback(): res = [] - for c in callback_map['callbacks_ui_tabs']: + for c in ordered_callbacks('ui_tabs'): try: res += c.callback() or [] except Exception: @@ -188,7 +288,7 @@ def ui_tabs_callback(): def ui_train_tabs_callback(params: UiTrainTabParams): - for c in callback_map['callbacks_ui_train_tabs']: + for c in ordered_callbacks('ui_train_tabs'): try: c.callback(params) except Exception: @@ -196,7 +296,7 @@ def ui_train_tabs_callback(params: UiTrainTabParams): def ui_settings_callback(): - for c in callback_map['callbacks_ui_settings']: + for c in ordered_callbacks('ui_settings'): try: c.callback() except Exception: @@ -204,7 +304,7 @@ def ui_settings_callback(): def before_image_saved_callback(params: ImageSaveParams): - for c in callback_map['callbacks_before_image_saved']: + for c in ordered_callbacks('before_image_saved'): try: c.callback(params) except Exception: @@ -212,7 +312,7 @@ def before_image_saved_callback(params: ImageSaveParams): def image_saved_callback(params: ImageSaveParams): - for c in callback_map['callbacks_image_saved']: + for c in ordered_callbacks('image_saved'): try: c.callback(params) except Exception: @@ -220,7 +320,7 @@ def image_saved_callback(params: ImageSaveParams): def extra_noise_callback(params: ExtraNoiseParams): - for c in callback_map['callbacks_extra_noise']: + for c in ordered_callbacks('extra_noise'): try: c.callback(params) except Exception: @@ -228,7 +328,7 @@ def extra_noise_callback(params: ExtraNoiseParams): def cfg_denoiser_callback(params: CFGDenoiserParams): - for c in callback_map['callbacks_cfg_denoiser']: + for c in ordered_callbacks('cfg_denoiser'): try: c.callback(params) except Exception: @@ -236,7 +336,7 @@ def cfg_denoiser_callback(params: CFGDenoiserParams): def cfg_denoised_callback(params: CFGDenoisedParams): - for c in callback_map['callbacks_cfg_denoised']: + for c in ordered_callbacks('cfg_denoised'): try: c.callback(params) except Exception: @@ -244,7 +344,7 @@ def cfg_denoised_callback(params: CFGDenoisedParams): def cfg_after_cfg_callback(params: AfterCFGCallbackParams): - for c in callback_map['callbacks_cfg_after_cfg']: + for c in ordered_callbacks('cfg_after_cfg'): try: c.callback(params) except Exception: @@ -252,7 +352,7 @@ def cfg_after_cfg_callback(params: AfterCFGCallbackParams): def before_component_callback(component, **kwargs): - for c in callback_map['callbacks_before_component']: + for c in ordered_callbacks('before_component'): try: c.callback(component, **kwargs) except Exception: @@ -260,7 +360,7 @@ def before_component_callback(component, **kwargs): def after_component_callback(component, **kwargs): - for c in callback_map['callbacks_after_component']: + for c in ordered_callbacks('after_component'): try: c.callback(component, **kwargs) except Exception: @@ -268,7 +368,7 @@ def after_component_callback(component, **kwargs): def image_grid_callback(params: ImageGridLoopParams): - for c in callback_map['callbacks_image_grid']: + for c in ordered_callbacks('image_grid'): try: c.callback(params) except Exception: @@ -276,7 +376,7 @@ def image_grid_callback(params: ImageGridLoopParams): def infotext_pasted_callback(infotext: str, params: dict[str, Any]): - for c in callback_map['callbacks_infotext_pasted']: + for c in ordered_callbacks('infotext_pasted'): try: c.callback(infotext, params) except Exception: @@ -284,7 +384,7 @@ def infotext_pasted_callback(infotext: str, params: dict[str, Any]): def script_unloaded_callback(): - for c in reversed(callback_map['callbacks_script_unloaded']): + for c in reversed(ordered_callbacks('script_unloaded')): try: c.callback() except Exception: @@ -292,7 +392,7 @@ def script_unloaded_callback(): def before_ui_callback(): - for c in reversed(callback_map['callbacks_before_ui']): + for c in reversed(ordered_callbacks('before_ui')): try: c.callback() except Exception: @@ -302,7 +402,7 @@ def before_ui_callback(): def list_optimizers_callback(): res = [] - for c in callback_map['callbacks_list_optimizers']: + for c in ordered_callbacks('list_optimizers'): try: c.callback(res) except Exception: @@ -314,7 +414,7 @@ def list_optimizers_callback(): def list_unets_callback(): res = [] - for c in callback_map['callbacks_list_unets']: + for c in ordered_callbacks('list_unets'): try: c.callback(res) except Exception: @@ -324,37 +424,13 @@ def list_unets_callback(): def before_token_counter_callback(params: BeforeTokenCounterParams): - for c in callback_map['callbacks_before_token_counter']: + for c in ordered_callbacks('before_token_counter'): try: c.callback(params) except Exception: report_exception(c, 'before_token_counter') -def setting_updated_event_subscriber_chain(handler, component, setting_name: str): - """ - Arguments: - - handler: The returned handler from calling an event subscriber. - - component: The component that is updated. The component should provide - the value of setting after update. - - setting_name: The name of the setting. - """ - for param in event_subscriber_map['callbacks_setting_updated']: - handler = handler.then( - fn=lambda *args: param["fn"](*args, setting_name), - inputs=param["inputs"] + [component], - outputs=param["outputs"], - show_progress=False, - ) - - -def add_callback(callbacks, fun): - stack = [x for x in inspect.stack() if x.filename != __file__] - filename = stack[0].filename if stack else 'unknown file' - - callbacks.append(ScriptCallback(filename, fun)) - - def remove_current_script_callbacks(): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if stack else 'unknown file' @@ -363,32 +439,38 @@ def remove_current_script_callbacks(): for callback_list in callback_map.values(): for callback_to_remove in [cb for cb in callback_list if cb.script == filename]: callback_list.remove(callback_to_remove) + for ordered_callbacks_list in ordered_callbacks_map.values(): + for callback_to_remove in [cb for cb in ordered_callbacks_list if cb.script == filename]: + ordered_callbacks_list.remove(callback_to_remove) def remove_callbacks_for_function(callback_func): for callback_list in callback_map.values(): for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]: callback_list.remove(callback_to_remove) + for ordered_callback_list in ordered_callbacks_map.values(): + for callback_to_remove in [cb for cb in ordered_callback_list if cb.callback == callback_func]: + ordered_callback_list.remove(callback_to_remove) -def on_app_started(callback): +def on_app_started(callback, *, name=None): """register a function to be called when the webui started, the gradio `Block` component and fastapi `FastAPI` object are passed as the arguments""" - add_callback(callback_map['callbacks_app_started'], callback) + add_callback(callback_map['callbacks_app_started'], callback, name=name, category='app_started') -def on_before_reload(callback): +def on_before_reload(callback, *, name=None): """register a function to be called just before the server reloads.""" - add_callback(callback_map['callbacks_on_reload'], callback) + add_callback(callback_map['callbacks_on_reload'], callback, name=name, category='on_reload') -def on_model_loaded(callback): +def on_model_loaded(callback, *, name=None): """register a function to be called when the stable diffusion model is created; the model is passed as an argument; this function is also called when the script is reloaded. """ - add_callback(callback_map['callbacks_model_loaded'], callback) + add_callback(callback_map['callbacks_model_loaded'], callback, name=name, category='model_loaded') -def on_ui_tabs(callback): +def on_ui_tabs(callback, *, name=None): """register a function to be called when the UI is creating new tabs. The function must either return a None, which means no new tabs to be added, or a list, where each element is a tuple: @@ -398,71 +480,71 @@ def on_ui_tabs(callback): title is tab text displayed to user in the UI elem_id is HTML id for the tab """ - add_callback(callback_map['callbacks_ui_tabs'], callback) + add_callback(callback_map['callbacks_ui_tabs'], callback, name=name, category='ui_tabs') -def on_ui_train_tabs(callback): +def on_ui_train_tabs(callback, *, name=None): """register a function to be called when the UI is creating new tabs for the train tab. Create your new tabs with gr.Tab. """ - add_callback(callback_map['callbacks_ui_train_tabs'], callback) + add_callback(callback_map['callbacks_ui_train_tabs'], callback, name=name, category='ui_train_tabs') -def on_ui_settings(callback): +def on_ui_settings(callback, *, name=None): """register a function to be called before UI settings are populated; add your settings by using shared.opts.add_option(shared.OptionInfo(...)) """ - add_callback(callback_map['callbacks_ui_settings'], callback) + add_callback(callback_map['callbacks_ui_settings'], callback, name=name, category='ui_settings') -def on_before_image_saved(callback): +def on_before_image_saved(callback, *, name=None): """register a function to be called before an image is saved to a file. The callback is called with one argument: - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object. """ - add_callback(callback_map['callbacks_before_image_saved'], callback) + add_callback(callback_map['callbacks_before_image_saved'], callback, name=name, category='before_image_saved') -def on_image_saved(callback): +def on_image_saved(callback, *, name=None): """register a function to be called after an image is saved to a file. The callback is called with one argument: - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing. """ - add_callback(callback_map['callbacks_image_saved'], callback) + add_callback(callback_map['callbacks_image_saved'], callback, name=name, category='image_saved') -def on_extra_noise(callback): +def on_extra_noise(callback, *, name=None): """register a function to be called before adding extra noise in img2img or hires fix; The callback is called with one argument: - params: ExtraNoiseParams - contains noise determined by seed and latent representation of image """ - add_callback(callback_map['callbacks_extra_noise'], callback) + add_callback(callback_map['callbacks_extra_noise'], callback, name=name, category='extra_noise') -def on_cfg_denoiser(callback): +def on_cfg_denoiser(callback, *, name=None): """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs. The callback is called with one argument: - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details. """ - add_callback(callback_map['callbacks_cfg_denoiser'], callback) + add_callback(callback_map['callbacks_cfg_denoiser'], callback, name=name, category='cfg_denoiser') -def on_cfg_denoised(callback): +def on_cfg_denoised(callback, *, name=None): """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs. The callback is called with one argument: - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details. """ - add_callback(callback_map['callbacks_cfg_denoised'], callback) + add_callback(callback_map['callbacks_cfg_denoised'], callback, name=name, category='cfg_denoised') -def on_cfg_after_cfg(callback): +def on_cfg_after_cfg(callback, *, name=None): """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed. The callback is called with one argument: - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation. """ - add_callback(callback_map['callbacks_cfg_after_cfg'], callback) + add_callback(callback_map['callbacks_cfg_after_cfg'], callback, name=name, category='cfg_after_cfg') -def on_before_component(callback): +def on_before_component(callback, *, name=None): """register a function to be called before a component is created. The callback is called with arguments: - component - gradio component that is about to be created. @@ -471,72 +553,61 @@ def on_before_component(callback): Use elem_id/label fields of kwargs to figure out which component it is. This can be useful to inject your own components somewhere in the middle of vanilla UI. """ - add_callback(callback_map['callbacks_before_component'], callback) + add_callback(callback_map['callbacks_before_component'], callback, name=name, category='before_component') -def on_after_component(callback): +def on_after_component(callback, *, name=None): """register a function to be called after a component is created. See on_before_component for more.""" - add_callback(callback_map['callbacks_after_component'], callback) + add_callback(callback_map['callbacks_after_component'], callback, name=name, category='after_component') -def on_image_grid(callback): +def on_image_grid(callback, *, name=None): """register a function to be called before making an image grid. The callback is called with one argument: - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified. """ - add_callback(callback_map['callbacks_image_grid'], callback) + add_callback(callback_map['callbacks_image_grid'], callback, name=name, category='image_grid') -def on_infotext_pasted(callback): +def on_infotext_pasted(callback, *, name=None): """register a function to be called before applying an infotext. The callback is called with two arguments: - infotext: str - raw infotext. - result: dict[str, any] - parsed infotext parameters. """ - add_callback(callback_map['callbacks_infotext_pasted'], callback) + add_callback(callback_map['callbacks_infotext_pasted'], callback, name=name, category='infotext_pasted') -def on_script_unloaded(callback): +def on_script_unloaded(callback, *, name=None): """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that the script did should be reverted here""" - add_callback(callback_map['callbacks_script_unloaded'], callback) + add_callback(callback_map['callbacks_script_unloaded'], callback, name=name, category='script_unloaded') -def on_before_ui(callback): +def on_before_ui(callback, *, name=None): """register a function to be called before the UI is created.""" - add_callback(callback_map['callbacks_before_ui'], callback) + add_callback(callback_map['callbacks_before_ui'], callback, name=name, category='before_ui') -def on_list_optimizers(callback): +def on_list_optimizers(callback, *, name=None): """register a function to be called when UI is making a list of cross attention optimization options. The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization to it.""" - add_callback(callback_map['callbacks_list_optimizers'], callback) + add_callback(callback_map['callbacks_list_optimizers'], callback, name=name, category='list_optimizers') -def on_list_unets(callback): +def on_list_unets(callback, *, name=None): """register a function to be called when UI is making a list of alternative options for unet. The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it.""" - add_callback(callback_map['callbacks_list_unets'], callback) + add_callback(callback_map['callbacks_list_unets'], callback, name=name, category='list_unets') -def on_before_token_counter(callback): +def on_before_token_counter(callback, *, name=None): """register a function to be called when UI is counting tokens for a prompt. The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary.""" - add_callback(callback_map['callbacks_before_token_counter'], callback) - - -def on_setting_updated_subscriber(subscriber_params): - """register a function to be called after settings update. `subscriber_params` - should contain necessary fields to register an gradio event handler. Necessary - fields are ["fn", "outputs", "inputs"]. - Setting name and setting value after update will be append to inputs. So be - sure to handle these extra params when defining the callback function. - """ - event_subscriber_map['callbacks_setting_updated'].append(subscriber_params) - + add_callback(callback_map['callbacks_before_token_counter'], callback, name=name, category='before_token_counter') diff --git a/modules/script_loading.py b/modules/script_loading.py index 0d55f1932..cccb30966 100644 --- a/modules/script_loading.py +++ b/modules/script_loading.py @@ -4,11 +4,15 @@ from modules import errors +loaded_scripts = {} + + def load_module(path): module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path) module = importlib.util.module_from_spec(module_spec) module_spec.loader.exec_module(module) + loaded_scripts[path] = module return module diff --git a/modules/scripts.py b/modules/scripts.py index 79c5cb767..685403bcb 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -7,7 +7,9 @@ import gradio as gr -from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer +from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer, util + +topological_sort = util.topological_sort AlwaysVisible = object() @@ -28,8 +30,9 @@ def __init__(self, samples): self.samples = samples class PostprocessImageArgs: - def __init__(self, image): + def __init__(self, image, index): self.image = image + self.index = index class PostProcessMaskOverlayArgs: def __init__(self, index, mask_for_overlay, overlay_image): @@ -92,7 +95,7 @@ class Script: """If true, the script setup will only be run in Gradio UI, not in API""" controls = None - """A list of controls retured by the ui().""" + """A list of controls returned by the ui().""" sorting_priority = 0 """Larger number will appear downwards in the UI.""" @@ -112,7 +115,7 @@ def ui(self, is_img2img): def show(self, is_img2img): """ - is_img2img is True if this function is called for the img2img interface, and Fasle otherwise + is_img2img is True if this function is called for the img2img interface, and False otherwise This function should return: - False if the script should not be shown in UI at all @@ -141,7 +144,6 @@ def setup(self, p, *args): """ pass - def before_process(self, p, *args): """ This function is called very early during processing begins for AlwaysVisible scripts. @@ -194,7 +196,6 @@ def process_before_every_sampling(self, p, *args, **kwargs): Similar to process(), called before every sampling. If you use high-res fix, this will be called two times. """ - pass def process_batch(self, p, *args, **kwargs): @@ -362,6 +363,9 @@ def elem_id(self, item_id): return f'{tabname}{item_id}' + def show(self, is_img2img): + return AlwaysVisible + current_basedir = paths.script_path @@ -380,29 +384,6 @@ def basedir(): postprocessing_scripts_data = [] ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) -def topological_sort(dependencies): - """Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies. - Ignores errors relating to missing dependeencies or circular dependencies - """ - - visited = {} - result = [] - - def inner(name): - visited[name] = True - - for dep in dependencies.get(name, []): - if dep in dependencies and dep not in visited: - inner(dep) - - result.append(name) - - for depname in dependencies: - if depname not in visited: - inner(depname) - - return result - @dataclass class ScriptWithDependencies: @@ -579,6 +560,25 @@ def __init__(self): self.paste_field_names = [] self.inputs = [None] + self.callback_map = {} + self.callback_names = [ + 'before_process', + 'process', + 'before_process_batch', + 'after_extra_networks_activate', + 'process_batch', + 'postprocess', + 'postprocess_batch', + 'postprocess_batch_list', + 'post_sample', + 'on_mask_blend', + 'postprocess_image', + 'postprocess_maskoverlay', + 'postprocess_image_after_composite', + 'before_component', + 'after_component', + ] + self.on_before_component_elem_id = {} """dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks""" @@ -617,6 +617,8 @@ def initialize_scripts(self, is_img2img): self.scripts.append(script) self.selectable_scripts.append(script) + self.callback_map.clear() + self.apply_on_before_component_callbacks() def apply_on_before_component_callbacks(self): @@ -756,12 +758,17 @@ def init_field(title): def onload_script_visibility(params): title = params.get('Script', None) if title: - title_index = self.titles.index(title) - visibility = title_index == self.script_load_ctr - self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles) - return gr.update(visible=visibility) - else: - return gr.update(visible=False) + try: + title_index = self.titles.index(title) + visibility = title_index == self.script_load_ctr + self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles) + return gr.update(visible=visibility) + except ValueError: + params['Script'] = None + massage = f'Cannot find Script: "{title}"' + print(massage) + gr.Warning(massage) + return gr.update(visible=False) self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None')))) self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts]) @@ -788,8 +795,42 @@ def run(self, p, *args): return processed + def list_scripts_for_method(self, method_name): + if method_name in ('before_component', 'after_component'): + return self.scripts + else: + return self.alwayson_scripts + + def create_ordered_callbacks_list(self, method_name, *, enable_user_sort=True): + script_list = self.list_scripts_for_method(method_name) + category = f'script_{method_name}' + callbacks = [] + + for script in script_list: + if getattr(script.__class__, method_name, None) == getattr(Script, method_name, None): + continue + + script_callbacks.add_callback(callbacks, script, category=category, name=script.__class__.__name__, filename=script.filename) + + return script_callbacks.sort_callbacks(category, callbacks, enable_user_sort=enable_user_sort) + + def ordered_callbacks(self, method_name, *, enable_user_sort=True): + script_list = self.list_scripts_for_method(method_name) + category = f'script_{method_name}' + + scrpts_len, callbacks = self.callback_map.get(category, (-1, None)) + + if callbacks is None or scrpts_len != len(script_list): + callbacks = self.create_ordered_callbacks_list(method_name, enable_user_sort=enable_user_sort) + self.callback_map[category] = len(script_list), callbacks + + return callbacks + + def ordered_scripts(self, method_name): + return [x.callback for x in self.ordered_callbacks(method_name)] + def before_process(self, p): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('before_process'): try: script_args = p.script_args[script.args_from:script.args_to] script.before_process(p, *script_args) @@ -797,23 +838,39 @@ def before_process(self, p): errors.report(f"Error running before_process: {script.filename}", exc_info=True) def process(self, p): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('process'): try: script_args = p.script_args[script.args_from:script.args_to] script.process(p, *script_args) except Exception: errors.report(f"Error running process: {script.filename}", exc_info=True) + def process_before_every_sampling(self, p, **kwargs): + for script in self.ordered_scripts('process_before_every_sampling'): + try: + script_args = p.script_args[script.args_from:script.args_to] + script.process_before_every_sampling(p, *script_args, **kwargs) + except Exception: + errors.report(f"Error running process_before_every_sampling: {script.filename}", exc_info=True) + def before_process_batch(self, p, **kwargs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('before_process_batch'): try: script_args = p.script_args[script.args_from:script.args_to] script.before_process_batch(p, *script_args, **kwargs) except Exception: errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True) + def before_process_init_images(self, p, pp, **kwargs): + for script in self.ordered_scripts('before_process_init_images'): + try: + script_args = p.script_args[script.args_from:script.args_to] + script.before_process_init_images(p, pp, *script_args, **kwargs) + except Exception: + errors.report(f"Error running before_process_init_images: {script.filename}", exc_info=True) + def after_extra_networks_activate(self, p, **kwargs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('after_extra_networks_activate'): try: script_args = p.script_args[script.args_from:script.args_to] script.after_extra_networks_activate(p, *script_args, **kwargs) @@ -821,7 +878,7 @@ def after_extra_networks_activate(self, p, **kwargs): errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True) def process_batch(self, p, **kwargs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('process_batch'): try: script_args = p.script_args[script.args_from:script.args_to] script.process_batch(p, *script_args, **kwargs) @@ -837,7 +894,7 @@ def process_before_every_sampling(self, p, **kwargs): errors.report(f"Error running process_before_every_sampling: {script.filename}", exc_info=True) def postprocess(self, p, processed): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('postprocess'): try: script_args = p.script_args[script.args_from:script.args_to] script.postprocess(p, processed, *script_args) @@ -845,7 +902,7 @@ def postprocess(self, p, processed): errors.report(f"Error running postprocess: {script.filename}", exc_info=True) def postprocess_batch(self, p, images, **kwargs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('postprocess_batch'): try: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_batch(p, *script_args, images=images, **kwargs) @@ -853,7 +910,7 @@ def postprocess_batch(self, p, images, **kwargs): errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True) def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('postprocess_batch_list'): try: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_batch_list(p, pp, *script_args, **kwargs) @@ -861,7 +918,7 @@ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs): errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True) def post_sample(self, p, ps: PostSampleArgs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('post_sample'): try: script_args = p.script_args[script.args_from:script.args_to] script.post_sample(p, ps, *script_args) @@ -869,7 +926,7 @@ def post_sample(self, p, ps: PostSampleArgs): errors.report(f"Error running post_sample: {script.filename}", exc_info=True) def on_mask_blend(self, p, mba: MaskBlendArgs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('on_mask_blend'): try: script_args = p.script_args[script.args_from:script.args_to] script.on_mask_blend(p, mba, *script_args) @@ -877,7 +934,7 @@ def on_mask_blend(self, p, mba: MaskBlendArgs): errors.report(f"Error running post_sample: {script.filename}", exc_info=True) def postprocess_image(self, p, pp: PostprocessImageArgs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('postprocess_image'): try: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_image(p, pp, *script_args) @@ -885,7 +942,7 @@ def postprocess_image(self, p, pp: PostprocessImageArgs): errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('postprocess_maskoverlay'): try: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_maskoverlay(p, ppmo, *script_args) @@ -893,7 +950,7 @@ def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs): errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('postprocess_image_after_composite'): try: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_image_after_composite(p, pp, *script_args) @@ -907,7 +964,7 @@ def before_component(self, component, **kwargs): except Exception: errors.report(f"Error running on_before_component: {script.filename}", exc_info=True) - for script in self.scripts: + for script in self.ordered_scripts('before_component'): try: script.before_component(component, **kwargs) except Exception: @@ -920,7 +977,7 @@ def after_component(self, component, **kwargs): except Exception: errors.report(f"Error running on_after_component: {script.filename}", exc_info=True) - for script in self.scripts: + for script in self.ordered_scripts('after_component'): try: script.after_component(component, **kwargs) except Exception: @@ -948,7 +1005,7 @@ def reload_sources(self, cache): self.scripts[si].args_to = args_to def before_hr(self, p): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('before_hr'): try: script_args = p.script_args[script.args_from:script.args_to] script.before_hr(p, *script_args) @@ -956,7 +1013,7 @@ def before_hr(self, p): errors.report(f"Error running before_hr: {script.filename}", exc_info=True) def setup_scrips(self, p, *, is_ui=True): - for script in self.alwayson_scripts: + for script in self.ordered_scripts('setup'): if not is_ui and script.setup_for_ui_only: continue diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py index 901cad080..4b3b7afda 100644 --- a/modules/scripts_postprocessing.py +++ b/modules/scripts_postprocessing.py @@ -143,6 +143,7 @@ def scripts_in_preferred_order(self): self.initialize_scripts(modules.scripts.postprocessing_scripts_data) scripts_order = shared.opts.postprocessing_operation_order + scripts_filter_out = set(shared.opts.postprocessing_disable_in_extras) def script_score(name): for i, possible_match in enumerate(scripts_order): @@ -151,9 +152,10 @@ def script_score(name): return len(self.scripts) - script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)} + filtered_scripts = [script for script in self.scripts if script.name not in scripts_filter_out] + script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(filtered_scripts)} - return sorted(self.scripts, key=lambda x: script_scores[x.name]) + return sorted(filtered_scripts, key=lambda x: script_scores[x.name]) def setup_ui(self): inputs = [] diff --git a/modules/sd_emphasis.py b/modules/sd_emphasis.py index 654817b60..49ef1a6ac 100644 --- a/modules/sd_emphasis.py +++ b/modules/sd_emphasis.py @@ -35,7 +35,7 @@ class EmphasisIgnore(Emphasis): class EmphasisOriginal(Emphasis): name = "Original" - description = "the orginal emphasis implementation" + description = "the original emphasis implementation" def after_transformers(self): original_mean = self.z.mean() @@ -48,7 +48,7 @@ def after_transformers(self): class EmphasisOriginalNoNorm(EmphasisOriginal): name = "No norm" - description = "same as orginal, but without normalization (seems to work better for SDXL)" + description = "same as original, but without normalization (seems to work better for SDXL)" def after_transformers(self): self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index e55cd2174..8050b2786 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -185,13 +185,28 @@ def forward(self, input_ids): vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec emb = devices.cond_cast_unet(vec) emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) - tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) + tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype) vecs.append(tensor) return torch.stack(vecs) +class TextualInversionEmbeddings(torch.nn.Embedding): + def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs): + super().__init__(num_embeddings, embedding_dim, **kwargs) + + self.embeddings = model_hijack + self.textual_inversion_key = textual_inversion_key + + @property + def wrapped(self): + return super().forward + + def forward(self, input_ids): + return EmbeddingsWithFixes.forward(self, input_ids) + + def add_circular_option_to_conv_2d(): conv2d_constructor = torch.nn.Conv2d.__init__ diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 98350ac43..a479148fc 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -23,28 +23,25 @@ def __init__(self): PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) """An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt -chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally +chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally are applied by sd_hijack.EmbeddingsWithFixes's forward function.""" -class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): - """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to - have unlimited prompt length and assign weights to tokens in prompt. - """ - - def __init__(self, wrapped, hijack): +class TextConditionalModel(torch.nn.Module): + def __init__(self): super().__init__() - self.wrapped = wrapped - """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation, - depending on model.""" - - self.hijack: sd_hijack.StableDiffusionModelHijack = hijack + self.hijack = sd_hijack.model_hijack self.chunk_length = 75 - self.is_trainable = getattr(wrapped, 'is_trainable', False) - self.input_key = getattr(wrapped, 'input_key', 'txt') - self.legacy_ucg_val = None + self.is_trainable = False + self.input_key = 'txt' + self.return_pooled = False + + self.comma_token = None + self.id_start = None + self.id_end = None + self.id_pad = None def empty_chunk(self): """creates an empty PromptChunk and returns it""" @@ -66,7 +63,7 @@ def tokenize(self, texts): def encode_with_transformers(self, tokens): """ - converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens; + converts a batch of token ids (in python lists) into a single tensor with numeric representation of those tokens; All python lists with tokens are assumed to have same length, usually 77. if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on model - can be 768 and 1024. @@ -136,7 +133,7 @@ def next_chunk(is_last=False): if token == self.comma_token: last_comma = len(chunk.tokens) - # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack + # this is when we are at the end of allotted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next. elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: break_location = last_comma + 1 @@ -206,14 +203,10 @@ def forward(self, texts): be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280. An example shape returned by this function can be: (2, 77, 768). For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values. - Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet + Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one element is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" """ - if opts.use_old_emphasis_implementation: - import modules.sd_hijack_clip_old - return modules.sd_hijack_clip_old.forward_old(self, texts) - batch_chunks, token_count = self.process_texts(texts) used_embeddings = {} @@ -230,7 +223,7 @@ def forward(self, texts): for fixes in self.hijack.fixes: for _position, embedding in fixes: used_embeddings[embedding.name] = embedding - + devices.torch_npu_set_device() z = self.process_tokens(tokens, multipliers) zs.append(z) @@ -252,7 +245,7 @@ def forward(self, texts): if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original": self.hijack.extra_generation_params["Emphasis"] = opts.emphasis - if getattr(self.wrapped, 'return_pooled', False): + if self.return_pooled: return torch.hstack(zs), zs[0].pooled else: return torch.hstack(zs) @@ -292,6 +285,34 @@ def process_tokens(self, remade_batch_tokens, batch_multipliers): return z +class FrozenCLIPEmbedderWithCustomWordsBase(TextConditionalModel): + """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to + have unlimited prompt length and assign weights to tokens in prompt. + """ + + def __init__(self, wrapped, hijack): + super().__init__() + + self.hijack = hijack + + self.wrapped = wrapped + """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation, + depending on model.""" + + self.is_trainable = getattr(wrapped, 'is_trainable', False) + self.input_key = getattr(wrapped, 'input_key', 'txt') + self.return_pooled = getattr(self.wrapped, 'return_pooled', False) + + self.legacy_ucg_val = None # for sgm codebase + + def forward(self, texts): + if opts.use_old_emphasis_implementation: + import modules.sd_hijack_clip_old + return modules.sd_hijack_clip_old.forward_old(self, texts) + + return super().forward(texts) + + class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): def __init__(self, wrapped, hijack): super().__init__(wrapped, hijack) @@ -353,7 +374,9 @@ def __init__(self, wrapped, hijack): def encode_with_transformers(self, tokens): outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden") - if self.wrapped.layer == "last": + if opts.sdxl_clip_l_skip is True: + z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] + elif self.wrapped.layer == "last": z = outputs.last_hidden_state else: z = outputs.hidden_states[self.wrapped.layer_idx] diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 7f9e328d0..0269f1f5b 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -486,7 +486,8 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs): k_in = self.to_k(context_k) v_in = self.to_v(context_v) - q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in)) + q, k, v = (t.reshape(t.shape[0], t.shape[1], h, -1) for t in (q_in, k_in, v_in)) + del q_in, k_in, v_in dtype = q.dtype @@ -497,7 +498,8 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs): out = out.to(dtype) - out = rearrange(out, 'b n h d -> b n (h d)', h=h) + b, n, h, d = out.shape + out = out.reshape(b, n, h * d) return self.to_out(out) diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 2101f1a04..b4f03b138 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -1,5 +1,7 @@ import torch from packaging import version +from einops import repeat +import math from modules import devices from modules.sd_hijack_utils import CondFunc @@ -36,7 +38,7 @@ def cat(self, tensors, *args, **kwargs): # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): - + """Always make sure inputs to unet are in correct dtype.""" if isinstance(cond, dict): for y in cond.keys(): if isinstance(cond[y], list): @@ -45,7 +47,59 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] with devices.autocast(): - return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() + result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs) + if devices.unet_needs_upcast: + return result.float() + else: + return result + + +# Monkey patch to create timestep embed tensor on device, avoiding a block. +def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +# Monkey patch to SpatialTransformer removing unnecessary contiguous calls. +# Prevents a lot of unnecessary aten::copy_ calls +def spatial_transformer_forward(_, self, x: torch.Tensor, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = x.view(b, h, w, c).permute(0, 3, 1, 2) + if not self.use_linear: + x = self.proj_out(x) + return x + x_in class GELUHijack(torch.nn.GELU, torch.nn.Module): @@ -64,12 +118,15 @@ def hijack_ddpm_edit(): if not ddpm_edit_hijack: CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) - ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) + ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model) unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding) +CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward) CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) + if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) @@ -81,5 +138,17 @@ def hijack_ddpm_edit(): CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) -CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast) -CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model) +CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model) + + +def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs): + if devices.unet_needs_upcast and timesteps.dtype == torch.int64: + dtype = torch.float32 + else: + dtype = devices.dtype_unet + return orig_func(timesteps, *args, **kwargs).to(dtype=dtype) + + +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) +CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py index 79bf6e468..546f2eda4 100644 --- a/modules/sd_hijack_utils.py +++ b/modules/sd_hijack_utils.py @@ -1,7 +1,11 @@ import importlib + +always_true_func = lambda *args, **kwargs: True + + class CondFunc: - def __new__(cls, orig_func, sub_func, cond_func): + def __new__(cls, orig_func, sub_func, cond_func=always_true_func): self = super(CondFunc, cls).__new__(cls) if isinstance(orig_func, str): func_path = orig_func.split('.') @@ -20,13 +24,13 @@ def __new__(cls, orig_func, sub_func, cond_func): print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack") pass self.__init__(orig_func, sub_func, cond_func) - return lambda *args, **kwargs: self(*args, **kwargs) - def __init__(self, orig_func, sub_func, cond_func): - self.__orig_func = orig_func - self.__sub_func = sub_func - self.__cond_func = cond_func - def __call__(self, *args, **kwargs): - if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): - return self.__sub_func(self.__orig_func, *args, **kwargs) - else: - return self.__orig_func(*args, **kwargs) + return lambda *args, **kwargs: self(*args, **kwargs) + def __init__(self, orig_func, sub_func, cond_func): + self.__orig_func = orig_func + self.__sub_func = sub_func + self.__cond_func = cond_func + def __call__(self, *args, **kwargs): + if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): + return self.__sub_func(self.__orig_func, *args, **kwargs) + else: + return self.__orig_func(*args, **kwargs) diff --git a/modules/sd_models.py b/modules/sd_models.py index 185062dbb..538f5577d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,19 +1,18 @@ import collections -import os.path +import importlib +import os import sys import threading +import enum import torch import re import safetensors.torch from omegaconf import OmegaConf, ListConfig -from os import mkdir from urllib import request import ldm.modules.midas as midas import gc -from ldm.util import instantiate_from_config - from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches from modules.timer import Timer import numpy as np @@ -33,6 +32,14 @@ checkpoints_loaded = collections.OrderedDict() +class ModelType(enum.Enum): + SD1 = 1 + SD2 = 2 + SDXL = 3 + SSD = 4 + SD3 = 5 + + def replace_key(d, key, new_key, value): keys = list(d.keys()) @@ -155,6 +162,7 @@ def list_models(): cmd_ckpt = shared.cmd_opts.ckpt if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt): model_url = None + expected_sha256 = None else: model_url = "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/realisticVisionV51_v51VAE.safetensors" @@ -286,17 +294,21 @@ def read_metadata_from_safetensors(filename): json_start = file.read(2) assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file" - json_data = json_start + file.read(metadata_len-2) - json_obj = json.loads(json_data) res = {} - for k, v in json_obj.get("__metadata__", {}).items(): - res[k] = v - if isinstance(v, str) and v[0:1] == '{': - try: - res[k] = json.loads(v) - except Exception: - pass + + try: + json_data = json_start + file.read(metadata_len-2) + json_obj = json.loads(json_data) + for k, v in json_obj.get("__metadata__", {}).items(): + res[k] = v + if isinstance(v, str) and v[0:1] == '{': + try: + res[k] = json.loads(v) + except Exception: + pass + except Exception: + errors.report(f"Error reading metadata from file: {filename}", exc_info=True) return res @@ -368,42 +380,39 @@ def check_fp8(model): return enable_fp8 -def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): - sd_model_hash = checkpoint_info.calculate_shorthash() - timer.record("calculate hash") - - if not SkipWritingToConfig.skip: - shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title - - if state_dict is None: - state_dict = get_checkpoint_state_dict(checkpoint_info, timer) +def set_model_type(model, state_dict): + model.is_sd1 = False + model.is_sd2 = False + model.is_sdxl = False + model.is_ssd = False + model.is_sd3 = False - if shared.opts.sd_checkpoint_cache > 0: - # cache newly loaded model - checkpoints_loaded[checkpoint_info] = state_dict.copy() + if "model.diffusion_model.x_embedder.proj.weight" in state_dict: + model.is_sd3 = True + model.model_type = ModelType.SD3 + elif hasattr(model, 'conditioner'): + model.is_sdxl = True - model.load_state_dict(state_dict, strict=False) - timer.record("apply weights to model") - - del state_dict + if 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys(): + model.is_ssd = True + model.model_type = ModelType.SSD + else: + model.model_type = ModelType.SDXL + elif hasattr(model.cond_stage_model, 'model'): + model.is_sd2 = True + model.model_type = ModelType.SD2 + else: + model.is_sd1 = True + model.model_type = ModelType.SD1 - # clean up cache if limit is reached - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: - checkpoints_loaded.popitem(last=False) - model.sd_model_hash = sd_model_hash - model.sd_model_checkpoint = checkpoint_info.filename - model.sd_checkpoint_info = checkpoint_info - shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 +def set_model_fields(model): + if not hasattr(model, 'latent_channels'): + model.latent_channels = 4 - if hasattr(model, 'logvar'): - model.logvar = model.logvar.to(devices.device) # fix for training - sd_vae.delete_base_vae() - sd_vae.clear_loaded_vae() - vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple() - sd_vae.load_vae(model, vae_file, vae_source) - timer.record("load VAE") +def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): + return def enable_midas_autodownload(): @@ -438,7 +447,7 @@ def load_model_wrapper(model_type): path = midas.api.ISL_PATHS[model_type] if not os.path.exists(path): if not os.path.exists(midas_path): - mkdir(midas_path) + os.mkdir(midas_path) print(f"Downloading midas model weights for {model_type} to {path}") request.urlretrieve(midas_urls[model_type], path) @@ -463,25 +472,76 @@ def patched_register_schedule(*args, **kwargs): original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule) -def repair_config(sd_config): - +def repair_config(sd_config, state_dict=None): if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False if hasattr(sd_config.model.params, 'unet_config'): if shared.cmd_opts.no_half: sd_config.model.params.unet_config.params.use_fp16 = False - elif shared.cmd_opts.upcast_sampling: + elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half": sd_config.model.params.unet_config.params.use_fp16 = True - if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: - sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" + if hasattr(sd_config.model.params, 'first_stage_config'): + if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: + sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" # For UnCLIP-L, override the hardcoded karlo directory if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"): karlo_path = os.path.join(paths.models_path, 'karlo') sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path) + # Do not use checkpoint for inference. + # This helps prevent extra performance overhead on checking parameters. + # The perf overhead is about 100ms/it on 4090 for SDXL. + if hasattr(sd_config.model.params, "network_config"): + sd_config.model.params.network_config.params.use_checkpoint = False + if hasattr(sd_config.model.params, "unet_config"): + sd_config.model.params.unet_config.params.use_checkpoint = False + + + +def rescale_zero_terminal_snr_abar(alphas_cumprod): + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= (alphas_bar_sqrt_T) + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt + alphas_bar[-1] = 4.8973451890853435e-08 + return alphas_bar + + +def apply_alpha_schedule_override(sd_model, p=None): + """ + Applies an override to the alpha schedule of the model according to settings. + - downcasts the alpha schedule to half precision + - rescales the alpha schedule to have zero terminal SNR + """ + + if not hasattr(sd_model, 'alphas_cumprod') or not hasattr(sd_model, 'alphas_cumprod_original'): + return + + sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device) + + if opts.use_downcasted_alpha_bar: + if p is not None: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device) + + if opts.sd_noise_schedule == "Zero Terminal SNR": + if p is not None: + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device) + sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' @@ -532,11 +592,15 @@ def get_empty_cond(sd_model): p = processing.StableDiffusionProcessingTxt2Img() extra_networks.activate(p, {}) - if hasattr(sd_model, 'conditioner'): + if hasattr(sd_model, 'get_learned_conditioning'): d = sd_model.get_learned_conditioning([""]) - return d['crossattn'] else: - return sd_model.cond_stage_model([""]) + d = sd_model.cond_stage_model([""]) + + if isinstance(d, dict): + d = d['crossattn'] + + return d def send_model_to_cpu(m): @@ -555,6 +619,25 @@ def send_model_to_trash(m): pass +def instantiate_from_config(config, state_dict=None): + constructor = get_obj_from_str(config["target"]) + + params = {**config.get("params", {})} + + if state_dict and "state_dict" in params and params["state_dict"] is None: + params["state_dict"] = state_dict + + return constructor(**params) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + def load_model(checkpoint_info=None, already_loaded_state_dict=None): from modules import sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -585,6 +668,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): sd_model = forge_loader.load_model_for_a1111(timer=timer, checkpoint_info=checkpoint_info, state_dict=state_dict) sd_model.filename = checkpoint_info.filename + if not SkipWritingToConfig.skip: + shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title + del state_dict # clean up cache if limit is reached diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index b38137eb5..fb44c5a8d 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -23,6 +23,8 @@ config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml") +config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml") + def is_using_v_parameterization_for_sd2(state_dict): """ @@ -31,11 +33,11 @@ def is_using_v_parameterization_for_sd2(state_dict): import ldm.modules.diffusionmodules.openaimodel - device = devices.cpu + device = devices.device with sd_disable_initialization.DisableInitialization(): unet = ldm.modules.diffusionmodules.openaimodel.UNetModel( - use_checkpoint=True, + use_checkpoint=False, use_fp16=False, image_size=32, in_channels=4, @@ -56,12 +58,13 @@ def is_using_v_parameterization_for_sd2(state_dict): with torch.no_grad(): unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} unet.load_state_dict(unet_sd, strict=True) - unet.to(device=device, dtype=torch.float) + unet.to(device=device, dtype=devices.dtype_unet) test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 - out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item() + with devices.autocast(): + out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item() return out < -1 @@ -71,11 +74,15 @@ def guess_model_config_from_state_dict(sd, filename): diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) + if "model.diffusion_model.x_embedder.proj.weight" in sd: + return config_sd3 + if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: if diffusion_model_input.shape[1] == 9: return config_sdxl_inpainting else: return config_sdxl + if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: return config_sdxl_refiner elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: @@ -99,7 +106,6 @@ def guess_model_config_from_state_dict(sd, filename): if diffusion_model_input.shape[1] == 8: return config_instruct_pix2pix - if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024: return config_alt_diffusion_m18 diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py index f911fbb68..2fce2777b 100644 --- a/modules/sd_models_types.py +++ b/modules/sd_models_types.py @@ -32,3 +32,9 @@ class WebuiSdModel(LatentDiffusion): is_sd1: bool """True if the model's architecture is SD 1.x""" + + is_sd3: bool + """True if the model's architecture is SD 3""" + + latent_channels: int + """number of layer in latent image representation; will be 16 in SD3 and 4 in other version""" diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 9ea8d6906..4f8c7ee15 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -18,8 +18,8 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: for embedder in self.conditioner.embedders: embedder.ucg_rate = 0.0 - width = getattr(batch, 'width', 1024) - height = getattr(batch, 'height', 1024) + width = getattr(batch, 'width', 1024) or 1024 + height = getattr(batch, 'height', 1024) or 1024 is_negative_prompt = getattr(batch, 'is_negative_prompt', False) aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index b70679471..9798cefe3 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,4 +1,8 @@ -from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared +from __future__ import annotations + +import functools +import logging +from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common, sd_schedulers # imports for functions that previously were here and are used by other modules from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401 @@ -12,8 +16,8 @@ ] all_samplers_map = {x.name: x for x in all_samplers} -samplers = [] -samplers_for_img2img = [] +samplers: list[sd_samplers_common.SamplerData] = [] +samplers_for_img2img: list[sd_samplers_common.SamplerData] = [] samplers_map = {} samplers_hidden = {} @@ -59,4 +63,71 @@ def visible_sampler_names(): return [x.name for x in samplers if x.name not in samplers_hidden] +def visible_samplers(): + return [x for x in samplers if x.name not in samplers_hidden] + + +def get_sampler_from_infotext(d: dict): + return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0] + + +def get_scheduler_from_infotext(d: dict): + return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1] + + +def get_hr_sampler_and_scheduler(d: dict): + hr_sampler = d.get("Hires sampler", "Use same sampler") + sampler = d.get("Sampler") if hr_sampler == "Use same sampler" else hr_sampler + + hr_scheduler = d.get("Hires schedule type", "Use same scheduler") + scheduler = d.get("Schedule type") if hr_scheduler == "Use same scheduler" else hr_scheduler + + sampler, scheduler = get_sampler_and_scheduler(sampler, scheduler) + + sampler = sampler if sampler != d.get("Sampler") else "Use same sampler" + scheduler = scheduler if scheduler != d.get("Schedule type") else "Use same scheduler" + + return sampler, scheduler + + +def get_hr_sampler_from_infotext(d: dict): + return get_hr_sampler_and_scheduler(d)[0] + + +def get_hr_scheduler_from_infotext(d: dict): + return get_hr_sampler_and_scheduler(d)[1] + + +@functools.cache +def get_sampler_and_scheduler(sampler_name, scheduler_name, *, convert_automatic=True): + default_sampler = samplers[0] + found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0]) + + name = sampler_name or default_sampler.name + + for scheduler in sd_schedulers.schedulers: + name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])] + + for name_option in name_options: + if name.endswith(" " + name_option): + found_scheduler = scheduler + name = name[0:-(len(name_option) + 1)] + break + + sampler = all_samplers_map.get(name, default_sampler) + + # revert back to Automatic if it's the default scheduler for the selected sampler + if convert_automatic and sampler.options.get('scheduler', None) == found_scheduler.name: + found_scheduler = sd_schedulers.schedulers[0] + + return sampler.name, found_scheduler.label + + +def fix_p_invalid_sampler_and_scheduler(p): + i_sampler_name, i_scheduler = p.sampler_name, p.scheduler + p.sampler_name, p.scheduler = get_sampler_and_scheduler(p.sampler_name, p.scheduler, convert_automatic=False) + if p.sampler_name != i_sampler_name or i_scheduler != p.scheduler: + logging.warning(f'Sampler Scheduler autocorrection: "{i_sampler_name}" -> "{p.sampler_name}", "{i_scheduler}" -> "{p.scheduler}"') + + set_samplers() diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index e2d3826b2..7594d97ff 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -1,5 +1,5 @@ import torch -from modules import prompt_parser, devices, sd_samplers_common +from modules import prompt_parser, sd_samplers_common from modules.shared import opts, state import modules.shared as shared @@ -183,7 +183,15 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): cond_scale=cond_scale, cond_composition=cond_composition) if self.mask is not None: - denoised = denoised * self.nmask + self.init_latent * self.mask + blended_latent = denoised * self.nmask + self.init_latent * self.mask + + if self.p.scripts is not None: + from modules import scripts + mba = scripts.MaskBlendArgs(denoised, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma) + self.p.scripts.on_mask_blend(self.p, mba) + blended_latent = mba.blended_latent + + denoised = blended_latent preview = self.sampler.last_latent = denoised sd_samplers_common.store_latent(preview) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 77ae38124..56d1dff47 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -237,7 +237,7 @@ def __init__(self, funcname): self.eta_infotext_field = 'Eta' self.eta_default = 1.0 - self.conditioning_key = shared.sd_model.model.conditioning_key + self.conditioning_key = getattr(shared.sd_model.model, 'conditioning_key', 'crossattn') self.p = None self.model_wrap_cfg = None diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 887d180d7..348d45f11 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,7 +1,7 @@ import torch import inspect import k_diffusion.sampling -from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser +from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers, devices from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401 from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback @@ -11,32 +11,20 @@ samplers_k_diffusion = [ - ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), - ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), - ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}), - ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), + ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {'scheduler': 'karras'}), + ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), + ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde'], {'scheduler': 'exponential', "brownian_noise": True}), + ('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}), + ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}), + ('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}), ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}), ('Euler', 'sample_euler', ['k_euler'], {}), ('LMS', 'sample_lms', ['k_lms'], {}), ('Heun', 'sample_heun', ['k_heun'], {"second_order": True}), - ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True, "second_order": True}), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), - ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}), - ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), - ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), - ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), - ('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {"brownian_noise": True, "solver_type": "heun"}), - ('DPM++ 2M SDE Heun Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_ka'], {'scheduler': 'karras', "brownian_noise": True, "solver_type": "heun"}), - ('DPM++ 2M SDE Heun Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_exp'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}), - ('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "brownian_noise": True}), - ('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "brownian_noise": True}), - ('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}), + ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "second_order": True}), + ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}), ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}), - ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), - ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), - ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), - ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}), ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras', "second_order": True}), ] @@ -60,20 +48,20 @@ } k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion} -k_diffusion_scheduler = { - 'Automatic': None, - 'karras': k_diffusion.sampling.get_sigmas_karras, - 'exponential': k_diffusion.sampling.get_sigmas_exponential, - 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential -} +k_diffusion_scheduler = {x.name: x.function for x in sd_schedulers.schedulers} class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser): @property def inner_model(self): if self.model_wrap is None: - denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) + denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None) + + if denoiser_constructor is not None: + self.model_wrap = denoiser_constructor() + else: + denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) return self.model_wrap @@ -98,47 +86,52 @@ def get_sigmas(self, p, steps): steps += 1 if discard_next_to_last_sigma else 0 + scheduler_name = (p.hr_scheduler if p.is_hr_pass else p.scheduler) or 'Automatic' + if scheduler_name == 'Automatic': + scheduler_name = self.config.options.get('scheduler', None) + + scheduler = sd_schedulers.schedulers_map.get(scheduler_name) + + m_sigma_min, m_sigma_max = self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item() + sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max) + if p.sampler_noise_scheduler_override: sigmas = p.sampler_noise_scheduler_override(steps) - elif opts.k_sched_type != "Automatic": - m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) - sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max) - sigmas_kwargs = { - 'sigma_min': sigma_min, - 'sigma_max': sigma_max, - } - - sigmas_func = k_diffusion_scheduler[opts.k_sched_type] - p.extra_generation_params["Schedule type"] = opts.k_sched_type - - if opts.sigma_min != m_sigma_min and opts.sigma_min != 0: + elif scheduler is None or scheduler.function is None: + sigmas = self.model_wrap.get_sigmas(steps) + else: + sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max} + + if scheduler.label != 'Automatic' and not p.is_hr_pass: + p.extra_generation_params["Schedule type"] = scheduler.label + elif scheduler.label != p.extra_generation_params.get("Schedule type"): + p.extra_generation_params["Hires schedule type"] = scheduler.label + + if opts.sigma_min != 0 and opts.sigma_min != m_sigma_min: sigmas_kwargs['sigma_min'] = opts.sigma_min p.extra_generation_params["Schedule min sigma"] = opts.sigma_min - if opts.sigma_max != m_sigma_max and opts.sigma_max != 0: + + if opts.sigma_max != 0 and opts.sigma_max != m_sigma_max: sigmas_kwargs['sigma_max'] = opts.sigma_max p.extra_generation_params["Schedule max sigma"] = opts.sigma_max - default_rho = 1. if opts.k_sched_type == "polyexponential" else 7. - - if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho: + if scheduler.default_rho != -1 and opts.rho != 0 and opts.rho != scheduler.default_rho: sigmas_kwargs['rho'] = opts.rho p.extra_generation_params["Schedule rho"] = opts.rho - sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device) - elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': - sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) + if scheduler.need_inner_model: + sigmas_kwargs['inner_model'] = self.model_wrap - sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) - elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential': - m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) - sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device) - else: - sigmas = self.model_wrap.get_sigmas(steps) + if scheduler.label == 'Beta': + p.extra_generation_params["Beta schedule alpha"] = opts.beta_dist_alpha + p.extra_generation_params["Beta schedule beta"] = opts.beta_dist_beta + + sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=devices.cpu) if discard_next_to_last_sigma: sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) - return sigmas + return sigmas.cpu() def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): unet_patcher = self.model_wrap.inner_model.forge_objects.unet diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index 149d67009..9dbb53b61 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -12,6 +12,7 @@ samplers_timesteps = [ ('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}), + ('DDIM CFG++', sd_samplers_timesteps_impl.ddim_cfgpp, ['ddim_cfgpp'], {}), ('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}), ('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}), ] diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py index 930a64af5..180e43899 100644 --- a/modules/sd_samplers_timesteps_impl.py +++ b/modules/sd_samplers_timesteps_impl.py @@ -5,13 +5,14 @@ from modules import shared from modules.models.diffusion.uni_pc import uni_pc +from modules.torch_utils import float64 @torch.no_grad() def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] - alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32) + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x)) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) @@ -39,11 +40,51 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta= return x +@torch.no_grad() +def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): + """ Implements CFG++: Manifold-constrained Classifier Free Guidance For Diffusion Models (2024). + Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction. + The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0]. + """ + alphas_cumprod = model.inner_model.inner_model.alphas_cumprod + alphas = alphas_cumprod[timesteps] + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x)) + sqrt_one_minus_alphas = torch.sqrt(1 - alphas) + sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) + + model.cond_scale_miltiplier = 1 / 12.5 + model.need_last_noise_uncond = True + + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones((x.shape[0])) + s_x = x.new_ones((x.shape[0], 1, 1, 1)) + for i in tqdm.trange(len(timesteps) - 1, disable=disable): + index = len(timesteps) - 1 - i + + e_t = model(x, timesteps[index].item() * s_in, **extra_args) + last_noise_uncond = model.last_noise_uncond + + a_t = alphas[index].item() * s_x + a_prev = alphas_prev[index].item() * s_x + sigma_t = sigmas[index].item() * s_x + sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x + + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * last_noise_uncond + noise = sigma_t * k_diffusion.sampling.torch.randn_like(x) + x = a_prev.sqrt() * pred_x0 + dir_xt + noise + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0}) + + return x + + @torch.no_grad() def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] - alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32) + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x)) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) extra_args = {} if extra_args is None else extra_args diff --git a/modules/sd_schedulers.py b/modules/sd_schedulers.py new file mode 100644 index 000000000..af873dc97 --- /dev/null +++ b/modules/sd_schedulers.py @@ -0,0 +1,154 @@ +import dataclasses +import torch +import k_diffusion +import numpy as np +from scipy import stats + +from modules import shared + + +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / sigma + + +k_diffusion.sampling.to_d = to_d + + +@dataclasses.dataclass +class Scheduler: + name: str + label: str + function: any + + default_rho: float = -1 + need_inner_model: bool = False + aliases: list = None + + +def uniform(n, sigma_min, sigma_max, inner_model, device): + return inner_model.get_sigmas(n).to(device) + + +def sgm_uniform(n, sigma_min, sigma_max, inner_model, device): + start = inner_model.sigma_to_t(torch.tensor(sigma_max)) + end = inner_model.sigma_to_t(torch.tensor(sigma_min)) + sigs = [ + inner_model.t_to_sigma(ts) + for ts in torch.linspace(start, end, n + 1)[:-1] + ] + sigs += [0.0] + return torch.FloatTensor(sigs).to(device) + + +def get_align_your_steps_sigmas(n, sigma_min, sigma_max, device): + # https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html + def loglinear_interp(t_steps, num_steps): + """ + Performs log-linear interpolation of a given array of decreasing numbers. + """ + xs = np.linspace(0, 1, len(t_steps)) + ys = np.log(t_steps[::-1]) + + new_xs = np.linspace(0, 1, num_steps) + new_ys = np.interp(new_xs, xs, ys) + + interped_ys = np.exp(new_ys)[::-1].copy() + return interped_ys + + if shared.sd_model.is_sdxl: + sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029] + else: + # Default to SD 1.5 sigmas. + sigmas = [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029] + + if n != len(sigmas): + sigmas = np.append(loglinear_interp(sigmas, n), [0.0]) + else: + sigmas.append(0.0) + + return torch.FloatTensor(sigmas).to(device) + + +def kl_optimal(n, sigma_min, sigma_max, device): + alpha_min = torch.arctan(torch.tensor(sigma_min, device=device)) + alpha_max = torch.arctan(torch.tensor(sigma_max, device=device)) + step_indices = torch.arange(n + 1, device=device) + sigmas = torch.tan(step_indices / n * alpha_min + (1.0 - step_indices / n) * alpha_max) + return sigmas + + +def simple_scheduler(n, sigma_min, sigma_max, inner_model, device): + sigs = [] + ss = len(inner_model.sigmas) / n + for x in range(n): + sigs += [float(inner_model.sigmas[-(1 + int(x * ss))])] + sigs += [0.0] + return torch.FloatTensor(sigs).to(device) + + +def normal_scheduler(n, sigma_min, sigma_max, inner_model, device, sgm=False, floor=False): + start = inner_model.sigma_to_t(torch.tensor(sigma_max)) + end = inner_model.sigma_to_t(torch.tensor(sigma_min)) + + if sgm: + timesteps = torch.linspace(start, end, n + 1)[:-1] + else: + timesteps = torch.linspace(start, end, n) + + sigs = [] + for x in range(len(timesteps)): + ts = timesteps[x] + sigs.append(inner_model.t_to_sigma(ts)) + sigs += [0.0] + return torch.FloatTensor(sigs).to(device) + + +def ddim_scheduler(n, sigma_min, sigma_max, inner_model, device): + sigs = [] + ss = max(len(inner_model.sigmas) // n, 1) + x = 1 + while x < len(inner_model.sigmas): + sigs += [float(inner_model.sigmas[x])] + x += ss + sigs = sigs[::-1] + sigs += [0.0] + return torch.FloatTensor(sigs).to(device) + + +def beta_scheduler(n, sigma_min, sigma_max, inner_model, device): + # From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """ + alpha = shared.opts.beta_dist_alpha + beta = shared.opts.beta_dist_beta + timesteps = 1 - np.linspace(0, 1, n) + timesteps = [stats.beta.ppf(x, alpha, beta) for x in timesteps] + sigmas = [sigma_min + (x * (sigma_max-sigma_min)) for x in timesteps] + sigmas += [0.0] + return torch.FloatTensor(sigmas).to(device) + + +def turbo_scheduler(n, sigma_min, sigma_max, inner_model, device): + unet = inner_model.inner_model.forge_objects.unet + timesteps = torch.flip(torch.arange(1, n + 1) * float(1000.0 / n) - 1, (0,)).round().long().clip(0, 999) + sigmas = unet.model.model_sampling.sigma(timesteps) + sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) + return sigmas.to(device) + + +schedulers = [ + Scheduler('automatic', 'Automatic', None), + Scheduler('uniform', 'Uniform', uniform, need_inner_model=True), + Scheduler('karras', 'Karras', k_diffusion.sampling.get_sigmas_karras, default_rho=7.0), + Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential), + Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0), + Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]), + Scheduler('kl_optimal', 'KL Optimal', kl_optimal), + Scheduler('align_your_steps', 'Align Your Steps', get_align_your_steps_sigmas), + Scheduler('simple', 'Simple', simple_scheduler, need_inner_model=True), + Scheduler('normal', 'Normal', normal_scheduler, need_inner_model=True), + Scheduler('ddim', 'DDIM', ddim_scheduler, need_inner_model=True), + Scheduler('beta', 'Beta', beta_scheduler, need_inner_model=True), + Scheduler('turbo', 'Turbo', turbo_scheduler, need_inner_model=True), +] + +schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}} diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index 3965e223e..c5dda7431 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -8,9 +8,9 @@ class VAEApprox(nn.Module): - def __init__(self): + def __init__(self, latent_channels=4): super(VAEApprox, self).__init__() - self.conv1 = nn.Conv2d(4, 8, (7, 7)) + self.conv1 = nn.Conv2d(latent_channels, 8, (7, 7)) self.conv2 = nn.Conv2d(8, 16, (5, 5)) self.conv3 = nn.Conv2d(16, 32, (3, 3)) self.conv4 = nn.Conv2d(32, 64, (3, 3)) @@ -40,7 +40,13 @@ def download_model(model_path, model_url): def model(): - model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt" + if shared.sd_model.is_sd3: + model_name = "vaeapprox-sd3.pt" + elif shared.sd_model.is_sdxl: + model_name = "vaeapprox-sdxl.pt" + else: + model_name = "model.pt" + loaded_model = sd_vae_approx_models.get(model_name) if loaded_model is None: @@ -52,7 +58,7 @@ def model(): model_path = os.path.join(paths.models_path, "VAE-approx", model_name) download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name) - loaded_model = VAEApprox() + loaded_model = VAEApprox(latent_channels=shared.sd_model.latent_channels) loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)) loaded_model.eval() loaded_model.to(devices.device, devices.dtype) @@ -64,7 +70,18 @@ def model(): def cheap_approximation(sample): # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2 - if shared.sd_model.is_sdxl: + if shared.sd_model.is_sd3: + coeffs = [ + [-0.0645, 0.0177, 0.1052], [ 0.0028, 0.0312, 0.0650], + [ 0.1848, 0.0762, 0.0360], [ 0.0944, 0.0360, 0.0889], + [ 0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284], + [ 0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047], + [-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039], + [ 0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481], + [ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867], + [-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259], + ] + elif shared.sd_model.is_sdxl: coeffs = [ [ 0.3448, 0.4168, 0.4395], [-0.1953, -0.0290, 0.0250], diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 808eb3624..d06253d2a 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -34,9 +34,9 @@ def forward(self, x): return self.fuse(self.conv(x) + self.skip(x)) -def decoder(): +def decoder(latent_channels=4): return nn.Sequential( - Clamp(), conv(4, 64), nn.ReLU(), + Clamp(), conv(latent_channels, 64), nn.ReLU(), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), @@ -44,13 +44,13 @@ def decoder(): ) -def encoder(): +def encoder(latent_channels=4): return nn.Sequential( conv(3, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), - conv(64, 4), + conv(64, latent_channels), ) @@ -58,10 +58,14 @@ class TAESDDecoder(nn.Module): latent_magnitude = 3 latent_shift = 0.5 - def __init__(self, decoder_path="taesd_decoder.pth"): + def __init__(self, decoder_path="taesd_decoder.pth", latent_channels=None): """Initialize pretrained TAESD on the given device from the given checkpoints.""" super().__init__() - self.decoder = decoder() + + if latent_channels is None: + latent_channels = 16 if "taesd3" in str(decoder_path) else 4 + + self.decoder = decoder(latent_channels) self.decoder.load_state_dict( torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) @@ -70,10 +74,14 @@ class TAESDEncoder(nn.Module): latent_magnitude = 3 latent_shift = 0.5 - def __init__(self, encoder_path="taesd_encoder.pth"): + def __init__(self, encoder_path="taesd_encoder.pth", latent_channels=None): """Initialize pretrained TAESD on the given device from the given checkpoints.""" super().__init__() - self.encoder = encoder() + + if latent_channels is None: + latent_channels = 16 if "taesd3" in str(encoder_path) else 4 + + self.encoder = encoder(latent_channels) self.encoder.load_state_dict( torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) @@ -87,7 +95,13 @@ def download_model(model_path, model_url): def decoder_model(): - model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth" + if shared.sd_model.is_sd3: + model_name = "taesd3_decoder.pth" + elif shared.sd_model.is_sdxl: + model_name = "taesdxl_decoder.pth" + else: + model_name = "taesd_decoder.pth" + loaded_model = sd_vae_taesd_models.get(model_name) if loaded_model is None: @@ -106,7 +120,13 @@ def decoder_model(): def encoder_model(): - model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth" + if shared.sd_model.is_sd3: + model_name = "taesd3_encoder.pth" + elif shared.sd_model.is_sdxl: + model_name = "taesdxl_encoder.pth" + else: + model_name = "taesd_encoder.pth" + loaded_model = sd_vae_taesd_models.get(model_name) if loaded_model is None: diff --git a/modules/shared.py b/modules/shared.py index ccdca4e70..2a3787f99 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -6,6 +6,10 @@ from modules import shared_cmd_options, shared_gradio_themes, options, shared_items, sd_models_types from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 from modules import util +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from modules import shared_state, styles, interrogate, shared_total_tqdm, memmon cmd_opts = shared_cmd_options.cmd_opts parser = shared_cmd_options.parser @@ -16,11 +20,11 @@ config_filename = cmd_opts.ui_settings_file hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} -demo = None +demo: gr.Blocks = None -device = None +device: str = None -weight_load_location = None +weight_load_location: str = None xformers_available = False @@ -28,22 +32,22 @@ loaded_hypernetworks = [] -state = None +state: 'shared_state.State' = None -prompt_styles = None +prompt_styles: 'styles.StyleDatabase' = None -interrogator = None +interrogator: 'interrogate.InterrogateModels' = None face_restorers = [] -options_templates = None -opts = None -restricted_opts = None +options_templates: dict = None +opts: options.Options = None +restricted_opts: set[str] = None sd_model: sd_models_types.WebuiSdModel = None -settings_components = None -"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings""" +settings_components: dict = None +"""assigned from ui.py, a mapping on setting names to gradio components responsible for those settings""" tab_names = [] @@ -65,9 +69,9 @@ gradio_theme = gr.themes.Base() -total_tqdm = None +total_tqdm: 'shared_total_tqdm.TotalTQDM' = None -mem_mon = None +mem_mon: 'memmon.MemUsageMonitor' = None options_section = options.options_section OptionInfo = options.OptionInfo @@ -86,3 +90,5 @@ refresh_checkpoints = shared_items.refresh_checkpoints list_samplers = shared_items.list_samplers reload_hypernetworks = shared_items.reload_hypernetworks + +hf_endpoint = os.getenv('HF_ENDPOINT', 'https://huggingface.co') diff --git a/modules/shared_gradio_themes.py b/modules/shared_gradio_themes.py index b6dc31450..b4e3f32bc 100644 --- a/modules/shared_gradio_themes.py +++ b/modules/shared_gradio_themes.py @@ -69,3 +69,44 @@ def reload_gradio_theme(theme_name=None): # append additional values gradio_theme shared.gradio_theme.sd_webui_modal_lightbox_toolbar_opacity = shared.opts.sd_webui_modal_lightbox_toolbar_opacity shared.gradio_theme.sd_webui_modal_lightbox_icon_opacity = shared.opts.sd_webui_modal_lightbox_icon_opacity + + +def resolve_var(name: str, gradio_theme=None, history=None): + """ + Attempt to resolve a theme variable name to its value + + Parameters: + name (str): The name of the theme variable + ie "background_fill_primary", "background_fill_primary_dark" + spaces and asterisk (*) prefix is removed from name before lookup + gradio_theme (gradio.themes.ThemeClass): The theme object to resolve the variable from + blank to use the webui default shared.gradio_theme + history (list): A list of previously resolved variables to prevent circular references + for regular use leave blank + Returns: + str: The resolved value + + Error handling: + return either #000000 or #ffffff depending on initial name ending with "_dark" + """ + try: + if history is None: + history = [] + if gradio_theme is None: + gradio_theme = shared.gradio_theme + + name = name.strip() + name = name[1:] if name.startswith("*") else name + + if name in history: + raise ValueError(f'Circular references: name "{name}" in {history}') + + if value := getattr(gradio_theme, name, None): + return resolve_var(value, gradio_theme, history + [name]) + else: + return name + + except Exception: + name = history[0] if history else name + errors.report(f'resolve_color({name})', exc_info=True) + return '#000000' if name.endswith("_dark") else '#ffffff' diff --git a/modules/shared_items.py b/modules/shared_items.py index 88f636452..11f10b3f7 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -1,5 +1,8 @@ +import html import sys +from modules import script_callbacks, scripts, ui_components +from modules.options import OptionHTML, OptionInfo from modules.shared_cmd_options import cmd_opts @@ -118,6 +121,45 @@ def ui_reorder_categories(): yield "scripts" +def callbacks_order_settings(): + options = { + "sd_vae_explanation": OptionHTML(""" + For categories below, callbacks added to dropdowns happen before others, in order listed. + """), + + } + + callback_options = {} + + for category, _ in script_callbacks.enumerate_callbacks(): + callback_options[category] = script_callbacks.ordered_callbacks(category, enable_user_sort=False) + + for method_name in scripts.scripts_txt2img.callback_names: + callback_options["script_" + method_name] = scripts.scripts_txt2img.create_ordered_callbacks_list(method_name, enable_user_sort=False) + + for method_name in scripts.scripts_img2img.callback_names: + callbacks = callback_options.get("script_" + method_name, []) + + for addition in scripts.scripts_img2img.create_ordered_callbacks_list(method_name, enable_user_sort=False): + if any(x.name == addition.name for x in callbacks): + continue + + callbacks.append(addition) + + callback_options["script_" + method_name] = callbacks + + for category, callbacks in callback_options.items(): + if not callbacks: + continue + + option_info = OptionInfo([], f"{category} callback priority", ui_components.DropdownMulti, {"choices": [x.name for x in callbacks]}) + option_info.needs_restart() + option_info.html("
Default order:
    " + "".join(f"
  1. {html.escape(x.name)}
  2. \n" for x in callbacks) + "
") + options['prioritized_callbacks_' + category] = option_info + + return options + + class Shared(sys.modules[__name__].__class__): """ this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than diff --git a/modules/shared_options.py b/modules/shared_options.py index 856c07a84..df83a6024 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -19,7 +19,9 @@ "outdir_grids", "outdir_txt2img_grids", "outdir_save", - "outdir_init_images" + "outdir_init_images", + "temp_dir", + "clean_temp_dir_at_start", } categories.register_category("saving", "Saving images") @@ -52,7 +54,7 @@ "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"), "save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"), - "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), + "jpeg_quality": OptionInfo(80, "Quality for saved jpeg and avif images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), "webp_lossless": OptionInfo(False, "Use lossless compression for webp images"), "export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"), "img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number), @@ -62,6 +64,7 @@ "use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"), "use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"), "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), + "save_write_log_csv": OptionInfo(True, "Write log.csv when saving images using 'Save' button"), "save_init_img": OptionInfo(False, "Save init images when using img2img"), "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"), @@ -101,6 +104,7 @@ "DAT_tile": OptionInfo(192, "Tile size for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), "DAT_tile_overlap": OptionInfo(8, "Tile overlap for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}), + "set_scale_by_when_changing_upscaler": OptionInfo(False, "Automatically set the Scale by factor based on the name of the selected Upscaler."), })) options_templates.update(options_section(('face-restoration', "Face restoration", "postprocessing"), { @@ -126,6 +130,22 @@ "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."), })) +options_templates.update(options_section(('profiler', "Profiler", "system"), { + "profiling_explanation": OptionHTML(""" +Those settings allow you to enable torch profiler when generating pictures. +Profiling allows you to see which code uses how much of computer's resources during generation. +Each generation writes its own profile to one file, overwriting previous. +The file can be viewed in Chrome, or on a Perfetto web site. +Warning: writing profile can take a lot of time, up to 30 seconds, and the file itelf can be around 500MB in size. +"""), + "profiling_enable": OptionInfo(False, "Enable profiling"), + "profiling_activities": OptionInfo(["CPU"], "Activities", gr.CheckboxGroup, {"choices": ["CPU", "CUDA"]}), + "profiling_record_shapes": OptionInfo(True, "Record shapes"), + "profiling_profile_memory": OptionInfo(True, "Profile memory"), + "profiling_with_stack": OptionInfo(True, "Include python stack"), + "profiling_filename": OptionInfo("trace.json", "Profile filename"), +})) + options_templates.update(options_section(('API', "API", "system"), { "api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True), "api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True), @@ -157,6 +177,7 @@ "emphasis": OptionInfo("Original", "Emphasis mode", gr.Radio, lambda: {"choices": [x.name for x in sd_emphasis.options]}, infotext="Emphasis").info("makes it possible to make model to pay (more:1.1) or (less:0.9) attention to text when you use the syntax in prompt; " + sd_emphasis.get_options_descriptions()), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"), + "sdxl_clip_l_skip": OptionInfo(False, "Clip skip SDXL", gr.Checkbox).info("Enable Clip skip for the secondary clip model in sdxl. Has no effect on SD 1.5 or SD 2.0/2.1."), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}, infotext="Clip skip").link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}, infotext="RNG").info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), @@ -171,6 +192,10 @@ "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"), })) +options_templates.update(options_section(('sd3', "Stable Diffusion 3", "sd"), { + "sd3_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"), +})) + options_templates.update(options_section(('vae', "VAE", "sd"), { "sd_vae_explanation": OptionHTML(""" VAE is a neural network that transforms a standard RGB @@ -194,7 +219,6 @@ "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"), "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}), - "img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_reload_ui(), "img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_reload_ui(), "img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_reload_ui(), "img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(), @@ -206,14 +230,15 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"), { "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}), - "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), + "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}, infotext='NGMS').link("PR", "https://github.com/AUTOMATIC1111/stablediffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), + "s_min_uncond_all": OptionInfo(False, "Negative Guidance minimum sigma all steps", infotext='NGMS all steps').info("By default, NGMS above skips every other step; this makes it skip all steps"), "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"), "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio hr').info("only applies if non-zero and overrides above"), "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"), "pad_cond_uncond_v0": OptionInfo(False, "Pad prompt/negative prompt (v0)", infotext='Pad conds v0').info("alternative implementation for the above; used prior to 1.6.0 for DDIM sampler; overrides the above if set; WARNING: truncates negative prompt if it's too long; changes seeds"), "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), - "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), + "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond commandline argument"), "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."), })) @@ -224,10 +249,10 @@ "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."), "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), - "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), "use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"), - "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod") + "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod"), + "refiner_switch_by_sample_steps": OptionInfo(False, "Switch to refiner by sampling steps instead of model timesteps. Old behavior for refiner.", infotext="Refiner switch by sampling steps") })) options_templates.update(options_section(('interrogate', "Interrogate"), { @@ -257,7 +282,9 @@ "extra_networks_card_description_is_html": OptionInfo(False, "Treat card description as HTML"), "extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(), "extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(), - "extra_networks_tree_view_default_enabled": OptionInfo(False, "Enables the Extra Networks directory tree view by default").needs_reload_ui(), + "extra_networks_tree_view_style": OptionInfo("Dirs", "Extra Networks directory view style", gr.Radio, {"choices": ["Tree", "Dirs"]}).needs_reload_ui(), + "extra_networks_tree_view_default_enabled": OptionInfo(True, "Show the Extra Networks directory view by default").needs_reload_ui(), + "extra_networks_tree_view_default_width": OptionInfo(180, "Default width for the Extra Networks directory tree view", gr.Number).needs_reload_ui(), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(), "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"), @@ -311,6 +338,8 @@ "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), + "enable_reloading_ui_scripts": OptionInfo(False, "Reload UI scripts when using Reload UI option").info("useful for developing: if you make changes to UI scripts code, it is applied when the UI is reloded."), + })) @@ -351,6 +380,7 @@ "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"), "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"), "js_live_preview_in_modal_lightbox": OptionInfo(False, "Show Live preview in full page image viewer"), + "prevent_screen_sleep_during_generation": OptionInfo(True, "Prevent screen sleep during generation"), })) options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), { @@ -362,22 +392,25 @@ 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'), 's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'), - 'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"), 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule min sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"), 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule max sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"), 'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}, infotext='ENSD').info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"), 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma", infotext='Discard penultimate sigma').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"), - 'sgm_noise_multiplier': OptionInfo(False, "SGM noise multiplier", infotext='SGM noise multplier').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818").info("Match initial noise to official SDXL implementation - only useful for reproducing images"), + 'sgm_noise_multiplier': OptionInfo(False, "SGM noise multiplier", infotext='SGM noise multiplier').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818").info("Match initial noise to official SDXL implementation - only useful for reproducing images"), 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}, infotext='UniPC variant'), 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'), 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'), - 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models") + 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models"), + 'skip_early_cond': OptionInfo(0.0, "Ignore negative prompt during early sampling", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext="Skip Early CFG").info("disables CFG on a proportion of steps at the beginning of generation; 0=skip none; 1=skip all; can both improve sample diversity/quality and speed up sampling"), + 'beta_dist_alpha': OptionInfo(0.6, "Beta scheduler - alpha", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Beta scheduler alpha').info('Default = 0.6; the alpha parameter of the beta distribution used in Beta sampling'), + 'beta_dist_beta': OptionInfo(0.6, "Beta scheduler - beta", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Beta scheduler beta').info('Default = 0.6; the beta parameter of the beta distribution used in Beta sampling'), })) options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), { 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), + 'postprocessing_disable_in_extras': OptionInfo([], "Disable postprocessing operations in extras tab", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), 'postprocessing_existing_caption_action': OptionInfo("Ignore", "Action for existing captions", gr.Radio, {"choices": ["Ignore", "Keep", "Prepend", "Append"]}).info("when generating captions using postprocessing; Ignore = use generated; Keep = use original; Prepend/Append = combine both"), diff --git a/modules/shared_state.py b/modules/shared_state.py index 5da5c7a06..bdbb47147 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -169,5 +169,7 @@ def do_set_current_image(self): @torch.inference_mode() def assign_current_image(self, image): + if shared.opts.live_previews_image_format == 'jpeg' and image.mode in ('RGBA', 'P'): + image = image.convert('RGB') self.current_image = image self.id_live_preview += 1 diff --git a/modules/styles.py b/modules/styles.py index 60bd8a7fb..25f22d3dd 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -1,3 +1,4 @@ +from __future__ import annotations from pathlib import Path from modules import errors import csv @@ -42,7 +43,7 @@ def extract_style_text_from_prompt(style_text, prompt): stripped_style_text = style_text.strip() if "{prompt}" in stripped_style_text: - left, right = stripped_style_text.split("{prompt}", 2) + left, _, right = stripped_style_text.partition("{prompt}") if stripped_prompt.startswith(left) and stripped_prompt.endswith(right): prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)] return True, prompt diff --git a/modules/sysinfo.py b/modules/sysinfo.py index f336251e4..e9a83d74e 100644 --- a/modules/sysinfo.py +++ b/modules/sysinfo.py @@ -1,15 +1,13 @@ import json import os import sys - +import subprocess import platform import hashlib -import pkg_resources -import psutil import re +from pathlib import Path -import launch -from modules import paths_internal, timer, shared, extensions, errors +from modules import paths_internal, timer, shared_cmd_options, errors, launch_utils checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY" environment_whitelist = { @@ -69,14 +67,46 @@ def check(x): return h.hexdigest() == m.group(1) -def get_dict(): - ram = psutil.virtual_memory() +def get_cpu_info(): + cpu_info = {"model": platform.processor()} + try: + import psutil + cpu_info["count logical"] = psutil.cpu_count(logical=True) + cpu_info["count physical"] = psutil.cpu_count(logical=False) + except Exception as e: + cpu_info["error"] = str(e) + return cpu_info + + +def get_ram_info(): + try: + import psutil + ram = psutil.virtual_memory() + return {x: pretty_bytes(getattr(ram, x, 0)) for x in ["total", "used", "free", "active", "inactive", "buffers", "cached", "shared"] if getattr(ram, x, 0) != 0} + except Exception as e: + return str(e) + + +def get_packages(): + try: + return subprocess.check_output([sys.executable, '-m', 'pip', 'freeze', '--all']).decode("utf8").splitlines() + except Exception as pip_error: + try: + import importlib.metadata + packages = importlib.metadata.distributions() + return sorted([f"{package.metadata['Name']}=={package.version}" for package in packages]) + except Exception as e2: + return {'error pip': pip_error, 'error importlib': str(e2)} + +def get_dict(): + config = get_config() res = { "Platform": platform.platform(), "Python": platform.python_version(), - "Version": launch.git_tag(), - "Commit": launch.commit_hash(), + "Version": launch_utils.git_tag(), + "Commit": launch_utils.commit_hash(), + "Git status": git_status(paths_internal.script_path), "Script path": paths_internal.script_path, "Data path": paths_internal.data_path, "Extensions dir": paths_internal.extensions_dir, @@ -84,20 +114,14 @@ def get_dict(): "Commandline": get_argv(), "Torch env info": get_torch_sysinfo(), "Exceptions": errors.get_exceptions(), - "CPU": { - "model": platform.processor(), - "count logical": psutil.cpu_count(logical=True), - "count physical": psutil.cpu_count(logical=False), - }, - "RAM": { - x: pretty_bytes(getattr(ram, x, 0)) for x in ["total", "used", "free", "active", "inactive", "buffers", "cached", "shared"] if getattr(ram, x, 0) != 0 - }, - "Extensions": get_extensions(enabled=True), - "Inactive extensions": get_extensions(enabled=False), + "CPU": get_cpu_info(), + "RAM": get_ram_info(), + "Extensions": get_extensions(enabled=True, fallback_disabled_extensions=config.get('disabled_extensions', [])), + "Inactive extensions": get_extensions(enabled=False, fallback_disabled_extensions=config.get('disabled_extensions', [])), "Environment": get_environment(), - "Config": get_config(), + "Config": config, "Startup": timer.startup_record, - "Packages": sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set]), + "Packages": get_packages(), } return res @@ -111,11 +135,11 @@ def get_argv(): res = [] for v in sys.argv: - if shared.cmd_opts.gradio_auth and shared.cmd_opts.gradio_auth == v: + if shared_cmd_options.cmd_opts.gradio_auth and shared_cmd_options.cmd_opts.gradio_auth == v: res.append("") continue - if shared.cmd_opts.api_auth and shared.cmd_opts.api_auth == v: + if shared_cmd_options.cmd_opts.api_auth and shared_cmd_options.cmd_opts.api_auth == v: res.append("") continue @@ -123,6 +147,7 @@ def get_argv(): return res + re_newline = re.compile(r"\r*\n") @@ -136,25 +161,55 @@ def get_torch_sysinfo(): return str(e) -def get_extensions(*, enabled): +def run_git(path, *args): + try: + return subprocess.check_output([launch_utils.git, '-C', path, *args], shell=False, encoding='utf8').strip() + except Exception as e: + return str(e) + + +def git_status(path): + if (Path(path) / '.git').is_dir(): + return run_git(paths_internal.script_path, 'status') + +def get_info_from_repo_path(path: Path): + is_repo = (path / '.git').is_dir() + return { + 'name': path.name, + 'path': str(path), + 'commit': run_git(path, 'rev-parse', 'HEAD') if is_repo else None, + 'branch': run_git(path, 'branch', '--show-current') if is_repo else None, + 'remote': run_git(path, 'remote', 'get-url', 'origin') if is_repo else None, + } + + +def get_extensions(*, enabled, fallback_disabled_extensions=None): try: - def to_json(x: extensions.Extension): - return { - "name": x.name, - "path": x.path, - "version": x.version, - "branch": x.branch, - "remote": x.remote, - } - - return [to_json(x) for x in extensions.extensions if not x.is_builtin and x.enabled == enabled] + from modules import extensions + if extensions.extensions: + def to_json(x: extensions.Extension): + return { + "name": x.name, + "path": x.path, + "commit": x.commit_hash, + "branch": x.branch, + "remote": x.remote, + } + return [to_json(x) for x in extensions.extensions if not x.is_builtin and x.enabled == enabled] + else: + return [get_info_from_repo_path(d) for d in Path(paths_internal.extensions_dir).iterdir() if d.is_dir() and enabled != (str(d.name) in fallback_disabled_extensions)] except Exception as e: return str(e) def get_config(): try: + from modules import shared return shared.opts.data - except Exception as e: - return str(e) + except Exception as _: + try: + with open(shared_cmd_options.cmd_opts.ui_settings_file, 'r') as f: + return json.load(f) + except Exception as e: + return str(e) diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index e223a2e0c..ca858ef4c 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -65,7 +65,7 @@ def crop_image(im, settings): rect[3] -= 1 d.rectangle(rect, outline=GREEN) results.append(im_debug) - if settings.destop_view_image: + if settings.desktop_view_image: im_debug.show() return results @@ -341,5 +341,5 @@ def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, en self.entropy_points_weight = entropy_points_weight self.face_points_weight = face_points_weight self.annotate_image = annotate_image - self.destop_view_image = False + self.desktop_view_image = False self.dnn_model_path = dnn_model_path diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 7ee050615..71c032df7 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -2,7 +2,6 @@ import numpy as np import PIL import torch -from PIL import Image from torch.utils.data import Dataset, DataLoader, Sampler from torchvision import transforms from collections import defaultdict @@ -10,7 +9,7 @@ import random import tqdm -from modules import devices, shared +from modules import devices, shared, images import re from ldm.modules.distributions.distributions import DiagonalGaussianDistribution @@ -61,7 +60,7 @@ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_to if shared.state.interrupted: raise Exception("interrupted") try: - image = Image.open(path) + image = images.read(path) #Currently does not work for single color transparency #We would need to read image.info['transparency'] for that if use_weight and 'A' in image.getbands(): diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index 81cff7bf1..eac0f9760 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -1,12 +1,16 @@ import base64 import json +import os.path import warnings +import logging import numpy as np import zlib from PIL import Image, ImageDraw import torch +logger = logging.getLogger(__name__) + class EmbeddingEncoder(json.JSONEncoder): def default(self, obj): @@ -43,7 +47,7 @@ def lcg(m=2**32, a=1664525, c=1013904223, seed=0): def xor_block(block): g = lcg() - randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape) + randblock = np.array([next(g) for _ in range(np.prod(block.shape))]).astype(np.uint8).reshape(block.shape) return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F) @@ -114,7 +118,7 @@ def extract_image_data_embed(image): outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0) if black_cols[0].shape[0] < 2: - print('No Image data blocks found.') + logger.debug(f'{os.path.basename(getattr(image, "filename", "unknown image file"))}: no embedded information found.') return None data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8) @@ -193,11 +197,11 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t embedded_image = insert_image_data_embed(cap_image, test_embed) - retrived_embed = extract_image_data_embed(embedded_image) + retrieved_embed = extract_image_data_embed(embedded_image) - assert str(retrived_embed) == str(test_embed) + assert str(retrieved_embed) == str(test_embed) - embedded_image2 = insert_image_data_embed(cap_image, retrived_embed) + embedded_image2 = insert_image_data_embed(cap_image, retrieved_embed) assert embedded_image == embedded_image2 diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/saving_settings.py similarity index 96% rename from modules/textual_inversion/logging.py rename to modules/textual_inversion/saving_settings.py index 45823eb11..953051409 100644 --- a/modules/textual_inversion/logging.py +++ b/modules/textual_inversion/saving_settings.py @@ -1,64 +1,64 @@ -import datetime -import json -import os - -saved_params_shared = { - "batch_size", - "clip_grad_mode", - "clip_grad_value", - "create_image_every", - "data_root", - "gradient_step", - "initial_step", - "latent_sampling_method", - "learn_rate", - "log_directory", - "model_hash", - "model_name", - "num_of_dataset_images", - "steps", - "template_file", - "training_height", - "training_width", -} -saved_params_ti = { - "embedding_name", - "num_vectors_per_token", - "save_embedding_every", - "save_image_with_stored_embedding", -} -saved_params_hypernet = { - "activation_func", - "add_layer_norm", - "hypernetwork_name", - "layer_structure", - "save_hypernetwork_every", - "use_dropout", - "weight_init", -} -saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet -saved_params_previews = { - "preview_cfg_scale", - "preview_height", - "preview_negative_prompt", - "preview_prompt", - "preview_sampler_index", - "preview_seed", - "preview_steps", - "preview_width", -} - - -def save_settings_to_file(log_directory, all_params): - now = datetime.datetime.now() - params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")} - - keys = saved_params_all - if all_params.get('preview_from_txt2img'): - keys = keys | saved_params_previews - - params.update({k: v for k, v in all_params.items() if k in keys}) - - filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json' - with open(os.path.join(log_directory, filename), "w") as file: - json.dump(params, file, indent=4) +import datetime +import json +import os + +saved_params_shared = { + "batch_size", + "clip_grad_mode", + "clip_grad_value", + "create_image_every", + "data_root", + "gradient_step", + "initial_step", + "latent_sampling_method", + "learn_rate", + "log_directory", + "model_hash", + "model_name", + "num_of_dataset_images", + "steps", + "template_file", + "training_height", + "training_width", +} +saved_params_ti = { + "embedding_name", + "num_vectors_per_token", + "save_embedding_every", + "save_image_with_stored_embedding", +} +saved_params_hypernet = { + "activation_func", + "add_layer_norm", + "hypernetwork_name", + "layer_structure", + "save_hypernetwork_every", + "use_dropout", + "weight_init", +} +saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet +saved_params_previews = { + "preview_cfg_scale", + "preview_height", + "preview_negative_prompt", + "preview_prompt", + "preview_sampler_index", + "preview_seed", + "preview_steps", + "preview_width", +} + + +def save_settings_to_file(log_directory, all_params): + now = datetime.datetime.now() + params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")} + + keys = saved_params_all + if all_params.get('preview_from_txt2img'): + keys = keys | saved_params_previews + + params.update({k: v for k, v in all_params.items() if k in keys}) + + filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json' + with open(os.path.join(log_directory, filename), "w") as file: + json.dump(params, file, indent=4) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 6d815c0b3..dc7833e93 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -17,7 +17,7 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay -from modules.textual_inversion.logging import save_settings_to_file +from modules.textual_inversion.saving_settings import save_settings_to_file TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"]) @@ -172,7 +172,7 @@ def load_from_file(self, path, filename): if data: name = data.get('name', name) else: - # if data is None, means this is not an embeding, just a preview image + # if data is None, means this is not an embedding, just a preview image return elif ext in ['.BIN', '.PT']: data = torch.load(path, map_location="cpu") @@ -181,12 +181,16 @@ def load_from_file(self, path, filename): else: return - embedding = create_embedding_from_data(data, name, filename=filename, filepath=path) + if data is not None: + embedding = create_embedding_from_data(data, name, filename=filename, filepath=path) - if self.expected_shape == -1 or self.expected_shape == embedding.shape: - self.register_embedding(embedding, shared.sd_model) + if self.expected_shape == -1 or self.expected_shape == embedding.shape: + self.register_embedding(embedding, shared.sd_model) + else: + self.skipped_embeddings[name] = embedding else: - self.skipped_embeddings[name] = embedding + print(f"Unable to load Textual inversion embedding due to data issue: '{name}'.") + def load_from_dir(self, embdir): if not os.path.isdir(embdir.path): diff --git a/modules/torch_utils.py b/modules/torch_utils.py index e5b52393e..5ea3da094 100644 --- a/modules/torch_utils.py +++ b/modules/torch_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import torch.nn +import torch def get_param(model) -> torch.nn.Parameter: @@ -15,3 +16,10 @@ def get_param(model) -> torch.nn.Parameter: return param raise ValueError(f"No parameters found in model {model!r}") + + +def float64(t: torch.Tensor): + """return torch.float64 if device is not mps or xpu, else return torch.float32""" + if t.device.type in ['mps', 'xpu']: + return torch.float32 + return torch.float64 diff --git a/modules/txt2img.py b/modules/txt2img.py index 04d62a0ac..9e3d7378e 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -12,7 +12,7 @@ from modules_forge import main_thread -def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False): +def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_scheduler: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False): override_settings = create_override_settings_dict(override_settings_texts) if force_enable_hr: @@ -25,10 +25,8 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne prompt=prompt, styles=prompt_styles, negative_prompt=negative_prompt, - sampler_name=sampler_name, batch_size=batch_size, n_iter=n_iter, - steps=steps, cfg_scale=cfg_scale, width=width, height=height, @@ -41,6 +39,7 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne hr_resize_y=hr_resize_y, hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name, hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name, + hr_scheduler=None if hr_scheduler == 'Use same scheduler' else hr_scheduler, hr_prompt=hr_prompt, hr_negative_prompt=hr_negative_prompt, override_settings=override_settings, diff --git a/modules/ui.py b/modules/ui.py index 300930aa7..b7200fb87 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -8,11 +8,11 @@ import gradio as gr import gradio.utils -import numpy as np +from gradio.components.image_editor import Brush from PIL import Image, PngImagePlugin # noqa: F401 -from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call +from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call, wrap_gradio_call_no_job # noqa: F401 -from modules import gradio_extensons # noqa: F401 +from modules import gradio_extensions, sd_schedulers # noqa: F401 from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow from modules.paths import script_path @@ -29,18 +29,22 @@ from modules import prompt_parser from modules.sd_hijack import model_hijack from modules.infotext_utils import image_from_url_text, PasteField +from modules_forge.forge_canvas.canvas import ForgeCanvas, canvas_head + create_setting_component = ui_settings.create_setting_component warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) -warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning) +warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gradio_extensions.GradioDeprecationWarning) # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() mimetypes.add_type('application/javascript', '.js') +mimetypes.add_type('application/javascript', '.mjs') # Likewise, add explicit content-type header for certain missing image types mimetypes.add_type('image/webp', '.webp') +mimetypes.add_type('image/avif', '.avif') if not cmd_opts.share and not cmd_opts.listen: # fix gradio phoning home @@ -99,8 +103,8 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz def resize_from_to_html(width, height, scale_by): - target_width = int(width * scale_by) - target_height = int(height * scale_by) + target_width = int(float(width) * scale_by) + target_height = int(float(height) * scale_by) if not target_width or not target_height: return "no image selected" @@ -109,10 +113,11 @@ def resize_from_to_html(width, height, scale_by): def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles): - if mode in {0, 1, 3, 4}: + mode = int(mode) + if mode in (0, 1, 3, 4): return [interrogation_function(ii_singles[mode]), None] elif mode == 2: - return [interrogation_function(ii_singles[mode]["image"]), None] + return [interrogation_function(ii_singles[mode]), None] elif mode == 5: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" images = shared.listfiles(ii_input_dir) @@ -235,19 +240,6 @@ def create_output_panel(tabname, outdir, toprow=None): return ui_common.create_output_panel(tabname, outdir, toprow) -def create_sampler_and_steps_selection(choices, tabname): - if opts.samplers_in_dropdown: - with FormRow(elem_id=f"sampler_selection_{tabname}"): - sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0]) - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - else: - with FormGroup(elem_id=f"sampler_selection_{tabname}"): - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - sampler_name = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0]) - - return steps, sampler_name - - def ordered_ui_categories(): user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)} @@ -275,13 +267,17 @@ def create_ui(): parameters_copypaste.reset() + settings = ui_settings.UiSettings() + settings.register_settings() + scripts.scripts_current = scripts.scripts_txt2img scripts.scripts_txt2img.initialize_scripts(is_img2img=False) - with gr.Blocks(analytics_enabled=False) as txt2img_interface: + with gr.Blocks(analytics_enabled=False, head=canvas_head) as txt2img_interface: toprow = ui_toprow.Toprow(is_img2img=False, is_compact=shared.opts.compact_prompt_box) - dummy_component = gr.Label(visible=False) + dummy_component = gr.Textbox(visible=False) + dummy_component_number = gr.Number(visible=False) extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs", elem_classes=["extra-networks"]) extra_tabs.__enter__() @@ -298,9 +294,6 @@ def create_ui(): if category == "prompt": toprow.create_inline_toprow_prompts() - if category == "sampler": - steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img") - elif category == "dimensions": with FormRow(): with gr.Column(elem_id="txt2img_column_size", scale=4): @@ -327,7 +320,7 @@ def create_ui(): with gr.Row(elem_id="txt2img_accordions", elem_classes="accordions"): with InputAccordion(False, label="Hires. fix", elem_id="txt2img_hr") as enable_hr: with enable_hr.extra(): - hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0) + hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution") with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"): hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) @@ -341,10 +334,11 @@ def create_ui(): with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container: - hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") + hr_checkpoint_name = gr.Dropdown(label='Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler") + hr_scheduler = gr.Dropdown(label='Hires schedule type', elem_id="hr_scheduler", choices=["Use same scheduler"] + [x.label for x in sd_schedulers.schedulers], value="Use same scheduler") with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: with gr.Column(scale=80): @@ -399,8 +393,6 @@ def create_ui(): toprow.prompt, toprow.negative_prompt, toprow.ui_styles.dropdown, - steps, - sampler_name, batch_count, batch_size, cfg_scale, @@ -415,6 +407,7 @@ def create_ui(): hr_resize_y, hr_checkpoint_name, hr_sampler_name, + hr_scheduler, hr_prompt, hr_negative_prompt, override_settings, @@ -429,7 +422,7 @@ def create_ui(): txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), - _js="submit", + js=f"(...args) => {{ return submit(args.slice(0, {len(txt2img_inputs)})); }}", inputs=txt2img_inputs, outputs=txt2img_outputs, show_progress=False, @@ -438,10 +431,11 @@ def create_ui(): toprow.prompt.submit(**txt2img_args) toprow.submit.click(**txt2img_args) + txt2img_upscale_inputs = txt2img_inputs[0:1] + [output_panel.gallery, dummy_component_number, output_panel.generation_info] + txt2img_inputs[1:] output_panel.button_upscale.click( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']), - _js="submit_txt2img_upscale", - inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component, output_panel.generation_info] + txt2img_inputs[1:], + js=f"(...args) => {{ return submit_txt2img_upscale(args.slice(0, {len(txt2img_upscale_inputs)})); }}", + inputs=txt2img_upscale_inputs, outputs=txt2img_outputs, show_progress=False, ) @@ -464,8 +458,6 @@ def create_ui(): txt2img_paste_fields = [ PasteField(toprow.prompt, "Prompt", api="prompt"), PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"), - PasteField(steps, "Steps", api="steps"), - PasteField(sampler_name, "Sampler", api="sampler_name"), PasteField(cfg_scale, "CFG scale", api="cfg_scale"), PasteField(width, "Size-1", api="width"), PasteField(height, "Size-2", api="height"), @@ -479,8 +471,9 @@ def create_ui(): PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"), PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"), PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"), - PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"), - PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), + PasteField(hr_sampler_name, sd_samplers.get_hr_sampler_from_infotext, api="hr_sampler_name"), + PasteField(hr_scheduler, sd_samplers.get_hr_scheduler_from_infotext, api="hr_scheduler"), + PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" or d.get("Hires schedule type", "Use same scheduler") != "Use same scheduler" else gr.update()), PasteField(hr_prompt, "Hires prompt", api="hr_prompt"), PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"), PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), @@ -491,11 +484,13 @@ def create_ui(): paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None, )) + steps = scripts.scripts_txt2img.script('Sampler').steps + txt2img_preview_params = [ toprow.prompt, toprow.negative_prompt, steps, - sampler_name, + scripts.scripts_txt2img.script('Sampler').sampler_name, cfg_scale, scripts.scripts_txt2img.script('Seed').seed, width, @@ -515,7 +510,7 @@ def create_ui(): scripts.scripts_current = scripts.scripts_img2img scripts.scripts_img2img.initialize_scripts(is_img2img=True) - with gr.Blocks(analytics_enabled=False) as img2img_interface: + with gr.Blocks(analytics_enabled=False, head=canvas_head) as img2img_interface: toprow = ui_toprow.Toprow(is_img2img=True, is_compact=shared.opts.compact_prompt_box) extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs", elem_classes=["extra-networks"]) @@ -532,9 +527,7 @@ def create_ui(): def add_copy_image_controls(tab_name, elem): with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"): - gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}") - - for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']): + for title, name in zip(['to img2img', 'to sketch', 'to inpaint', 'to inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']): if name == tab_name: gr.Button(title, interactive=False) copy_image_destinations[name] = elem @@ -554,48 +547,45 @@ def add_copy_image_controls(tab_name, elem): img2img_selected_tab = gr.Number(value=0, visible=False) with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height) + init_img = ForgeCanvas(elem_id="img2img_image", height=512, no_scribbles=True) add_copy_image_controls('img2img', init_img) with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: - sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color) + sketch = ForgeCanvas(elem_id="img2img_sketch", height=512) add_copy_image_controls('sketch', sketch) with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color) + init_img_with_mask = ForgeCanvas(elem_id="img2maskimg", height=512, contrast_scribbles=True) add_copy_image_controls('inpaint', init_img_with_mask) with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: - inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color) - inpaint_color_sketch_orig = gr.State(None) + inpaint_color_sketch = ForgeCanvas(elem_id="inpaint_sketch", height=512) add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state - - inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) - with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask") with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: - hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML( - "

Process images in a directory on the same machine where the server is running." + - "
Use an empty output directory to save pictures normally instead of writing to the output directory." + - f"
Add inpaint batch mask directory to enable inpaint batch processing." - f"{hidden}

" - ) - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") - img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") + with gr.Tabs(elem_id="img2img_batch_source"): + img2img_batch_source_type = gr.Textbox(visible=False, value="upload") + with gr.TabItem('Upload', id='batch_upload', elem_id="img2img_batch_upload_tab") as tab_batch_upload: + img2img_batch_upload = gr.Files(label="Files", interactive=True, elem_id="img2img_batch_upload") + with gr.TabItem('From directory', id='batch_from_dir', elem_id="img2img_batch_from_dir_tab") as tab_batch_from_dir: + hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' + gr.HTML( + "

Process images in a directory on the same machine where the server is running." + + "
Use an empty output directory to save pictures normally instead of writing to the output directory." + + f"
Add inpaint batch mask directory to enable inpaint batch processing." + f"{hidden}

" + ) + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") + tab_batch_upload.select(fn=lambda: "upload", inputs=[], outputs=[img2img_batch_source_type]) + tab_batch_from_dir.select(fn=lambda: "from dir", inputs=[], outputs=[img2img_batch_source_type]) with gr.Accordion("PNG info", open=False): - img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") + img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", elem_id="img2img_batch_use_png_info") img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") @@ -604,20 +594,14 @@ def update_orig(image, state): for i, tab in enumerate(img2img_tabs): tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab]) - def copy_image(img): - if isinstance(img, dict) and 'image' in img: - return img['image'] - - return img - for button, name, elem in copy_image_buttons: button.click( - fn=copy_image, - inputs=[elem], - outputs=[copy_image_destinations[name]], + fn=lambda img: img, + inputs=[elem.background], + outputs=[copy_image_destinations[name].background], ) button.click( - fn=lambda: None, + fn=None, _js=f"switch_to_{name.replace(' ', '_')}", inputs=[], outputs=[], @@ -626,16 +610,13 @@ def copy_image(img): with FormRow(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - if category == "sampler": - steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img") - elif category == "dimensions": with FormRow(): with gr.Column(elem_id="img2img_column_size", scale=4): selected_scale_tab = gr.Number(value=0, visible=False) - with gr.Tabs(): - with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to: + with gr.Tabs(elem_id="img2img_tabs_resize"): + with gr.Tab(label="Resize to", id="to", elem_id="img2img_tab_resize_to") as tab_scale_to: with FormRow(): with gr.Column(elem_id="img2img_column_size", scale=4): width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") @@ -644,7 +625,7 @@ def copy_image(img): res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn", tooltip="Switch width/height") detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn", tooltip="Auto detect size from img2img") - with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by: + with gr.Tab(label="Resize by", id="by", elem_id="img2img_tab_resize_by") as tab_scale_by: scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") with FormRow(): @@ -723,12 +704,6 @@ def copy_image(img): if category not in {"accordions"}: scripts.scripts_img2img.setup_ui_for_section(category) - # the code below is meant to update the resolution label after the image in the image selection UI has changed. - # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests. - # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs. - for component in [init_img, sketch]: - component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) - def select_img2img_tab(tab): return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), @@ -741,48 +716,52 @@ def select_img2img_tab(tab): output_panel = create_output_panel("img2img", opts.outdir_img2img_samples, toprow) + submit_img2img_inputs = [ + dummy_component, + img2img_selected_tab, + toprow.prompt, + toprow.negative_prompt, + toprow.ui_styles.dropdown, + init_img.background, + sketch.background, + sketch.foreground, + init_img_with_mask.background, + init_img_with_mask.foreground, + inpaint_color_sketch.background, + inpaint_color_sketch.foreground, + init_img_inpaint, + init_mask_inpaint, + mask_blur, + mask_alpha, + inpainting_fill, + batch_count, + batch_size, + cfg_scale, + image_cfg_scale, + denoising_strength, + selected_scale_tab, + height, + width, + scale_by, + resize_mode, + inpaint_full_res, + inpaint_full_res_padding, + inpainting_mask_invert, + img2img_batch_input_dir, + img2img_batch_output_dir, + img2img_batch_inpaint_mask_dir, + override_settings, + img2img_batch_use_png_info, + img2img_batch_png_info_props, + img2img_batch_png_info_dir, + img2img_batch_source_type, + img2img_batch_upload, + ] + custom_inputs + img2img_args = dict( fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), - _js="submit_img2img", - inputs=[ - dummy_component, - dummy_component, - toprow.prompt, - toprow.negative_prompt, - toprow.ui_styles.dropdown, - init_img, - sketch, - init_img_with_mask, - inpaint_color_sketch, - inpaint_color_sketch_orig, - init_img_inpaint, - init_mask_inpaint, - steps, - sampler_name, - mask_blur, - mask_alpha, - inpainting_fill, - batch_count, - batch_size, - cfg_scale, - image_cfg_scale, - denoising_strength, - selected_scale_tab, - height, - width, - scale_by, - resize_mode, - inpaint_full_res, - inpaint_full_res_padding, - inpainting_mask_invert, - img2img_batch_input_dir, - img2img_batch_output_dir, - img2img_batch_inpaint_mask_dir, - override_settings, - img2img_batch_use_png_info, - img2img_batch_png_info_props, - img2img_batch_png_info_dir, - ] + custom_inputs, + js=f"(...args) => {{ return submit_img2img(args.slice(0, {len(submit_img2img_inputs)})); }}", + inputs=submit_img2img_inputs, outputs=[ output_panel.gallery, output_panel.generation_info, @@ -798,10 +777,10 @@ def select_img2img_tab(tab): dummy_component, img2img_batch_input_dir, img2img_batch_output_dir, - init_img, - sketch, - init_img_with_mask, - inpaint_color_sketch, + init_img.background, + sketch.background, + init_img_with_mask.background, + inpaint_color_sketch.background, init_img_inpaint, ], outputs=[toprow.prompt, dummy_component], @@ -813,9 +792,9 @@ def select_img2img_tab(tab): res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False) detect_image_size_btn.click( - fn=lambda w, h, _: (w or gr.update(), h or gr.update()), + fn=lambda w, h: (w or gr.update(), h or gr.update()), _js="currentImg2imgSourceResolution", - inputs=[dummy_component, dummy_component, dummy_component], + inputs=[dummy_component, dummy_component], outputs=[width, height], show_progress=False, ) @@ -843,6 +822,8 @@ def select_img2img_tab(tab): **interrogate_args, ) + steps = scripts.scripts_img2img.script('Sampler').steps + toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter]) toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter]) toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter]) @@ -851,8 +832,6 @@ def select_img2img_tab(tab): img2img_paste_fields = [ (toprow.prompt, "Prompt"), (toprow.negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_name, "Sampler"), (cfg_scale, "CFG scale"), (image_cfg_scale, "Image CFG scale"), (width, "Size-1"), @@ -867,8 +846,8 @@ def select_img2img_tab(tab): (inpaint_full_res_padding, 'Masked area padding'), *scripts.scripts_img2img.infotext_fields ] - parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) - parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) + parameters_copypaste.add_paste_fields("img2img", init_img.background, img2img_paste_fields, override_settings) + parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask.background, img2img_paste_fields, override_settings) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None, )) @@ -880,10 +859,10 @@ def select_img2img_tab(tab): scripts.scripts_current = None - with gr.Blocks(analytics_enabled=False) as extras_interface: + with gr.Blocks(analytics_enabled=False, head=canvas_head) as extras_interface: ui_postprocessing.create_ui() - with gr.Blocks(analytics_enabled=False) as pnginfo_interface: + with gr.Blocks(analytics_enabled=False, head=canvas_head) as pnginfo_interface: with ResizeHandleRow(equal_height=False): with gr.Column(variant='panel'): image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") @@ -901,14 +880,14 @@ def select_img2img_tab(tab): )) image.change( - fn=wrap_gradio_call(modules.extras.run_pnginfo), + fn=wrap_gradio_call_no_job(modules.extras.run_pnginfo), inputs=[image], outputs=[html, generation_info, html2], ) modelmerger_ui = ui_checkpoint_merger.UiCheckpointMerger() - with gr.Blocks(analytics_enabled=False) as train_interface: + with gr.Blocks(analytics_enabled=False, head=canvas_head) as train_interface: with gr.Row(equal_height=False): gr.HTML(value="

See wiki for detailed explanation.

") @@ -1007,7 +986,7 @@ def get_textual_inversion_template_names(): with gr.Column(elem_id='ti_gallery_container'): ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery', columns=4) + gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery', columns=4, object_fit="contain") gr.HTML(elem_id="ti_progress", value="") ti_outcome = gr.HTML(elem_id="ti_error", value="") @@ -1122,7 +1101,6 @@ def get_textual_inversion_template_names(): loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file) ui_settings_from_file = loadsave.ui_settings.copy() - settings = ui_settings.UiSettings() settings.create_ui(loadsave, dummy_component) interfaces = [ @@ -1144,7 +1122,7 @@ def get_textual_inversion_template_names(): for _interface, label, _ifid in interfaces: shared.tab_names.append(label) - with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo: + with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion", head=canvas_head) as demo: settings.add_quicksettings() parameters_copypaste.connect_paste_params_buttons() diff --git a/modules/ui_common.py b/modules/ui_common.py index fc5d6e3f4..58c4e27e2 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -3,14 +3,11 @@ import json import html import os -import platform -import sys +from contextlib import nullcontext import gradio as gr -import subprocess as sp -from modules import call_queue, shared, ui_tempdir -from modules.infotext_utils import image_from_url_text +from modules import call_queue, shared, ui_tempdir, util import modules.images from modules.ui_components import ToolButton import modules.infotext_utils as parameters_copypaste @@ -105,21 +102,20 @@ def __init__(self, d=None): logfile_path = os.path.join(shared.opts.outdir_save, "log.csv") # NOTE: ensure csv integrity when fields are added by - # updating headers and padding with delimeters where needed - if os.path.exists(logfile_path): + # updating headers and padding with delimiters where needed + if shared.opts.save_write_log_csv and os.path.exists(logfile_path): update_logfile(logfile_path, fields) - with open(logfile_path, "a", encoding="utf8", newline='') as file: - at_start = file.tell() == 0 - writer = csv.writer(file) - if at_start: - writer.writerow(fields) + with (open(logfile_path, "a", encoding="utf8", newline='') if shared.opts.save_write_log_csv else nullcontext()) as file: + if file: + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(fields) for image_index, filedata in enumerate(images, start_index): - image = image_from_url_text(filedata) - + image = filedata[0] is_grid = image_index < p.index_of_first_image - p.batch_index = image_index-1 parameters = parameters_copypaste.parse_generation_parameters(data["infotexts"][image_index], []) @@ -133,7 +129,8 @@ def __init__(self, d=None): filenames.append(os.path.basename(txt_fullfn)) fullfns.append(txt_fullfn) - writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], parsed_infotexts[0]['Negative prompt'], data["sd_model_name"], data["sd_model_hash"]]) + if file: + writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], parsed_infotexts[0]['Negative prompt'], data["sd_model_name"], data["sd_model_hash"]]) # Make Zip if do_make_zip: @@ -176,31 +173,7 @@ def open_folder(f, images=None, index=None): except Exception: pass - if not os.path.exists(f): - msg = f'Folder "{f}" does not exist. After you create an image, the folder will be created.' - print(msg) - gr.Info(msg) - return - elif not os.path.isdir(f): - msg = f""" -WARNING -An open_folder request was made with an argument that is not a folder. -This could be an error or a malicious attempt to run code on your computer. -Requested path was: {f} -""" - print(msg, file=sys.stderr) - gr.Warning(msg) - return - - path = os.path.normpath(f) - if platform.system() == "Windows": - os.startfile(path) - elif platform.system() == "Darwin": - sp.Popen(["open", path]) - elif "microsoft-standard-WSL2" in platform.uname().release: - sp.Popen(["wsl-open", path]) - else: - sp.Popen(["xdg-open", path]) + util.open_folder(f) with gr.Column(elem_id=f"{tabname}_results"): if toprow: @@ -208,7 +181,7 @@ def open_folder(f, images=None, index=None): with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"): with gr.Group(elem_id=f"{tabname}_gallery_container"): - res.gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None, object_fit='contain') + res.gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None, interactive=False, type="pil", object_fit="contain") with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"): open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.") @@ -220,8 +193,7 @@ def open_folder(f, images=None, index=None): buttons = { 'img2img': ToolButton('🖼️', elem_id=f'{tabname}_send_to_img2img', tooltip="Send image and generation parameters to img2img tab."), 'inpaint': ToolButton('🎨️', elem_id=f'{tabname}_send_to_inpaint', tooltip="Send image and generation parameters to img2img inpaint tab."), - 'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab."), - 'svd': ToolButton('🎬', elem_id=f'{tabname}_send_to_svd', tooltip="Send image and generation parameters to SVD tab."), + 'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.") } if tabname == 'txt2img': @@ -256,7 +228,7 @@ def open_folder(f, images=None, index=None): ) save.click( - fn=call_queue.wrap_gradio_call(save_files), + fn=call_queue.wrap_gradio_call_no_job(save_files), _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", inputs=[ res.generation_info, @@ -272,7 +244,7 @@ def open_folder(f, images=None, index=None): ) save_zip.click( - fn=call_queue.wrap_gradio_call(save_files), + fn=call_queue.wrap_gradio_call_no_job(save_files), _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", inputs=[ res.generation_info, diff --git a/modules/ui_components.py b/modules/ui_components.py index 55979f626..6d213ce4f 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -1,7 +1,12 @@ +from functools import wraps + import gradio as gr +from modules import gradio_extensions # noqa: F401 class FormComponent: + webui_do_not_create_gradio_pyi_thank_you = True + def get_expected_parent(self): return gr.components.Form @@ -9,12 +14,13 @@ def get_expected_parent(self): gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent -class ToolButton(FormComponent, gr.Button): +class ToolButton(gr.Button, FormComponent): """Small button with single emoji as text, fits inside gradio forms""" - def __init__(self, *args, **kwargs): - classes = kwargs.pop("elem_classes", []) - super().__init__(*args, elem_classes=["tool", *classes], **kwargs) + @wraps(gr.Button.__init__) + def __init__(self, value="", *args, elem_classes=None, **kwargs): + elem_classes = elem_classes or [] + super().__init__(*args, elem_classes=["tool", *elem_classes], value=value, **kwargs) def get_block_name(self): return "button" @@ -22,7 +28,9 @@ def get_block_name(self): class ResizeHandleRow(gr.Row): """Same as gr.Row but fits inside gradio forms""" + webui_do_not_create_gradio_pyi_thank_you = True + @wraps(gr.Row.__init__) def __init__(self, **kwargs): super().__init__(**kwargs) @@ -32,79 +40,92 @@ def get_block_name(self): return "row" -class FormRow(FormComponent, gr.Row): +class FormRow(gr.Row, FormComponent): """Same as gr.Row but fits inside gradio forms""" def get_block_name(self): return "row" -class FormColumn(FormComponent, gr.Column): +class FormColumn(gr.Column, FormComponent): """Same as gr.Column but fits inside gradio forms""" def get_block_name(self): return "column" -class FormGroup(FormComponent, gr.Group): +class FormGroup(gr.Group, FormComponent): """Same as gr.Group but fits inside gradio forms""" def get_block_name(self): return "group" -class FormHTML(FormComponent, gr.HTML): +class FormHTML(gr.HTML, FormComponent): """Same as gr.HTML but fits inside gradio forms""" def get_block_name(self): return "html" -class FormColorPicker(FormComponent, gr.ColorPicker): +class FormColorPicker(gr.ColorPicker, FormComponent): """Same as gr.ColorPicker but fits inside gradio forms""" def get_block_name(self): return "colorpicker" -class DropdownMulti(FormComponent, gr.Dropdown): +class DropdownMulti(gr.Dropdown, FormComponent): """Same as gr.Dropdown but always multiselect""" + + @wraps(gr.Dropdown.__init__) def __init__(self, **kwargs): - super().__init__(multiselect=True, **kwargs) + kwargs['multiselect'] = True + super().__init__(**kwargs) def get_block_name(self): return "dropdown" -class DropdownEditable(FormComponent, gr.Dropdown): +class DropdownEditable(gr.Dropdown, FormComponent): """Same as gr.Dropdown but allows editing value""" + + @wraps(gr.Dropdown.__init__) def __init__(self, **kwargs): - super().__init__(allow_custom_value=True, **kwargs) + kwargs['allow_custom_value'] = True + super().__init__(**kwargs) def get_block_name(self): return "dropdown" -class InputAccordion(gr.Checkbox): +class InputAccordionImpl(gr.Checkbox): """A gr.Accordion that can be used as an input - returns True if open, False if closed. - Actaully just a hidden checkbox, but creates an accordion that follows and is followed by the state of the checkbox. + Actually just a hidden checkbox, but creates an accordion that follows and is followed by the state of the checkbox. """ + webui_do_not_create_gradio_pyi_thank_you = True + global_index = 0 - def __init__(self, value, **kwargs): + @wraps(gr.Checkbox.__init__) + def __init__(self, value=None, setup=False, **kwargs): + if not setup: + super().__init__(value=value, **kwargs) + return + self.accordion_id = kwargs.get('elem_id') if self.accordion_id is None: - self.accordion_id = f"input-accordion-{InputAccordion.global_index}" - InputAccordion.global_index += 1 + self.accordion_id = f"input-accordion-{InputAccordionImpl.global_index}" + InputAccordionImpl.global_index += 1 kwargs_checkbox = { **kwargs, "elem_id": f"{self.accordion_id}-checkbox", "visible": False, } - super().__init__(value, **kwargs_checkbox) + super().__init__(value=value, **kwargs_checkbox) self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self]) @@ -115,6 +136,7 @@ def __init__(self, value, **kwargs): "elem_classes": ['input-accordion'], "open": value, } + self.accordion = gr.Accordion(**kwargs_accordion) def extra(self): @@ -143,3 +165,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): def get_block_name(self): return "checkbox" + +def InputAccordion(value=None, **kwargs): + return InputAccordionImpl(value=value, setup=True, **kwargs) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index a24ea32ef..bbf5c113f 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -58,8 +58,9 @@ def apply_and_restart(disable_list, update_list, disable_all): def save_config_state(name): current_config_state = config_states.get_config() - if not name: - name = "Config" + + name = os.path.basename(name or "Config") + current_config_state["name"] = name timestamp = datetime.now().strftime('%Y_%m_%d-%H_%M_%S') filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json") @@ -380,7 +381,7 @@ def install_extension_from_url(dirname, url, branch_name=None): except OSError as err: if err.errno == errno.EXDEV: # Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems - # Since we can't use a rename, do the slower but more versitile shutil.move() + # Since we can't use a rename, do the slower but more versatile shutil.move() shutil.move(tmpdir, target_dir) else: # Something else, not enough free space, permissions, etc. rethrow it so that it gets handled. @@ -395,15 +396,15 @@ def install_extension_from_url(dirname, url, branch_name=None): shutil.rmtree(tmpdir, True) -def install_extension_from_index(url, hide_tags, sort_column, filter_text): +def install_extension_from_index(url, selected_tags, showing_type, filtering_type, sort_column, filter_text): ext_table, message = install_extension_from_url(None, url) - code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text) + code, _ = refresh_available_extensions_from_data(selected_tags, showing_type, filtering_type, sort_column, filter_text) return code, ext_table, message, '' -def refresh_available_extensions(url, hide_tags, sort_column): +def refresh_available_extensions(url, selected_tags, showing_type, filtering_type, sort_column): global available_extensions import urllib.request @@ -412,19 +413,19 @@ def refresh_available_extensions(url, hide_tags, sort_column): available_extensions = json.loads(text) - code, tags = refresh_available_extensions_from_data(hide_tags, sort_column) + code, tags = refresh_available_extensions_from_data(selected_tags, showing_type, filtering_type, sort_column) return url, code, gr.CheckboxGroup.update(choices=tags), '', '' -def refresh_available_extensions_for_tags(hide_tags, sort_column, filter_text): - code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text) +def refresh_available_extensions_for_tags(selected_tags, showing_type, filtering_type, sort_column, filter_text): + code, _ = refresh_available_extensions_from_data(selected_tags, showing_type, filtering_type, sort_column, filter_text) return code, '' -def search_extensions(filter_text, hide_tags, sort_column): - code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text) +def search_extensions(filter_text, selected_tags, showing_type, filtering_type, sort_column): + code, _ = refresh_available_extensions_from_data(selected_tags, showing_type, filtering_type, sort_column, filter_text) return code, '' @@ -449,13 +450,13 @@ def get_date(info: dict, key): return '' -def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""): +def refresh_available_extensions_from_data(selected_tags, showing_type, filtering_type, sort_column, filter_text=""): extlist = available_extensions["extensions"] installed_extensions = {extension.name for extension in extensions.extensions} installed_extension_urls = {normalize_git_url(extension.remote) for extension in extensions.extensions if extension.remote is not None} tags = available_extensions.get("tags", {}) - tags_to_hide = set(hide_tags) + selected_tags = set(selected_tags) hidden = 0 code = f""" @@ -488,9 +489,19 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" existing = get_extension_dirname_from_url(url) in installed_extensions or normalize_git_url(url) in installed_extension_urls extension_tags = extension_tags + ["installed"] if existing else extension_tags - if any(x for x in extension_tags if x in tags_to_hide): - hidden += 1 - continue + if len(selected_tags) > 0: + matched_tags = [x for x in extension_tags if x in selected_tags] + if filtering_type == 'or': + need_hide = len(matched_tags) > 0 + else: + need_hide = len(matched_tags) == len(selected_tags) + + if showing_type == 'show': + need_hide = not need_hide + + if need_hide: + hidden += 1 + continue if filter_text and filter_text.strip(): if filter_text.lower() not in html.escape(name).lower() and filter_text.lower() not in html.escape(description).lower(): @@ -593,8 +604,12 @@ def create_ui(): install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) with gr.Row(): - hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) - sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index") + selected_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Extension tags", choices=["script", "ads", "localization", "installed"], elem_classes=['compact-checkbox-group']) + sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index", elem_classes=['compact-checkbox-group']) + + with gr.Row(): + showing_type = gr.Radio(value="hide", label="Showing type", choices=["hide", "show"], elem_classes=['compact-checkbox-group']) + filtering_type = gr.Radio(value="or", label="Filtering type", choices=["or", "and"], elem_classes=['compact-checkbox-group']) with gr.Row(): search_extensions_text = gr.Text(label="Search", container=False) @@ -604,31 +619,43 @@ def create_ui(): refresh_available_extensions_button.click( fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]), - inputs=[available_extensions_index, hide_tags, sort_column], - outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result], + inputs=[available_extensions_index, selected_tags, showing_type, filtering_type, sort_column], + outputs=[available_extensions_index, available_extensions_table, selected_tags, search_extensions_text, install_result], ) install_extension_button.click( - fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]), - inputs=[extension_to_install, hide_tags, sort_column, search_extensions_text], + fn=modules.ui.wrap_gradio_call_no_job(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]), + inputs=[extension_to_install, selected_tags, showing_type, filtering_type, sort_column, search_extensions_text], outputs=[available_extensions_table, extensions_table, install_result], ) search_extensions_text.change( - fn=modules.ui.wrap_gradio_call(search_extensions, extra_outputs=[gr.update()]), - inputs=[search_extensions_text, hide_tags, sort_column], + fn=modules.ui.wrap_gradio_call_no_job(search_extensions, extra_outputs=[gr.update()]), + inputs=[search_extensions_text, selected_tags, showing_type, filtering_type, sort_column], outputs=[available_extensions_table, install_result], ) - hide_tags.change( - fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), - inputs=[hide_tags, sort_column, search_extensions_text], + selected_tags.change( + fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text], + outputs=[available_extensions_table, install_result] + ) + + showing_type.change( + fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text], + outputs=[available_extensions_table, install_result] + ) + + filtering_type.change( + fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text], outputs=[available_extensions_table, install_result] ) sort_column.change( - fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), - inputs=[hide_tags, sort_column, search_extensions_text], + fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text], outputs=[available_extensions_table, install_result] ) @@ -640,7 +667,7 @@ def create_ui(): install_result = gr.HTML(elem_id="extension_install_result") install_button.click( - fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]), + fn=modules.ui.wrap_gradio_call_no_job(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]), inputs=[install_dirname, install_url, install_branch], outputs=[install_url, extensions_table, install_result], ) @@ -661,7 +688,7 @@ def create_ui(): config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info]) - dummy_component = gr.Label(visible=False) + dummy_component = gr.State() config_restore_button.click(fn=restore_config_state, _js="config_state_confirm_restore", inputs=[dummy_component, config_states_list, config_restore_type], outputs=[config_states_info]) config_states_list.change( diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 34c46ed40..395549bfb 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -1,6 +1,8 @@ import functools import os.path import urllib.parse +from base64 import b64decode +from io import BytesIO from pathlib import Path from typing import Optional, Union from dataclasses import dataclass @@ -11,6 +13,7 @@ import json import html from fastapi.exceptions import HTTPException +from PIL import Image from modules.infotext_utils import image_from_url_text @@ -108,6 +111,31 @@ def fetch_file(filename: str = ""): return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) +def fetch_cover_images(page: str = "", item: str = "", index: int = 0): + from starlette.responses import Response + + page = next(iter([x for x in extra_pages if x.name == page]), None) + if page is None: + raise HTTPException(status_code=404, detail="File not found") + + metadata = page.metadata.get(item) + if metadata is None: + raise HTTPException(status_code=404, detail="File not found") + + cover_images = json.loads(metadata.get('ssmd_cover_images', {})) + image = cover_images[index] if index < len(cover_images) else None + if not image: + raise HTTPException(status_code=404, detail="File not found") + + try: + image = Image.open(BytesIO(b64decode(image))) + buffer = BytesIO() + image.save(buffer, format=image.format) + return Response(content=buffer.getvalue(), media_type=image.get_format_mimetype()) + except Exception as err: + raise ValueError(f"File cannot be fetched: {item}. Failed to load cover image.") from err + + def get_metadata(page: str = "", item: str = ""): from starlette.responses import JSONResponse @@ -119,6 +147,8 @@ def get_metadata(page: str = "", item: str = ""): if metadata is None: return JSONResponse({}) + metadata = {i:metadata[i] for i in metadata if i != 'ssmd_cover_images'} # those are cover images, and they are too big to display in UI as text + return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)}) @@ -142,6 +172,7 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""): def add_pages_to_demo(app): app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"]) + app.add_api_route("/sd_extra_networks/cover-images", fetch_cover_images, methods=["GET"]) app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"]) app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"]) @@ -151,6 +182,7 @@ def quote_js(s): s = s.replace('"', '\\"') return f'"{s}"' + class ExtraNetworksPage: def __init__(self, title): self.title = title @@ -164,6 +196,8 @@ def __init__(self, title): self.lister = util.MassFileLister() # HTML Templates self.pane_tpl = shared.html("extra-networks-pane.html") + self.pane_content_tree_tpl = shared.html("extra-networks-pane-tree.html") + self.pane_content_dirs_tpl = shared.html("extra-networks-pane-dirs.html") self.card_tpl = shared.html("extra-networks-card.html") self.btn_tree_tpl = shared.html("extra-networks-tree-button.html") self.btn_copy_path_tpl = shared.html("extra-networks-copy-path-button.html") @@ -243,14 +277,12 @@ def create_item_html( btn_metadata = self.btn_metadata_tpl.format( **{ "extra_networks_tabname": self.extra_networks_tabname, - "name": html.escape(item["name"]), } ) btn_edit_item = self.btn_edit_item_tpl.format( **{ "tabname": tabname, "extra_networks_tabname": self.extra_networks_tabname, - "name": html.escape(item["name"]), } ) @@ -476,6 +508,47 @@ def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> Optional return f"
    {res}
" + def create_dirs_view_html(self, tabname: str) -> str: + """Generates HTML for displaying folders.""" + + subdirs = {} + for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: + for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])): + for dirname in sorted(dirs, key=shared.natural_sort_key): + x = os.path.join(root, dirname) + + if not os.path.isdir(x): + continue + + subdir = os.path.abspath(x)[len(parentdir):] + + if shared.opts.extra_networks_dir_button_function: + if not subdir.startswith(os.path.sep): + subdir = os.path.sep + subdir + else: + while subdir.startswith(os.path.sep): + subdir = subdir[1:] + + is_empty = len(os.listdir(x)) == 0 + if not is_empty and not subdir.endswith(os.path.sep): + subdir = subdir + os.path.sep + + if (os.path.sep + "." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories: + continue + + subdirs[subdir] = 1 + + if subdirs: + subdirs = {"": 1, **subdirs} + + subdirs_html = "".join([f""" + + """ for subdir in subdirs]) + + return subdirs_html + def create_card_view_html(self, tabname: str, *, none_message) -> str: """Generates HTML for the network Card View section for a tab. @@ -489,15 +562,15 @@ def create_card_view_html(self, tabname: str, *, none_message) -> str: Returns: HTML formatted string. """ - res = "" + res = [] for item in self.items.values(): - res += self.create_item_html(tabname, item, self.card_tpl) + res.append(self.create_item_html(tabname, item, self.card_tpl)) - if res == "": + if not res: dirs = "".join([f"
  • {x}
  • " for x in self.allowed_directories_for_previews()]) - res = none_message or shared.html("extra-networks-no-cards.html").format(dirs=dirs) + res = [none_message or shared.html("extra-networks-no-cards.html").format(dirs=dirs)] - return res + return "".join(res) def create_html(self, tabname, *, empty=False): """Generates an HTML string for the current pane. @@ -526,28 +599,28 @@ def create_html(self, tabname, *, empty=False): if "user_metadata" not in item: self.read_user_metadata(item) - data_sortdir = shared.opts.extra_networks_card_order - data_sortmode = shared.opts.extra_networks_card_order_field.lower().replace("sort", "").replace(" ", "_").rstrip("_").strip() - data_sortkey = f"{data_sortmode}-{data_sortdir}-{len(self.items)}" - tree_view_btn_extra_class = "" - tree_view_div_extra_class = "hidden" - if shared.opts.extra_networks_tree_view_default_enabled: - tree_view_btn_extra_class = "extra-network-control--enabled" - tree_view_div_extra_class = "" + show_tree = shared.opts.extra_networks_tree_view_default_enabled - return self.pane_tpl.format( - **{ - "tabname": tabname, - "extra_networks_tabname": self.extra_networks_tabname, - "data_sortmode": data_sortmode, - "data_sortkey": data_sortkey, - "data_sortdir": data_sortdir, - "tree_view_btn_extra_class": tree_view_btn_extra_class, - "tree_view_div_extra_class": tree_view_div_extra_class, - "tree_html": self.create_tree_view_html(tabname), - "items_html": self.create_card_view_html(tabname, none_message="Loading..." if empty else None), - } - ) + page_params = { + "tabname": tabname, + "extra_networks_tabname": self.extra_networks_tabname, + "data_sortdir": shared.opts.extra_networks_card_order, + "sort_path_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Path' else '', + "sort_name_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Name' else '', + "sort_date_created_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Date Created' else '', + "sort_date_modified_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Date Modified' else '', + "tree_view_btn_extra_class": "extra-network-control--enabled" if show_tree else "", + "items_html": self.create_card_view_html(tabname, none_message="Loading..." if empty else None), + "extra_networks_tree_view_default_width": shared.opts.extra_networks_tree_view_default_width, + "tree_view_div_default_display_class": "" if show_tree else "extra-network-dirs-hidden", + } + + if shared.opts.extra_networks_tree_view_style == "Tree": + pane_content = self.pane_content_tree_tpl.format(**page_params, tree_html=self.create_tree_view_html(tabname)) + else: + pane_content = self.pane_content_dirs_tpl.format(**page_params, dirs_html=self.create_dirs_view_html(tabname)) + + return self.pane_tpl.format(**page_params, pane_content=pane_content) def create_item(self, name, index=None): raise NotImplementedError() @@ -584,6 +657,17 @@ def find_preview(self, path): return None + def find_embedded_preview(self, path, name, metadata): + """ + Find if embedded preview exists in safetensors metadata and return endpoint for it. + """ + + file = f"{path}.safetensors" + if self.lister.exists(file) and 'ssmd_cover_images' in metadata and len(list(filter(None, json.loads(metadata['ssmd_cover_images'])))) > 0: + return f"./sd_extra_networks/cover-images?page={self.extra_networks_tabname}&item={name}" + + return None + def find_description(self, path): """ Find and read a description file for a given path (without extension). @@ -609,10 +693,10 @@ def initialize(): def register_default_pages(): from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion - from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks + # from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints register_page(ExtraNetworksPageTextualInversion()) - register_page(ExtraNetworksPageHypernetworks()) + # register_page(ExtraNetworksPageHypernetworks()) register_page(ExtraNetworksPageCheckpoints()) @@ -666,9 +750,11 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): elem_id = f"{tabname}_{page.extra_networks_tabname}_cards_html" page_elem = gr.HTML(page.create_html(tabname, empty=True), elem_id=elem_id) ui.pages.append(page_elem) + editor = page.create_user_metadata_editor(ui, tabname) editor.create_ui() ui.user_metadata_editors.append(editor) + related_tabs.append(tab) ui.button_save_preview = gr.Button('Save preview', elem_id=f"{tabname}_save_preview", visible=False) @@ -693,7 +779,7 @@ def refresh(): return ui.pages_contents button_refresh = gr.Button("Refresh", elem_id=f"{tabname}_{page.extra_networks_tabname}_extra_refresh_internal", visible=False) - button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages).then(fn=lambda: None, _js="function(){ " + f"applyExtraNetworkFilter('{tabname}_{page.extra_networks_tabname}');" + " }") + button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages).then(fn=lambda: None, _js="function(){ " + f"applyExtraNetworkFilter('{tabname}_{page.extra_networks_tabname}');" + " }").then(fn=lambda: None, _js='setupAllResizeHandles') def create_html(): ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] @@ -703,7 +789,7 @@ def pages_html(): create_html() return ui.pages_contents - interface.load(fn=pages_html, inputs=[], outputs=ui.pages) + interface.load(fn=pages_html, inputs=[], outputs=ui.pages).then(fn=lambda: None, _js='setupAllResizeHandles') return ui diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 2ca937fd1..3a07db105 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -133,8 +133,10 @@ def write_user_metadata(self, name, metadata): filename = item.get("filename", None) basename, ext = os.path.splitext(filename) - with open(basename + '.json', "w", encoding="utf8") as file: + metadata_path = basename + '.json' + with open(metadata_path, "w", encoding="utf8") as file: json.dump(metadata, file, indent=4, ensure_ascii=False) + self.page.lister.update_file_entry(metadata_path) def save_user_metadata(self, name, desc, notes): user_metadata = self.get_user_metadata(name) @@ -185,13 +187,14 @@ def save_preview(self, index, gallery, name): geninfo, items = images.read_info_from_image(image) images.save_image_with_geninfo(image, geninfo, item["local_preview"]) - + self.page.lister.update_file_entry(item["local_preview"]) + item['preview'] = self.page.find_preview(item["local_preview"]) return self.get_card_html(name), '' def setup_ui(self, gallery): self.button_replace_preview.click( fn=self.save_preview, - _js="function(x, y, z){return [selected_gallery_index(), y, z]}", + _js=f"function(x, y, z){{return [selected_gallery_index_id('{self.tabname + '_gallery_container'}'), y, z]}}", inputs=[self.edit_name_input, gallery, self.edit_name_input], outputs=[self.html_preview, self.html_status] ).then( @@ -200,6 +203,3 @@ def setup_ui(self, gallery): inputs=[self.edit_name_input], outputs=[] ) - - - diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py index f5278d22f..ed57c1e98 100644 --- a/modules/ui_gradio_extensions.py +++ b/modules/ui_gradio_extensions.py @@ -41,6 +41,11 @@ def stylesheet(fn): if os.path.exists(user_css): head += stylesheet(user_css) + from modules.shared_gradio_themes import resolve_var + light = resolve_var('background_fill_primary') + dark = resolve_var('background_fill_primary_dark') + head += f'' + return head @@ -50,7 +55,7 @@ def reload_javascript(): def template_response(*args, **kwargs): res = shared.GradioTemplateResponseOriginal(*args, **kwargs) - res.body = res.body.replace(b'', f'{js}'.encode("utf8")) + res.body = res.body.replace(b'', f'{js}'.encode("utf8")) res.body = res.body.replace(b'', f'{css}'.encode("utf8")) res.init_headers() return res diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py index 2555cdb6c..0cc1ab82a 100644 --- a/modules/ui_loadsave.py +++ b/modules/ui_loadsave.py @@ -104,6 +104,8 @@ def check_dropdown(val): apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None)) if type(x) == InputAccordion: + if hasattr(x, 'custom_script_source'): + x.accordion.custom_script_source = x.custom_script_source if x.accordion.visible: apply_field(x.accordion, 'visible') apply_field(x, 'value') diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 7261c2df8..7a33ca8f0 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -2,17 +2,18 @@ from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow import modules.infotext_utils as parameters_copypaste from modules.ui_components import ResizeHandleRow +from modules_forge.forge_canvas.canvas import ForgeCanvas def create_ui(): - dummy_component = gr.Label(visible=False) - tab_index = gr.Number(value=0, visible=False) + dummy_component = gr.Textbox(visible=False) + tab_index = gr.State(value=0) with ResizeHandleRow(equal_height=False, variant='compact'): with gr.Column(variant='compact'): with gr.Tabs(elem_id="mode_extras"): with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single: - extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") + extras_image = ForgeCanvas(elem_id="extras_image", height=512, no_scribbles=True).background with gr.TabItem('Batch Process', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch: image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch") @@ -35,19 +36,21 @@ def create_ui(): tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index]) tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) + submit_click_inputs = [ + dummy_component, + tab_index, + extras_image, + image_batch, + extras_batch_input_dir, + extras_batch_output_dir, + show_extras_results, + *script_inputs + ] + submit.click( fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing_webui, extra_outputs=[None, '']), - _js="submit_extras", - inputs=[ - dummy_component, - tab_index, - extras_image, - image_batch, - extras_batch_input_dir, - extras_batch_output_dir, - show_extras_results, - *script_inputs - ], + js=f"(...args) => {{ return submit_extras(args.slice(0, {len(submit_click_inputs)})); }}", + inputs=submit_click_inputs, outputs=[ output_panel.gallery, output_panel.generation_info, diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py index d67e3f17e..f71b40c41 100644 --- a/modules/ui_prompt_styles.py +++ b/modules/ui_prompt_styles.py @@ -67,7 +67,7 @@ def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): with gr.Row(): self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.") ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles") - self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply_dialog", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.") + self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply_dialog", tooltip="Apply all selected styles from the style selection dropdown in main UI to the prompt.") self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.") with gr.Row(): diff --git a/modules/ui_settings.py b/modules/ui_settings.py index f2576dc56..e750d3714 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -1,7 +1,8 @@ import gradio as gr -from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer -from modules.call_queue import wrap_gradio_call +from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer, shared_items +from modules.call_queue import wrap_gradio_call_no_job +from modules.options import options_section from modules.shared import opts from modules.ui_components import FormRow from modules.ui_gradio_extensions import reload_javascript @@ -98,6 +99,9 @@ def run_settings_single(self, value, key): return get_value_for_setting(key), opts.dumpjson() + def register_settings(self): + script_callbacks.ui_settings_callback() + def create_ui(self, loadsave, dummy_component): self.components = [] self.component_dict = {} @@ -105,7 +109,11 @@ def create_ui(self, loadsave, dummy_component): shared.settings_components = self.component_dict - script_callbacks.ui_settings_callback() + # we add this as late as possible so that scripts have already registered their callbacks + opts.data_labels.update(options_section(('callbacks', "Callbacks", "system"), { + **shared_items.callbacks_order_settings(), + })) + opts.reorder() with gr.Blocks(analytics_enabled=False) as settings_interface: @@ -287,7 +295,7 @@ def add_quicksettings(self): def add_functionality(self, demo): self.submit.click( - fn=wrap_gradio_call(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]), + fn=wrap_gradio_call_no_job(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]), inputs=self.components, outputs=[self.text_settings, self.result], ) @@ -303,30 +311,20 @@ def add_functionality(self, demo): methods = [component.change] for method in methods: - handler = method( + method( fn=lambda value, k=k: self.run_settings_single(value, key=k), inputs=[component], outputs=[component, self.text_settings], show_progress=False, ) - script_callbacks.setting_updated_event_subscriber_chain( - handler=handler, - component=component, - setting_name=k, - ) button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) - handler = button_set_checkpoint.click( + button_set_checkpoint.click( fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'), _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component], outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings], ) - script_callbacks.setting_updated_event_subscriber_chain( - handler=handler, - component=self.component_dict['sd_model_checkpoint'], - setting_name="sd_model_checkpoint" - ) component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict] diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py index ecd6bdec3..af9601f3a 100644 --- a/modules/ui_tempdir.py +++ b/modules/ui_tempdir.py @@ -4,6 +4,7 @@ from pathlib import Path import gradio.components +import gradio as gr from PIL import PngImagePlugin @@ -13,25 +14,35 @@ Savedfile = namedtuple("Savedfile", ["name"]) -def register_tmp_file(gradio, filename): - if hasattr(gradio, 'temp_file_sets'): # gradio 3.15 - gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)} +def register_tmp_file(gradio_app, filename): + if hasattr(gradio_app, 'temp_file_sets'): # gradio 3.15 + if hasattr(gr.utils, 'abspath'): # gradio 4.19 + filename = gr.utils.abspath(filename) + else: + filename = os.path.abspath(filename) - if hasattr(gradio, 'temp_dirs'): # gradio 3.9 - gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))} + gradio_app.temp_file_sets[0] = gradio_app.temp_file_sets[0] | {filename} + if hasattr(gradio_app, 'temp_dirs'): # gradio 3.9 + gradio_app.temp_dirs = gradio_app.temp_dirs | {os.path.abspath(os.path.dirname(filename))} -def check_tmp_file(gradio, filename): - if hasattr(gradio, 'temp_file_sets'): - return any(filename in fileset for fileset in gradio.temp_file_sets) - if hasattr(gradio, 'temp_dirs'): - return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs) +def check_tmp_file(gradio_app, filename): + if hasattr(gradio_app, 'temp_file_sets'): + if hasattr(gr.utils, 'abspath'): # gradio 4.19 + filename = gr.utils.abspath(filename) + else: + filename = os.path.abspath(filename) + + return any(filename in fileset for fileset in gradio_app.temp_file_sets) + + if hasattr(gradio_app, 'temp_dirs'): + return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio_app.temp_dirs) return False -def save_pil_to_file(self, pil_image, dir=None, format="png"): +def save_pil_to_file(pil_image, cache_dir=None, format="png"): already_saved_as = getattr(pil_image, 'already_saved_as', None) if already_saved_as and os.path.isfile(already_saved_as): register_tmp_file(shared.demo, already_saved_as) @@ -39,9 +50,10 @@ def save_pil_to_file(self, pil_image, dir=None, format="png"): register_tmp_file(shared.demo, filename_with_mtime) return filename_with_mtime - if shared.opts.temp_dir != "": + if shared.opts.temp_dir: dir = shared.opts.temp_dir else: + dir = cache_dir os.makedirs(dir, exist_ok=True) use_metadata = False @@ -56,9 +68,96 @@ def save_pil_to_file(self, pil_image, dir=None, format="png"): return file_obj.name +async def async_move_files_to_cache(data, block, postprocess=False, check_in_upload_folder=False, keep_in_cache=False): + """Move any files in `data` to cache and (optionally), adds URL prefixes (/file=...) needed to access the cached file. + Also handles the case where the file is on an external Gradio app (/proxy=...). + + Runs after .postprocess() and before .preprocess(). + + Copied from gradio's processing_utils.py + + Args: + data: The input or output data for a component. Can be a dictionary or a dataclass + block: The component whose data is being processed + postprocess: Whether its running from postprocessing + check_in_upload_folder: If True, instead of moving the file to cache, checks if the file is in already in cache (exception if not). + keep_in_cache: If True, the file will not be deleted from cache when the server is shut down. + """ + + from gradio import FileData + from gradio.data_classes import GradioRootModel + from gradio.data_classes import GradioModel + from gradio_client import utils as client_utils + from gradio.utils import get_upload_folder, is_in_or_equal, is_static_file + + async def _move_to_cache(d: dict): + payload = FileData(**d) + + # EDITED + payload.path = payload.path.rsplit('?', 1)[0] + + # If the gradio app developer is returning a URL from + # postprocess, it means the component can display a URL + # without it being served from the gradio server + # This makes it so that the URL is not downloaded and speeds up event processing + if payload.url and postprocess and client_utils.is_http_url_like(payload.url): + payload.path = payload.url + elif is_static_file(payload): + pass + elif not block.proxy_url: + # EDITED + if check_tmp_file(shared.demo, payload.path): + temp_file_path = payload.path + else: + # If the file is on a remote server, do not move it to cache. + if check_in_upload_folder and not client_utils.is_http_url_like( + payload.path + ): + path = os.path.abspath(payload.path) + if not is_in_or_equal(path, get_upload_folder()): + raise ValueError( + f"File {path} is not in the upload folder and cannot be accessed." + ) + if not payload.is_stream: + temp_file_path = await block.async_move_resource_to_block_cache( + payload.path + ) + if temp_file_path is None: + raise ValueError("Did not determine a file path for the resource.") + payload.path = temp_file_path + if keep_in_cache: + block.keep_in_cache.add(payload.path) + + url_prefix = "/stream/" if payload.is_stream else "/file=" + if block.proxy_url: + proxy_url = block.proxy_url.rstrip("/") + url = f"/proxy={proxy_url}{url_prefix}{payload.path}" + elif client_utils.is_http_url_like(payload.path) or payload.path.startswith( + f"{url_prefix}" + ): + url = payload.path + else: + url = f"{url_prefix}{payload.path}" + payload.url = url + + return payload.model_dump() + + if isinstance(data, (GradioRootModel, GradioModel)): + data = data.model_dump() + + return await client_utils.async_traverse( + data, _move_to_cache, client_utils.is_file_obj + ) + + def install_ui_tempdir_override(): - """override save to file function so that it also writes PNG info""" - gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file + """ + override save to file function so that it also writes PNG info. + override gradio4's move_files_to_cache function to prevent it from writing a copy into a temporary directory. + """ + + gradio.processing_utils.save_pil_to_cache = save_pil_to_file + gradio.processing_utils.async_move_files_to_cache = async_move_files_to_cache def on_tmpdir_changed(): diff --git a/modules/upscaler.py b/modules/upscaler.py index 0e38d52fb..507881fed 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -7,7 +7,6 @@ import modules.shared from modules import modelloader, shared - LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST) @@ -21,7 +20,7 @@ class Upscaler: filter = None model = None user_path = None - scalers: [] + scalers: list tile = True def __init__(self, create_dirs=False): @@ -57,8 +56,11 @@ def upscale(self, img: PIL.Image, scale, selected_model: str = None): dest_w = int((img.width * scale) // 8 * 8) dest_h = int((img.height * scale) // 8 * 8) - for _ in range(3): - if img.width >= dest_w and img.height >= dest_h: + for i in range(3): + if img.width >= dest_w and img.height >= dest_h and (i > 0 or scale != 1): + break + + if shared.state.interrupted: break shape = (img.width, img.height) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index b5e5a80ca..5ecbbed96 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -69,10 +69,10 @@ def upscale_with_model( for y, h, row in grid.tiles: newrow = [] for x, w, tile in row: - logger.debug("Tile (%d, %d) %s...", x, y, tile) + if shared.state.interrupted: + return img output = upscale_pil_patch(model, tile) scale_factor = output.width // tile.width - logger.debug("=> %s (scale factor %s)", output, scale_factor) newrow.append([x * scale_factor, w * scale_factor, output]) p.update(1) newtiles.append([y * scale_factor, h * scale_factor, newrow]) diff --git a/modules/util.py b/modules/util.py index 8d1aea44f..7911b0db7 100644 --- a/modules/util.py +++ b/modules/util.py @@ -81,6 +81,17 @@ def __init__(self, dirname): self.files = {x[0].lower(): x for x in files} self.files_cased = {x[0]: x for x in files} + def update_entry(self, filename): + """Add a file to the cache""" + file_path = os.path.join(self.dirname, filename) + try: + stat = os.stat(file_path) + entry = (filename, stat.st_mtime, stat.st_ctime) + self.files[filename.lower()] = entry + self.files_cased[filename] = entry + except FileNotFoundError as e: + print(f'MassFileListerCachedDir.add_entry: "{file_path}" {e}') + class MassFileLister: """A class that provides a way to check for the existence and mtime/ctile of files without doing more than one stat call per file.""" @@ -136,3 +147,67 @@ def mctime(self, path): def reset(self): """Clear the cache of all directories.""" self.cached_dirs.clear() + + def update_file_entry(self, path): + """Update the cache for a specific directory.""" + dirname, filename = os.path.split(path) + if cached_dir := self.cached_dirs.get(dirname): + cached_dir.update_entry(filename) + +def topological_sort(dependencies): + """Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies. + Ignores errors relating to missing dependencies or circular dependencies + """ + + visited = {} + result = [] + + def inner(name): + visited[name] = True + + for dep in dependencies.get(name, []): + if dep in dependencies and dep not in visited: + inner(dep) + + result.append(name) + + for depname in dependencies: + if depname not in visited: + inner(depname) + + return result + + +def open_folder(path): + """Open a folder in the file manager of the respect OS.""" + # import at function level to avoid potential issues + import gradio as gr + import platform + import sys + import subprocess + + if not os.path.exists(path): + msg = f'Folder "{path}" does not exist. after you save an image, the folder will be created.' + print(msg) + gr.Info(msg) + return + elif not os.path.isdir(path): + msg = f""" +WARNING +An open_folder request was made with an path that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {path} +""" + print(msg, file=sys.stderr) + gr.Warning(msg) + return + + path = os.path.normpath(path) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + subprocess.Popen(["open", path]) + elif "microsoft-standard-WSL2" in platform.uname().release: + subprocess.Popen(["explorer.exe", subprocess.check_output(["wslpath", "-w", path])]) + else: + subprocess.Popen(["xdg-open", path]) diff --git a/modules_forge/forge_alter_samplers.py b/modules_forge/forge_alter_samplers.py index 4e4822086..8316d322e 100644 --- a/modules_forge/forge_alter_samplers.py +++ b/modules_forge/forge_alter_samplers.py @@ -1,45 +1,22 @@ -import torch from modules import sd_samplers_kdiffusion, sd_samplers_common - from ldm_patched.k_diffusion import sampling as k_diffusion_sampling -from ldm_patched.modules.samplers import calculate_sigmas_scheduler class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler): - def __init__(self, sd_model, sampler_name, scheduler_name): + def __init__(self, sd_model, sampler_name): self.sampler_name = sampler_name - self.scheduler_name = scheduler_name self.unet = sd_model.forge_objects.unet - sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) super().__init__(sampler_function, sd_model, None) - def get_sigmas(self, p, steps): - if self.scheduler_name == 'turbo': - timesteps = torch.flip(torch.arange(1, steps + 1) * float(1000.0 / steps) - 1, (0,)).round().long().clip(0, 999) - sigmas = self.unet.model.model_sampling.sigma(timesteps) - sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) - else: - sigmas = calculate_sigmas_scheduler(self.unet.model, self.scheduler_name, steps) - return sigmas.to(self.unet.load_device) - -def build_constructor(sampler_name, scheduler_name): +def build_constructor(sampler_name): def constructor(m): - return AlterSampler(m, sampler_name, scheduler_name) + return AlterSampler(m, sampler_name) return constructor samplers_data_alter = [ - sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm', scheduler_name='normal'), ['ddpm'], {}), - sd_samplers_common.SamplerData('DDPM Karras', build_constructor(sampler_name='ddpm', scheduler_name='karras'), ['ddpm_karras'], {}), - sd_samplers_common.SamplerData('Euler A Turbo', build_constructor(sampler_name='euler_ancestral', scheduler_name='turbo'), ['euler_ancestral_turbo'], {}), - sd_samplers_common.SamplerData('DPM++ 2M Turbo', build_constructor(sampler_name='dpmpp_2m', scheduler_name='turbo'), ['dpmpp_2m_turbo'], {}), - sd_samplers_common.SamplerData('DPM++ 2M SDE Turbo', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='turbo'), ['dpmpp_2m_sde_turbo'], {}), - sd_samplers_common.SamplerData('LCM Karras', build_constructor(sampler_name='lcm', scheduler_name='karras'), ['lcm_karras'], {}), - sd_samplers_common.SamplerData('Euler SGMUniform', build_constructor(sampler_name='euler', scheduler_name='sgm_uniform'), ['euler_sgm_uniform'], {}), - sd_samplers_common.SamplerData('Euler A SGMUniform', build_constructor(sampler_name='euler_ancestral', scheduler_name='sgm_uniform'), ['euler_ancestral_sgm_uniform'], {}), - sd_samplers_common.SamplerData('DPM++ 2M SGMUniform', build_constructor(sampler_name='dpmpp_2m', scheduler_name='sgm_uniform'), ['dpmpp_2m_sgm_uniform'], {}), - sd_samplers_common.SamplerData('DPM++ 2M SDE SGMUniform', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='sgm_uniform'), ['dpmpp_2m_sde_sgm_uniform'], {}), + sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm'), ['ddpm'], {}), ] diff --git a/modules_forge/forge_canvas/canvas.css b/modules_forge/forge_canvas/canvas.css new file mode 100644 index 000000000..bc32583d5 --- /dev/null +++ b/modules_forge/forge_canvas/canvas.css @@ -0,0 +1,159 @@ +.forge-container { + width: 100%; + height: 512px; + position: relative; + overflow: hidden; +} + +.forge-image-container { + width: 100%; + height: calc(100% - 6px); + position: relative; + overflow: hidden; + background-color: #cccccc; + background-image: linear-gradient(45deg, #eee 25%, transparent 25%, transparent 75%, #eee 75%, #eee), + linear-gradient(45deg, #eee 25%, transparent 25%, transparent 75%, #eee 75%, #eee); + background-size: 20px 20px; + background-position: 0 0, 10px 10px; +} + +.forge-image { + position: absolute; + top: 0; + left: 0; + background-size: contain; + background-repeat: no-repeat; + cursor: grab; + max-width: unset !important; + max-height: unset !important; +} + +.forge-image:active { + cursor: grabbing; +} + +.forge-file-upload { + display: none; +} + +.forge-resize-line { + width: 100%; + height: 6px; + background-image: linear-gradient(to bottom, grey 50%, darkgrey 50%); + background-size: 4px 4px; + background-repeat: repeat; + cursor: ns-resize; + position: absolute; + bottom: 0; + left: 0; +} + +.forge-toolbar { + position: absolute; + top: 0px; + left: 0px; + z-index: 10; + background: rgba(47, 47, 47, 0.8); + padding: 6px 10px; + opacity: 0; + transition: opacity 0.3s ease; +} + +.forge-toolbar .forge-btn { + padding: 2px 6px; + border: none; + background-color: #4a4a4a; + color: white; + font-size: 14px; + cursor: pointer; + transition: background-color 0.3s ease; +} + +.forge-toolbar .forge-btn:hover { + background-color: #5e5e5e; +} + +.forge-toolbar .forge-btn:active { + background-color: #3e3e3e; +} + +.forge-toolbar-box-a { + flex-wrap: wrap; +} + +.forge-toolbar-box-b { + display: flex; + flex-wrap: wrap; + align-items: center; + justify-content: space-between; + gap: 4px; +} + +.forge-color-picker-block { + display: flex; + align-items: center; +} + +.forge-range-row { + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; +} + +.forge-toolbar-color { + border: none; + background: none; + padding: 3px; + border-radius: 50%; + width: 20px; + height: 20px; + -webkit-appearance: none; + appearance: none; + cursor: pointer; +} + +.forge-toolbar-color::-webkit-color-swatch-wrapper { + padding: 0; + border-radius: 50%; +} + +.forge-toolbar-color::-webkit-color-swatch { + border: none; + border-radius: 50%; + background: none; +} + +.forge-toolbar-label { + color: white !important; + padding: 0 4px; + display: flex; + align-items: center; + margin-bottom: 4px; /* Adjust margin as needed */ +} + +.forge-toolbar-range { +} + +.forge-scribble-indicator { + position: relative; + border-radius: 50%; + border: 1px solid; + pointer-events: none; + display: none; + width: 80px; + height: 80px; +} + +.forge-no-select { + user-select: none; +} + +.forge-upload-hint { + position: absolute; + top: 50%; + left: 50%; + width: 30%; + height: 30%; + transform: translate(-50%, -50%); +} diff --git a/modules_forge/forge_canvas/canvas.html b/modules_forge/forge_canvas/canvas.html new file mode 100644 index 000000000..a9554d190 --- /dev/null +++ b/modules_forge/forge_canvas/canvas.html @@ -0,0 +1,63 @@ +
    + +
    +
    + + + + + + + + + + +
    + + +
    +
    + + + + + + + + +
    +
    +
    + +
    +
    +
    brush width
    + +
    +
    +
    brush opacity
    + +
    +
    +
    brush softness
    + +
    +
    +
    +
    +
    +
    +
    diff --git a/modules_forge/forge_canvas/canvas.min.js b/modules_forge/forge_canvas/canvas.min.js new file mode 100644 index 000000000..601014072 --- /dev/null +++ b/modules_forge/forge_canvas/canvas.min.js @@ -0,0 +1 @@ +const _0x374a8c=_0xe5ae;(function(_0xb56795,_0x5457a6){const _0x5032c4=_0xe5ae,_0x36dd93=_0xb56795();while(!![]){try{const _0x8ddbe3=-parseInt(_0x5032c4(0x28c))/0x1*(-parseInt(_0x5032c4(0x240))/0x2)+-parseInt(_0x5032c4(0x27e))/0x3*(-parseInt(_0x5032c4(0x255))/0x4)+parseInt(_0x5032c4(0x205))/0x5+-parseInt(_0x5032c4(0x25b))/0x6+parseInt(_0x5032c4(0x262))/0x7+-parseInt(_0x5032c4(0x270))/0x8*(-parseInt(_0x5032c4(0x1e8))/0x9)+-parseInt(_0x5032c4(0x273))/0xa;if(_0x8ddbe3===_0x5457a6)break;else _0x36dd93['push'](_0x36dd93['shift']());}catch(_0x43c961){_0x36dd93['push'](_0x36dd93['shift']());}}}(_0x3ec7,0xef695));class GradioTextAreaBind{constructor(_0x1be7d4,_0x49c014){const _0x26ce03=_0xe5ae;this[_0x26ce03(0x1e9)]=document[_0x26ce03(0x1ef)]('#'+_0x1be7d4+'.'+_0x49c014+_0x26ce03(0x287)),this['sync_lock']=![],this['previousValue']='';}['set_value'](_0x2f6819){const _0x3acfca=_0xe5ae;if(this[_0x3acfca(0x22d)])return;this[_0x3acfca(0x22d)]=!![],this[_0x3acfca(0x1e9)]['value']=_0x2f6819,this[_0x3acfca(0x208)]=_0x2f6819;let _0x4b8a67=new Event(_0x3acfca(0x204),{'bubbles':!![]});Object[_0x3acfca(0x27b)](_0x4b8a67,'target',{'value':this[_0x3acfca(0x1e9)]}),this['target'][_0x3acfca(0x20c)](_0x4b8a67),this[_0x3acfca(0x208)]=_0x2f6819,this['sync_lock']=![];}[_0x374a8c(0x274)](_0x28b4a1){setInterval(()=>{const _0x3ec347=_0xe5ae;if(this['target'][_0x3ec347(0x1f6)]!==this['previousValue']){this['previousValue']=this[_0x3ec347(0x1e9)][_0x3ec347(0x1f6)];if(this['sync_lock'])return;this[_0x3ec347(0x22d)]=!![],_0x28b4a1(this['target'][_0x3ec347(0x1f6)]),this['sync_lock']=![];}},0x64);}}class ForgeCanvas{constructor(_0x486b67,_0x2e2919=![],_0xe39b25=![],_0x159b35=![],_0x4a8f00=0x200,_0x35abee=_0x374a8c(0x1f3),_0x4d1374=![],_0x372354=0x4,_0x4d2b07=![],_0x1f66ae=0x64,_0x34dc71=![],_0x110907=0x0,_0x55db66=![]){const _0x54ab35=_0x374a8c;this['gradio_config']=gradio_config,this['uuid']=_0x486b67,this[_0x54ab35(0x241)]=_0xe39b25,this[_0x54ab35(0x1cf)]=_0x159b35,this[_0x54ab35(0x1f5)]=_0x2e2919,this[_0x54ab35(0x1f4)]=_0x4a8f00,this[_0x54ab35(0x28d)]=null,this[_0x54ab35(0x250)]=0x0,this[_0x54ab35(0x24e)]=0x0,this[_0x54ab35(0x1d9)]=0x0,this[_0x54ab35(0x24a)]=0x0,this['imgScale']=0x1,this['dragging']=![],this[_0x54ab35(0x261)]=![],this[_0x54ab35(0x1ec)]=![],this[_0x54ab35(0x1fb)]=![],this[_0x54ab35(0x24c)]=_0x35abee,this[_0x54ab35(0x1e0)]=_0x372354,this[_0x54ab35(0x230)]=_0x1f66ae,this[_0x54ab35(0x21d)]=_0x110907,this[_0x54ab35(0x217)]=_0x4d1374,this['scribbleWidthFixed']=_0x4d2b07,this[_0x54ab35(0x1e2)]=_0x34dc71,this[_0x54ab35(0x269)]=_0x55db66,this[_0x54ab35(0x23c)]=[],this[_0x54ab35(0x246)]=-0x1,this['maximized']=![],this[_0x54ab35(0x28a)]={},this['contrast_pattern']=null,this['mouseInsideContainer']=![],this[_0x54ab35(0x26c)]=document['createElement']('canvas'),this[_0x54ab35(0x213)]=[],this['temp_draw_bg']=null,this[_0x54ab35(0x259)]=new GradioTextAreaBind(this[_0x54ab35(0x1dc)],_0x54ab35(0x1d1)),this['foreground_gradio_bind']=new GradioTextAreaBind(this[_0x54ab35(0x1dc)],_0x54ab35(0x1d5)),this[_0x54ab35(0x1eb)]();}[_0x374a8c(0x1eb)](){const _0x4611cb=_0x374a8c;let _0x44277b=this;const _0x540dd1=document[_0x4611cb(0x26d)](_0x4611cb(0x1ca)+_0x44277b['uuid']),_0x6f013c=document[_0x4611cb(0x26d)](_0x4611cb(0x227)+_0x44277b['uuid']),_0x1d67ee=document['getElementById'](_0x4611cb(0x1cd)+_0x44277b[_0x4611cb(0x1dc)]),_0x39d95e=document[_0x4611cb(0x26d)](_0x4611cb(0x282)+_0x44277b['uuid']),_0x3201f4=document[_0x4611cb(0x26d)]('toolbar_'+_0x44277b['uuid']),_0x2ba998=document[_0x4611cb(0x26d)](_0x4611cb(0x263)+_0x44277b[_0x4611cb(0x1dc)]),_0x3c3fac=document[_0x4611cb(0x26d)]('resetButton_'+_0x44277b['uuid']),_0x47f742=document['getElementById'](_0x4611cb(0x21e)+_0x44277b[_0x4611cb(0x1dc)]),_0x718fea=document['getElementById'](_0x4611cb(0x202)+_0x44277b[_0x4611cb(0x1dc)]),_0x31aae5=document[_0x4611cb(0x26d)](_0x4611cb(0x1f7)+_0x44277b[_0x4611cb(0x1dc)]),_0xdd4e04=document[_0x4611cb(0x26d)](_0x4611cb(0x267)+_0x44277b[_0x4611cb(0x1dc)]),_0x3723db=document[_0x4611cb(0x26d)](_0x4611cb(0x237)+_0x44277b[_0x4611cb(0x1dc)]),_0x2fdf34=document[_0x4611cb(0x26d)](_0x4611cb(0x228)+_0x44277b[_0x4611cb(0x1dc)]),_0x323f6d=document[_0x4611cb(0x26d)](_0x4611cb(0x23a)+_0x44277b['uuid']),_0x67434b=document[_0x4611cb(0x26d)](_0x4611cb(0x215)+_0x44277b[_0x4611cb(0x1dc)]),_0x4cb610=document[_0x4611cb(0x26d)](_0x4611cb(0x22b)+_0x44277b[_0x4611cb(0x1dc)]),_0x12201f=document[_0x4611cb(0x26d)](_0x4611cb(0x268)+_0x44277b[_0x4611cb(0x1dc)]),_0x4517ef=document[_0x4611cb(0x26d)](_0x4611cb(0x283)+_0x44277b[_0x4611cb(0x1dc)]),_0xeb2930=document['getElementById']('scribbleWidth_'+_0x44277b[_0x4611cb(0x1dc)]),_0x16a8be=document[_0x4611cb(0x26d)](_0x4611cb(0x234)+_0x44277b['uuid']),_0xd669cd=document[_0x4611cb(0x26d)](_0x4611cb(0x276)+_0x44277b[_0x4611cb(0x1dc)]),_0x4aac00=document[_0x4611cb(0x26d)]('scribbleAlpha_'+_0x44277b[_0x4611cb(0x1dc)]),_0x5e0b2c=document['getElementById'](_0x4611cb(0x1df)+_0x44277b[_0x4611cb(0x1dc)]),_0x102790=document['getElementById'](_0x4611cb(0x258)+_0x44277b[_0x4611cb(0x1dc)]),_0x4cd12a=document[_0x4611cb(0x26d)](_0x4611cb(0x1c9)+_0x44277b[_0x4611cb(0x1dc)]),_0x56a2cd=document[_0x4611cb(0x26d)]('softnessLabel_'+_0x44277b[_0x4611cb(0x1dc)]),_0x4ddd50=document['getElementById'](_0x4611cb(0x222)+_0x44277b[_0x4611cb(0x1dc)]);_0x12201f[_0x4611cb(0x1f6)]=_0x44277b[_0x4611cb(0x24c)],_0xeb2930[_0x4611cb(0x1f6)]=_0x44277b[_0x4611cb(0x1e0)],_0x4aac00[_0x4611cb(0x1f6)]=_0x44277b['scribbleAlpha'],_0x4cd12a['value']=_0x44277b[_0x4611cb(0x21d)];const _0x3f7f50=_0x44277b[_0x4611cb(0x1e0)]*0x14;_0x67434b['style'][_0x4611cb(0x242)]=_0x3f7f50+'px',_0x67434b[_0x4611cb(0x239)][_0x4611cb(0x203)]=_0x3f7f50+'px',_0x39d95e[_0x4611cb(0x239)][_0x4611cb(0x203)]=_0x44277b[_0x4611cb(0x1f4)]+'px',_0x3723db[_0x4611cb(0x242)]=_0x540dd1[_0x4611cb(0x285)],_0x3723db[_0x4611cb(0x203)]=_0x540dd1['clientHeight'];const _0x250102=_0x3723db[_0x4611cb(0x24d)]('2d');_0x44277b[_0x4611cb(0x27f)]=_0x3723db;_0x44277b[_0x4611cb(0x241)]&&(_0x3c3fac['style'][_0x4611cb(0x1fe)]=_0x4611cb(0x272),_0x31aae5[_0x4611cb(0x239)][_0x4611cb(0x1fe)]=_0x4611cb(0x272),_0xdd4e04[_0x4611cb(0x239)][_0x4611cb(0x1fe)]='none',_0x12201f['style'][_0x4611cb(0x1fe)]=_0x4611cb(0x272),_0x4517ef[_0x4611cb(0x239)][_0x4611cb(0x1fe)]=_0x4611cb(0x272),_0xd669cd[_0x4611cb(0x239)][_0x4611cb(0x1fe)]=_0x4611cb(0x272),_0x102790[_0x4611cb(0x239)][_0x4611cb(0x1fe)]=_0x4611cb(0x272),_0x4ddd50['style'][_0x4611cb(0x1fe)]='none',_0x67434b[_0x4611cb(0x239)][_0x4611cb(0x1fe)]=_0x4611cb(0x272),_0x3723db[_0x4611cb(0x239)][_0x4611cb(0x1fe)]=_0x4611cb(0x272));_0x44277b['no_upload']&&(_0x2ba998[_0x4611cb(0x239)]['display']=_0x4611cb(0x272),_0x4cb610[_0x4611cb(0x239)][_0x4611cb(0x1fe)]=_0x4611cb(0x272));if(_0x44277b['contrast_scribbles']){_0x4517ef[_0x4611cb(0x239)][_0x4611cb(0x1fe)]='none',_0x102790[_0x4611cb(0x239)][_0x4611cb(0x1fe)]=_0x4611cb(0x272),_0x4ddd50['style'][_0x4611cb(0x1fe)]=_0x4611cb(0x272);const _0x193153=_0x44277b[_0x4611cb(0x26c)],_0x21d8e9=0xa;_0x193153[_0x4611cb(0x242)]=_0x21d8e9*0x2,_0x193153[_0x4611cb(0x203)]=_0x21d8e9*0x2;const _0x2537fc=_0x193153[_0x4611cb(0x24d)]('2d');_0x2537fc[_0x4611cb(0x1db)]=_0x4611cb(0x21c),_0x2537fc['fillRect'](0x0,0x0,_0x21d8e9,_0x21d8e9),_0x2537fc[_0x4611cb(0x210)](_0x21d8e9,_0x21d8e9,_0x21d8e9,_0x21d8e9),_0x2537fc[_0x4611cb(0x1db)]='#000000',_0x2537fc[_0x4611cb(0x210)](_0x21d8e9,0x0,_0x21d8e9,_0x21d8e9),_0x2537fc[_0x4611cb(0x210)](0x0,_0x21d8e9,_0x21d8e9,_0x21d8e9),_0x44277b[_0x4611cb(0x1ee)]=_0x250102[_0x4611cb(0x229)](_0x193153,_0x4611cb(0x1d0)),_0x3723db[_0x4611cb(0x239)]['opacity']='0.5';}(_0x44277b['contrast_scribbles']||_0x44277b[_0x4611cb(0x217)]&&_0x44277b[_0x4611cb(0x1e2)]&&_0x44277b['scribbleSoftnessFixed'])&&(_0xd669cd['style'][_0x4611cb(0x242)]='100%',_0xeb2930[_0x4611cb(0x239)]['width']=_0x4611cb(0x220),_0x16a8be[_0x4611cb(0x239)][_0x4611cb(0x1fe)]=_0x4611cb(0x272));_0x44277b['scribbleColorFixed']&&(_0x4517ef['style'][_0x4611cb(0x1fe)]=_0x4611cb(0x272));_0x44277b[_0x4611cb(0x21a)]&&(_0xd669cd['style']['display']=_0x4611cb(0x272));_0x44277b[_0x4611cb(0x1e2)]&&(_0x102790[_0x4611cb(0x239)][_0x4611cb(0x1fe)]=_0x4611cb(0x272));_0x44277b['scribbleSoftnessFixed']&&(_0x4ddd50[_0x4611cb(0x239)][_0x4611cb(0x1fe)]=_0x4611cb(0x272));const _0x39af6d=new ResizeObserver(()=>{const _0x109c1b=_0x4611cb;_0x44277b[_0x109c1b(0x209)](),_0x44277b[_0x109c1b(0x26e)]();});_0x39af6d[_0x4611cb(0x26f)](_0x39d95e),document[_0x4611cb(0x26d)](_0x4611cb(0x23d)+_0x44277b['uuid'])[_0x4611cb(0x1de)](_0x4611cb(0x260),function(_0x529663){const _0x49a8fe=_0x4611cb;_0x44277b[_0x49a8fe(0x206)](_0x529663[_0x49a8fe(0x1e9)][_0x49a8fe(0x20b)][0x0]);}),_0x2ba998['addEventListener'](_0x4611cb(0x23b),function(){const _0x1fedee=_0x4611cb;if(_0x44277b[_0x1fedee(0x1f5)])return;document[_0x1fedee(0x26d)]('imageInput_'+_0x44277b[_0x1fedee(0x1dc)])[_0x1fedee(0x23b)]();}),_0x3c3fac['addEventListener'](_0x4611cb(0x23b),function(){const _0x378ac5=_0x4611cb;_0x44277b[_0x378ac5(0x26b)]();}),_0x47f742['addEventListener'](_0x4611cb(0x23b),function(){const _0x2c9eba=_0x4611cb;_0x44277b['adjustInitialPositionAndScale'](),_0x44277b[_0x2c9eba(0x26e)]();}),_0x718fea[_0x4611cb(0x1de)](_0x4611cb(0x23b),function(){const _0x44bc8b=_0x4611cb;_0x44277b[_0x44bc8b(0x1e4)]();}),_0x31aae5[_0x4611cb(0x1de)](_0x4611cb(0x23b),function(){const _0x595d80=_0x4611cb;_0x44277b[_0x595d80(0x252)]();}),_0xdd4e04['addEventListener'](_0x4611cb(0x23b),function(){const _0x5e57d4=_0x4611cb;_0x44277b[_0x5e57d4(0x1c6)]();}),_0x12201f[_0x4611cb(0x1de)](_0x4611cb(0x204),function(){const _0xe753c6=_0x4611cb;_0x44277b[_0xe753c6(0x24c)]=this[_0xe753c6(0x1f6)],_0x67434b[_0xe753c6(0x239)][_0xe753c6(0x257)]=_0x44277b[_0xe753c6(0x24c)];}),_0xeb2930[_0x4611cb(0x1de)](_0x4611cb(0x204),function(){const _0x4fb7fc=_0x4611cb;_0x44277b[_0x4fb7fc(0x1e0)]=this[_0x4fb7fc(0x1f6)];const _0x52346e=_0x44277b[_0x4fb7fc(0x1e0)]*0x14;_0x67434b['style']['width']=_0x52346e+'px',_0x67434b['style'][_0x4fb7fc(0x203)]=_0x52346e+'px';}),_0x4aac00['addEventListener'](_0x4611cb(0x204),function(){const _0x335d66=_0x4611cb;_0x44277b[_0x335d66(0x230)]=this['value'];}),_0x4cd12a['addEventListener'](_0x4611cb(0x204),function(){const _0x29257f=_0x4611cb;_0x44277b[_0x29257f(0x21d)]=this['value'];}),_0x3723db[_0x4611cb(0x1de)](_0x4611cb(0x27d),function(_0x1e1333){const _0x56b21e=_0x4611cb;if(!_0x44277b[_0x56b21e(0x28d)]||_0x1e1333[_0x56b21e(0x27c)]!==0x0||_0x44277b[_0x56b21e(0x241)])return;const _0x591721=_0x3723db[_0x56b21e(0x244)]();_0x44277b[_0x56b21e(0x1fb)]=!![],_0x3723db[_0x56b21e(0x239)]['cursor']='crosshair',_0x67434b[_0x56b21e(0x239)][_0x56b21e(0x1fe)]=_0x56b21e(0x272),_0x44277b[_0x56b21e(0x213)]=[[(_0x1e1333[_0x56b21e(0x1e7)]-_0x591721[_0x56b21e(0x275)])/_0x44277b[_0x56b21e(0x1f9)],(_0x1e1333['clientY']-_0x591721[_0x56b21e(0x1ce)])/_0x44277b['imgScale']]],_0x44277b[_0x56b21e(0x281)]=_0x250102[_0x56b21e(0x1d2)](0x0,0x0,_0x3723db[_0x56b21e(0x242)],_0x3723db[_0x56b21e(0x203)]),_0x44277b[_0x56b21e(0x265)](_0x1e1333);}),_0x3723db[_0x4611cb(0x1de)]('mousemove',function(_0xd2f5e6){const _0x241d14=_0x4611cb;_0x44277b[_0x241d14(0x1fb)]&&_0x44277b['handleDraw'](_0xd2f5e6);_0x44277b[_0x241d14(0x28d)]&&!_0x44277b[_0x241d14(0x218)]&&(_0x3723db[_0x241d14(0x239)][_0x241d14(0x223)]=_0x241d14(0x25d));if(_0x44277b[_0x241d14(0x28d)]&&!_0x44277b[_0x241d14(0x1fb)]&&!_0x44277b[_0x241d14(0x218)]&&!_0x44277b['no_scribbles']){const _0x110a6d=_0x540dd1[_0x241d14(0x244)](),_0x454ab9=_0x44277b[_0x241d14(0x1e0)]*0xa;_0x67434b[_0x241d14(0x239)][_0x241d14(0x275)]=_0xd2f5e6[_0x241d14(0x1e7)]-_0x110a6d[_0x241d14(0x275)]-_0x454ab9+'px',_0x67434b[_0x241d14(0x239)][_0x241d14(0x1ce)]=_0xd2f5e6[_0x241d14(0x23e)]-_0x110a6d['top']-_0x454ab9+'px',_0x67434b[_0x241d14(0x239)][_0x241d14(0x1fe)]=_0x241d14(0x248);}}),_0x3723db[_0x4611cb(0x1de)](_0x4611cb(0x20d),function(){const _0x2eebc5=_0x4611cb;_0x44277b[_0x2eebc5(0x1fb)]=![],_0x3723db[_0x2eebc5(0x239)][_0x2eebc5(0x223)]='',_0x44277b[_0x2eebc5(0x22c)]();}),_0x3723db['addEventListener'](_0x4611cb(0x1e5),function(){const _0x4461c3=_0x4611cb;_0x44277b[_0x4461c3(0x1fb)]=![],_0x3723db[_0x4461c3(0x239)][_0x4461c3(0x223)]='',_0x67434b[_0x4461c3(0x239)]['display']=_0x4461c3(0x272);}),_0x3201f4[_0x4611cb(0x1de)](_0x4611cb(0x27d),function(_0x5439f1){const _0x3f0cb9=_0x4611cb;_0x5439f1[_0x3f0cb9(0x1f1)]();}),_0x540dd1[_0x4611cb(0x1de)](_0x4611cb(0x27d),function(_0xe3fcfd){const _0x149dfb=_0x4611cb,_0x1854f1=_0x540dd1[_0x149dfb(0x244)](),_0x4f9bc1=_0xe3fcfd[_0x149dfb(0x1e7)]-_0x1854f1[_0x149dfb(0x275)],_0x240702=_0xe3fcfd[_0x149dfb(0x23e)]-_0x1854f1['top'];if(_0xe3fcfd[_0x149dfb(0x27c)]===0x2&&_0x44277b[_0x149dfb(0x1e3)](_0x4f9bc1,_0x240702))_0x44277b[_0x149dfb(0x218)]=!![],_0x44277b['offsetX']=_0x4f9bc1-_0x44277b[_0x149dfb(0x250)],_0x44277b[_0x149dfb(0x201)]=_0x240702-_0x44277b[_0x149dfb(0x24e)],_0x6f013c[_0x149dfb(0x239)][_0x149dfb(0x223)]='grabbing',_0x3723db[_0x149dfb(0x239)][_0x149dfb(0x223)]=_0x149dfb(0x1da),_0x67434b[_0x149dfb(0x239)]['display']=_0x149dfb(0x272);else _0xe3fcfd[_0x149dfb(0x27c)]===0x0&&!_0x44277b['img']&&!_0x44277b[_0x149dfb(0x1f5)]&&document[_0x149dfb(0x26d)](_0x149dfb(0x23d)+_0x44277b[_0x149dfb(0x1dc)])[_0x149dfb(0x23b)]();}),_0x540dd1[_0x4611cb(0x1de)]('mousemove',function(_0x523813){const _0x5f0b2c=_0x4611cb;if(_0x44277b['dragging']){const _0x25e960=_0x540dd1[_0x5f0b2c(0x244)](),_0x13ef14=_0x523813[_0x5f0b2c(0x1e7)]-_0x25e960['left'],_0x47fe96=_0x523813[_0x5f0b2c(0x23e)]-_0x25e960[_0x5f0b2c(0x1ce)];_0x44277b['imgX']=_0x13ef14-_0x44277b[_0x5f0b2c(0x224)],_0x44277b[_0x5f0b2c(0x24e)]=_0x47fe96-_0x44277b[_0x5f0b2c(0x201)],_0x44277b[_0x5f0b2c(0x26e)](),_0x44277b['dragged_just_now']=!![];}}),_0x540dd1[_0x4611cb(0x1de)](_0x4611cb(0x20d),function(_0x5cb4f){const _0x5bf59f=_0x4611cb;_0x44277b[_0x5bf59f(0x218)]&&_0x44277b[_0x5bf59f(0x1f8)](_0x5cb4f,![]);}),_0x540dd1[_0x4611cb(0x1de)]('mouseleave',function(_0x19abb7){const _0x429eba=_0x4611cb;_0x44277b[_0x429eba(0x218)]&&_0x44277b[_0x429eba(0x1f8)](_0x19abb7,!![]);}),_0x540dd1[_0x4611cb(0x1de)]('wheel',function(_0x5c9bc8){const _0x116d3a=_0x4611cb;if(!_0x44277b[_0x116d3a(0x28d)])return;_0x5c9bc8['preventDefault']();const _0x21a087=_0x540dd1[_0x116d3a(0x244)](),_0x1a73a6=_0x5c9bc8['clientX']-_0x21a087[_0x116d3a(0x275)],_0x3bb7b0=_0x5c9bc8[_0x116d3a(0x23e)]-_0x21a087['top'],_0x2bf160=_0x44277b[_0x116d3a(0x1f9)],_0x503211=_0x5c9bc8[_0x116d3a(0x219)]*-0.001;_0x44277b[_0x116d3a(0x1f9)]+=_0x503211,_0x44277b['imgScale']=Math['max'](0.1,_0x44277b[_0x116d3a(0x1f9)]);const _0x3e5feb=_0x44277b[_0x116d3a(0x1f9)]/_0x2bf160;_0x44277b['imgX']=_0x1a73a6-(_0x1a73a6-_0x44277b[_0x116d3a(0x250)])*_0x3e5feb,_0x44277b[_0x116d3a(0x24e)]=_0x3bb7b0-(_0x3bb7b0-_0x44277b[_0x116d3a(0x24e)])*_0x3e5feb,_0x44277b[_0x116d3a(0x26e)]();}),_0x540dd1['addEventListener'](_0x4611cb(0x21f),function(_0x44c56a){const _0x2b68b3=_0x4611cb;_0x44277b['dragged_just_now']&&_0x44c56a[_0x2b68b3(0x236)](),_0x44277b[_0x2b68b3(0x261)]=![];}),_0x540dd1[_0x4611cb(0x1de)](_0x4611cb(0x20e),function(){const _0x311bcf=_0x4611cb;_0x3201f4['style'][_0x311bcf(0x20a)]='1',!_0x44277b[_0x311bcf(0x28d)]&&!_0x44277b[_0x311bcf(0x1f5)]&&(_0x540dd1[_0x311bcf(0x239)][_0x311bcf(0x223)]='pointer');}),_0x540dd1[_0x4611cb(0x1de)](_0x4611cb(0x1d3),function(){const _0x19ff1c=_0x4611cb;_0x3201f4['style'][_0x19ff1c(0x20a)]='0',_0x6f013c[_0x19ff1c(0x239)][_0x19ff1c(0x223)]='',_0x3723db[_0x19ff1c(0x239)]['cursor']='',_0x540dd1['style']['cursor']='',_0x67434b[_0x19ff1c(0x239)][_0x19ff1c(0x1fe)]=_0x19ff1c(0x272);}),_0x1d67ee[_0x4611cb(0x1de)](_0x4611cb(0x27d),function(_0x40f337){const _0x1a39c5=_0x4611cb;_0x44277b['resizing']=!![],_0x40f337[_0x1a39c5(0x236)](),_0x40f337[_0x1a39c5(0x1f1)]();}),document[_0x4611cb(0x1de)](_0x4611cb(0x1ea),function(_0x14ad60){const _0x301046=_0x4611cb;if(_0x44277b['resizing']){const _0x350a32=_0x39d95e[_0x301046(0x244)](),_0x2857b4=_0x14ad60['clientY']-_0x350a32['top'];_0x39d95e[_0x301046(0x239)][_0x301046(0x203)]=_0x2857b4+'px',_0x14ad60['preventDefault'](),_0x14ad60[_0x301046(0x1f1)]();}}),document[_0x4611cb(0x1de)](_0x4611cb(0x20d),function(){const _0x19b8d4=_0x4611cb;_0x44277b[_0x19b8d4(0x1ec)]=![];}),document[_0x4611cb(0x1de)](_0x4611cb(0x1e5),function(){const _0x498f32=_0x4611cb;_0x44277b[_0x498f32(0x1ec)]=![];}),['dragenter',_0x4611cb(0x266),_0x4611cb(0x1c8),_0x4611cb(0x1e6)][_0x4611cb(0x1fd)](_0x1ff25d=>{const _0xebdc28=_0x4611cb;_0x540dd1[_0xebdc28(0x1de)](_0x1ff25d,_0x1b55c8,![]);});function _0x1b55c8(_0x2f6e52){const _0x1e361a=_0x4611cb;_0x2f6e52[_0x1e361a(0x236)](),_0x2f6e52[_0x1e361a(0x1f1)]();}_0x540dd1[_0x4611cb(0x1de)](_0x4611cb(0x1cb),()=>{const _0x3e07a7=_0x4611cb;_0x6f013c[_0x3e07a7(0x239)]['cursor']=_0x3e07a7(0x253),_0x3723db[_0x3e07a7(0x239)][_0x3e07a7(0x223)]=_0x3e07a7(0x253);}),_0x540dd1[_0x4611cb(0x1de)]('dragleave',()=>{const _0xe21f36=_0x4611cb;_0x6f013c['style']['cursor']='',_0x3723db[_0xe21f36(0x239)][_0xe21f36(0x223)]='';}),_0x540dd1[_0x4611cb(0x1de)]('drop',function(_0xfd1575){const _0x5f2856=_0x4611cb;_0x6f013c[_0x5f2856(0x239)][_0x5f2856(0x223)]='',_0x3723db[_0x5f2856(0x239)][_0x5f2856(0x223)]='';const _0x3f69e0=_0xfd1575['dataTransfer'],_0x58abcc=_0x3f69e0[_0x5f2856(0x20b)];_0x58abcc[_0x5f2856(0x20f)]>0x0&&_0x44277b['handleFileUpload'](_0x58abcc[0x0]);}),_0x540dd1[_0x4611cb(0x1de)]('mouseenter',()=>{const _0x448e84=_0x4611cb;_0x44277b[_0x448e84(0x254)]=!![];}),_0x540dd1[_0x4611cb(0x1de)](_0x4611cb(0x1e5),()=>{const _0x1881e2=_0x4611cb;_0x44277b[_0x1881e2(0x254)]=![];}),document[_0x4611cb(0x1de)](_0x4611cb(0x251),function(_0x2c1418){const _0x210e83=_0x4611cb;_0x44277b[_0x210e83(0x254)]&&_0x44277b[_0x210e83(0x1d4)](_0x2c1418);}),document['addEventListener'](_0x4611cb(0x1fc),_0x563be4=>{const _0x2a82c6=_0x4611cb;if(!_0x44277b[_0x2a82c6(0x254)])return;_0x563be4[_0x2a82c6(0x25a)]&&_0x563be4[_0x2a82c6(0x1dd)]==='z'&&(_0x563be4[_0x2a82c6(0x236)](),this[_0x2a82c6(0x252)]()),_0x563be4[_0x2a82c6(0x25a)]&&_0x563be4[_0x2a82c6(0x1dd)]==='y'&&(_0x563be4[_0x2a82c6(0x236)](),this[_0x2a82c6(0x1c6)]());}),_0x2fdf34[_0x4611cb(0x1de)]('click',function(){const _0x1670d7=_0x4611cb;_0x44277b[_0x1670d7(0x22a)]();}),_0x323f6d[_0x4611cb(0x1de)](_0x4611cb(0x23b),function(){const _0x1f9b8a=_0x4611cb;_0x44277b[_0x1f9b8a(0x1c7)]();}),_0x44277b[_0x4611cb(0x27a)](),_0x44277b[_0x4611cb(0x259)][_0x4611cb(0x274)](function(_0x3ad581){const _0x4f3d4b=_0x4611cb;_0x44277b[_0x4f3d4b(0x211)](_0x3ad581);}),_0x44277b['foreground_gradio_bind'][_0x4611cb(0x274)](function(_0x55fa67){_0x44277b['uploadBase64DrawingCanvas'](_0x55fa67);});}['handleDraw'](_0x1e24bd){const _0x51bcda=_0x374a8c,_0x4766b7=this[_0x51bcda(0x27f)],_0x3ca2ff=_0x4766b7[_0x51bcda(0x24d)]('2d'),_0x5be4b2=_0x4766b7[_0x51bcda(0x244)](),_0x2df6b2=(_0x1e24bd[_0x51bcda(0x1e7)]-_0x5be4b2[_0x51bcda(0x275)])/this[_0x51bcda(0x1f9)],_0x23d726=(_0x1e24bd[_0x51bcda(0x23e)]-_0x5be4b2[_0x51bcda(0x1ce)])/this['imgScale'];this[_0x51bcda(0x213)][_0x51bcda(0x284)]([_0x2df6b2,_0x23d726]),_0x3ca2ff[_0x51bcda(0x289)](this[_0x51bcda(0x281)],0x0,0x0),_0x3ca2ff[_0x51bcda(0x1ed)](),_0x3ca2ff[_0x51bcda(0x243)](this[_0x51bcda(0x213)][0x0][0x0],this[_0x51bcda(0x213)][0x0][0x1]);for(let _0x5d4b63=0x1;_0x5d4b630x0)){_0x3ca2ff[_0x51bcda(0x247)]=_0x51bcda(0x23f),_0x3ca2ff[_0x51bcda(0x26a)]=0x1,_0x3ca2ff[_0x51bcda(0x288)]();return;}_0x3ca2ff[_0x51bcda(0x247)]=_0x51bcda(0x1d8);if(!(this[_0x51bcda(0x21d)]>0x0)){_0x3ca2ff[_0x51bcda(0x26a)]=this[_0x51bcda(0x230)]/0x64,_0x3ca2ff[_0x51bcda(0x288)]();return;}const _0x12be7b=_0x3ca2ff[_0x51bcda(0x24b)]*(0x1-this[_0x51bcda(0x21d)]/0x96),_0x8dfe71=_0x3ca2ff['lineWidth']*(0x1+this['scribbleSoftness']/0x96),_0x58b349=Math[_0x51bcda(0x28b)](0x5+this['scribbleSoftness']/0x5),_0x36e61a=(_0x8dfe71-_0x12be7b)/(_0x58b349-0x1);_0x3ca2ff[_0x51bcda(0x26a)]=0x1-Math[_0x51bcda(0x25c)](0x1-Math['min'](this[_0x51bcda(0x230)]/0x64,0.95),0x1/_0x58b349);for(let _0x4a18d7=0x0;_0x4a18d7<_0x58b349;_0x4a18d7++){_0x3ca2ff[_0x51bcda(0x24b)]=_0x12be7b+_0x36e61a*_0x4a18d7,_0x3ca2ff[_0x51bcda(0x288)]();}}[_0x374a8c(0x206)](_0x5531ae){const _0x1048ff=_0x374a8c;if(_0x5531ae&&!this[_0x1048ff(0x1f5)]){const _0x4a10d8=new FileReader();_0x4a10d8[_0x1048ff(0x216)]=_0x5c0fd2=>{const _0x16ca45=_0x1048ff;this[_0x16ca45(0x211)](_0x5c0fd2[_0x16ca45(0x1e9)][_0x16ca45(0x25f)]);},_0x4a10d8['readAsDataURL'](_0x5531ae);}}[_0x374a8c(0x1d4)](_0x85c097){const _0x49df7d=_0x374a8c,_0x5deb03=_0x85c097[_0x49df7d(0x249)][_0x49df7d(0x1f2)];for(let _0x4bbeb0=0x0;_0x4bbeb0<_0x5deb03[_0x49df7d(0x20f)];_0x4bbeb0++){const _0x5ab24f=_0x5deb03[_0x4bbeb0];if(_0x5ab24f['type']['indexOf'](_0x49df7d(0x245))!==-0x1){const _0x4ee5c5=_0x5ab24f[_0x49df7d(0x235)]();this[_0x49df7d(0x206)](_0x4ee5c5);break;}}}[_0x374a8c(0x211)](_0xd82930){const _0x28362b=_0x374a8c;if(typeof this[_0x28362b(0x1d7)]!==_0x28362b(0x1ff)){if(!this[_0x28362b(0x1d7)][_0x28362b(0x25e)][_0x28362b(0x278)](0x1+0x4+0x1-0x2+''))return;}else return;const _0x4eec46=new Image();_0x4eec46[_0x28362b(0x216)]=()=>{const _0xa13c84=_0x28362b;this[_0xa13c84(0x28d)]=_0xd82930,this['orgWidth']=_0x4eec46[_0xa13c84(0x242)],this[_0xa13c84(0x24a)]=_0x4eec46[_0xa13c84(0x203)];const _0x35593e=document['getElementById'](_0xa13c84(0x237)+this['uuid']);(_0x35593e[_0xa13c84(0x242)]!==_0x4eec46['width']||_0x35593e[_0xa13c84(0x203)]!==_0x4eec46[_0xa13c84(0x203)])&&(_0x35593e[_0xa13c84(0x242)]=_0x4eec46['width'],_0x35593e[_0xa13c84(0x203)]=_0x4eec46[_0xa13c84(0x203)]),this['adjustInitialPositionAndScale'](),this[_0xa13c84(0x26e)](),this[_0xa13c84(0x279)](),this['saveState'](),document[_0xa13c84(0x26d)](_0xa13c84(0x23d)+this['uuid'])[_0xa13c84(0x1f6)]=null,document['getElementById'](_0xa13c84(0x22b)+this[_0xa13c84(0x1dc)])[_0xa13c84(0x239)]['display']=_0xa13c84(0x272);};if(_0xd82930)_0x4eec46['src']=_0xd82930;else{this[_0x28362b(0x28d)]=null;const _0x4da790=document[_0x28362b(0x26d)](_0x28362b(0x237)+this[_0x28362b(0x1dc)]);_0x4da790[_0x28362b(0x242)]=0x1,_0x4da790['height']=0x1,this[_0x28362b(0x209)](),this[_0x28362b(0x26e)](),this[_0x28362b(0x279)](),this[_0x28362b(0x22c)]();}}[_0x374a8c(0x22e)](_0xea0055){const _0x5af1ce=_0x374a8c,_0x158e8a=new Image();_0x158e8a[_0x5af1ce(0x216)]=()=>{const _0x412900=_0x5af1ce,_0x5e7d3e=document[_0x412900(0x26d)](_0x412900(0x237)+this[_0x412900(0x1dc)]),_0x49fa8c=_0x5e7d3e['getContext']('2d');_0x49fa8c[_0x412900(0x212)](0x0,0x0,_0x5e7d3e['width'],_0x5e7d3e['height']),_0x49fa8c[_0x412900(0x26e)](_0x158e8a,0x0,0x0),this[_0x412900(0x22c)]();};if(_0xea0055)_0x158e8a[_0x5af1ce(0x256)]=_0xea0055;else{const _0x127e54=document[_0x5af1ce(0x26d)](_0x5af1ce(0x237)+this[_0x5af1ce(0x1dc)]),_0x4d21d0=_0x127e54[_0x5af1ce(0x24d)]('2d');_0x4d21d0[_0x5af1ce(0x212)](0x0,0x0,_0x127e54[_0x5af1ce(0x242)],_0x127e54[_0x5af1ce(0x203)]),this[_0x5af1ce(0x22c)]();}}[_0x374a8c(0x1e3)](_0x3d1f50,_0xf5ae34){const _0x462f7c=_0x374a8c,_0x58db8e=this[_0x462f7c(0x1d9)]*this[_0x462f7c(0x1f9)],_0x4045e8=this[_0x462f7c(0x24a)]*this['imgScale'];return _0x3d1f50>this[_0x462f7c(0x250)]&&_0x3d1f50this['imgY']&&_0xf5ae340x0&&(this[_0xa2a5f0(0x246)]--,this['restoreState'](),this['updateUndoRedoButtons']());}['redo'](){const _0x209b5a=_0x374a8c;this[_0x209b5a(0x246)]=this['history']['length']-0x1,_0x376ab2[_0x8ba24e(0x239)][_0x8ba24e(0x20a)]=_0x376ab2[_0x8ba24e(0x221)]?_0x8ba24e(0x207):'1',_0x4170b8['style'][_0x8ba24e(0x20a)]=_0x4170b8['disabled']?_0x8ba24e(0x207):'1';}[_0x374a8c(0x279)](){const _0xe49cc0=_0x374a8c;if(!this[_0xe49cc0(0x28d)]){this['background_gradio_bind']['set_value']('');return;}const _0x1251b9=document['getElementById']('image_'+this[_0xe49cc0(0x1dc)]),_0x8a2e42=this['temp_canvas'],_0x534739=_0x8a2e42[_0xe49cc0(0x24d)]('2d');_0x8a2e42['width']=this[_0xe49cc0(0x1d9)],_0x8a2e42[_0xe49cc0(0x203)]=this[_0xe49cc0(0x24a)],_0x534739[_0xe49cc0(0x26e)](_0x1251b9,0x0,0x0,this[_0xe49cc0(0x1d9)],this[_0xe49cc0(0x24a)]);const _0x41ec46=_0x8a2e42[_0xe49cc0(0x1e1)](_0xe49cc0(0x1d6));this[_0xe49cc0(0x259)][_0xe49cc0(0x231)](_0x41ec46);}[_0x374a8c(0x225)](){const _0x374804=_0x374a8c;if(!this[_0x374804(0x28d)]){this[_0x374804(0x200)][_0x374804(0x231)]('');return;}const _0x388adf=document[_0x374804(0x26d)]('drawingCanvas_'+this[_0x374804(0x1dc)]),_0x368c03=_0x388adf[_0x374804(0x1e1)](_0x374804(0x1d6));this[_0x374804(0x200)]['set_value'](_0x368c03);}[_0x374a8c(0x22a)](){const _0xa73976=_0x374a8c;if(this[_0xa73976(0x1cc)])return;const _0x302b84=document[_0xa73976(0x26d)](_0xa73976(0x282)+this['uuid']),_0xf60ce8=document[_0xa73976(0x26d)](_0xa73976(0x286)+this['uuid']),_0x5b82e3=document[_0xa73976(0x26d)](_0xa73976(0x228)+this[_0xa73976(0x1dc)]),_0x19f97=document[_0xa73976(0x26d)]('minButton_'+this[_0xa73976(0x1dc)]);this[_0xa73976(0x28a)]={'width':_0x302b84[_0xa73976(0x239)][_0xa73976(0x242)],'height':_0x302b84[_0xa73976(0x239)]['height'],'top':_0x302b84['style'][_0xa73976(0x1ce)],'left':_0x302b84['style']['left'],'position':_0x302b84[_0xa73976(0x239)]['position'],'zIndex':_0x302b84[_0xa73976(0x239)][_0xa73976(0x226)]},_0x302b84[_0xa73976(0x239)][_0xa73976(0x242)]=_0xa73976(0x232),_0x302b84[_0xa73976(0x239)]['height']=_0xa73976(0x21b),_0x302b84[_0xa73976(0x239)]['top']='0',_0x302b84['style'][_0xa73976(0x275)]='0',_0x302b84[_0xa73976(0x239)][_0xa73976(0x1f0)]=_0xa73976(0x271),_0x302b84[_0xa73976(0x239)][_0xa73976(0x226)]=_0xa73976(0x238),_0x5b82e3[_0xa73976(0x239)][_0xa73976(0x1fe)]=_0xa73976(0x272),_0x19f97[_0xa73976(0x239)][_0xa73976(0x1fe)]=_0xa73976(0x264),this[_0xa73976(0x1cc)]=!![];}[_0x374a8c(0x1c7)](){const _0x1bc47a=_0x374a8c;if(!this['maximized'])return;const _0x4ccdb3=document['getElementById'](_0x1bc47a(0x282)+this['uuid']),_0x116cf6=document[_0x1bc47a(0x26d)](_0x1bc47a(0x228)+this[_0x1bc47a(0x1dc)]),_0x1ba7c0=document[_0x1bc47a(0x26d)]('minButton_'+this[_0x1bc47a(0x1dc)]);_0x4ccdb3[_0x1bc47a(0x239)]['width']=this[_0x1bc47a(0x28a)][_0x1bc47a(0x242)],_0x4ccdb3[_0x1bc47a(0x239)][_0x1bc47a(0x203)]=this['originalState'][_0x1bc47a(0x203)],_0x4ccdb3[_0x1bc47a(0x239)][_0x1bc47a(0x1ce)]=this[_0x1bc47a(0x28a)][_0x1bc47a(0x1ce)],_0x4ccdb3[_0x1bc47a(0x239)][_0x1bc47a(0x275)]=this[_0x1bc47a(0x28a)][_0x1bc47a(0x275)],_0x4ccdb3[_0x1bc47a(0x239)][_0x1bc47a(0x1f0)]=this[_0x1bc47a(0x28a)][_0x1bc47a(0x1f0)],_0x4ccdb3[_0x1bc47a(0x239)][_0x1bc47a(0x226)]=this[_0x1bc47a(0x28a)][_0x1bc47a(0x226)],_0x116cf6['style'][_0x1bc47a(0x1fe)]=_0x1bc47a(0x264),_0x1ba7c0[_0x1bc47a(0x239)][_0x1bc47a(0x1fe)]=_0x1bc47a(0x272),this[_0x1bc47a(0x1cc)]=![];}[_0x374a8c(0x1f8)](_0x5d8941,_0x1ca1cf){const _0x1db4da=_0x374a8c,_0x1065a9=document[_0x1db4da(0x26d)](_0x1db4da(0x227)+this[_0x1db4da(0x1dc)]),_0x44fd56=document[_0x1db4da(0x26d)](_0x1db4da(0x237)+this['uuid']);this[_0x1db4da(0x218)]=![],_0x1065a9[_0x1db4da(0x239)][_0x1db4da(0x223)]=_0x1db4da(0x277),_0x44fd56[_0x1db4da(0x239)]['cursor']=_0x1db4da(0x277);}}function _0x3ec7(){const _0x2f644a=['13644547gnftKV','uploadButton_','inline-block','handleDraw','dragover','redoButton_','scribbleColor_','scribbleSoftnessFixed','globalAlpha','resetImage','temp_canvas','getElementById','drawImage','observe','12132616ymoAjN','fixed','none','50194130SOKNXq','listen','left','scribbleWidthBlock_','grab','startsWith','on_img_upload','updateUndoRedoButtons','defineProperty','button','mousedown','26322OKSXUe','drawingCanvas','slice','temp_draw_bg','container_','scribbleColorBlock_','push','clientWidth','toolbar_','\x20textarea','stroke','putImageData','originalState','round','1OGJEZj','img','redo','minimize','dragleave','scribbleSoftness_','imageContainer_','dragenter','maximized','resizeLine_','top','contrast_scribbles','repeat','logical_image_background','getImageData','mouseout','handlePaste','logical_image_foreground','image/png','gradio_config','source-over','orgWidth','grabbing','fillStyle','uuid','key','addEventListener','alphaLabel_','scribbleWidth','toDataURL','scribbleAlphaFixed','isInsideImage','removeImage','mouseleave','drop','clientX','9mbfKci','target','mousemove','start','resizing','beginPath','contrast_pattern','querySelector','position','stopPropagation','items','#000000','initial_height','no_upload','value','undoButton_','handleDragEnd','imgScale','strokeStyle','drawing','keydown','forEach','display','undefined','foreground_gradio_bind','offsetY','removeButton_','height','input','9368075wZYBhj','handleFileUpload','0.5','previousValue','adjustInitialPositionAndScale','opacity','files','dispatchEvent','mouseup','mouseover','length','fillRect','uploadBase64','clearRect','temp_draw_points','min','scribbleIndicator_','onload','scribbleColorFixed','dragging','deltaY','scribbleWidthFixed','100vh','#ffffff','scribbleSoftness','centerButton_','contextmenu','100%','disabled','scribbleSoftnessBlock_','cursor','offsetX','on_drawing_canvas_upload','zIndex','image_','maxButton_','createPattern','maximize','uploadHint_','saveState','sync_lock','uploadBase64DrawingCanvas','clientHeight','scribbleAlpha','set_value','100vw','lineJoin','widthLabel_','getAsFile','preventDefault','drawingCanvas_','1000','style','minButton_','click','history','imageInput_','clientY','destination-out','3123806gxlYDD','no_scribbles','width','moveTo','getBoundingClientRect','image','historyIndex','globalCompositeOperation','block','clipboardData','orgHeight','lineWidth','scribbleColor','getContext','imgY','restoreState','imgX','paste','undo','copy','mouseInsideContainer','412gshJgS','src','borderColor','scribbleAlphaBlock_','background_gradio_bind','ctrlKey','10829976sdcoBo','pow','crosshair','version','result','change','dragged_just_now'];_0x3ec7=function(){return _0x2f644a;};return _0x3ec7();}function _0xe5ae(_0xfed99,_0xc31b63){const _0x3ec7ec=_0x3ec7();return _0xe5ae=function(_0xe5ae18,_0x559b8){_0xe5ae18=_0xe5ae18-0x1c6;let _0x13de64=_0x3ec7ec[_0xe5ae18];return _0x13de64;},_0xe5ae(_0xfed99,_0xc31b63);}const True=!![],False=![]; \ No newline at end of file diff --git a/modules_forge/forge_canvas/canvas.py b/modules_forge/forge_canvas/canvas.py new file mode 100644 index 000000000..ce29759c8 --- /dev/null +++ b/modules_forge/forge_canvas/canvas.py @@ -0,0 +1,120 @@ +# Forge Canvas +# AGPL V3 +# by lllyasviel +# Commercial Use is not allowed. (Contact us for commercial use.) + + +import os +import uuid +import base64 +import gradio as gr +import numpy as np + +from PIL import Image +from io import BytesIO +from gradio.context import Context +from functools import wraps +from modules.ui_components import FormComponent + + +canvas_js_root_path = os.path.dirname(__file__) + + +def web_js(file_name): + full_path = os.path.join(canvas_js_root_path, file_name) + return f'\n' + + +def web_css(file_name): + full_path = os.path.join(canvas_js_root_path, file_name) + return f'\n' + + +DEBUG_MODE = False + +canvas_html = open(os.path.join(canvas_js_root_path, 'canvas.html'), encoding='utf-8').read() +canvas_head = '' +canvas_head += web_css('canvas.css') +canvas_head += web_js('canvas.min.js') + + +def image_to_base64(image_array, numpy=True): + image = Image.fromarray(image_array) if numpy else image_array + image = image.convert("RGBA") + buffered = BytesIO() + image.save(buffered, format="PNG") + image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') + return f"data:image/png;base64,{image_base64}" + + +def base64_to_image(base64_str, numpy=True): + if base64_str.startswith("data:image/png;base64,"): + base64_str = base64_str.replace("data:image/png;base64,", "") + image_data = base64.b64decode(base64_str) + image = Image.open(BytesIO(image_data)) + image = image.convert("RGBA") + image_array = np.array(image) if numpy else image + return image_array + + +class LogicalImage(gr.Textbox, FormComponent): + @wraps(gr.Textbox.__init__) + def __init__(self, *args, numpy=True, **kwargs): + self.numpy = numpy + + if 'value' in kwargs: + initial_value = kwargs['value'] + if initial_value is not None: + kwargs['value'] = self.image_to_base64(initial_value) + else: + del kwargs['value'] + + super().__init__(*args, **kwargs) + + def preprocess(self, payload): + if not isinstance(payload, str): + return None + + if not payload.startswith("data:image/png;base64,"): + return None + + return base64_to_image(payload, numpy=self.numpy) + + def postprocess(self, value): + if value is None: + return None + + return image_to_base64(value, numpy=self.numpy) + + def get_block_name(self): + return "textbox" + + +class ForgeCanvas: + def __init__( + self, + no_upload=False, + no_scribbles=False, + contrast_scribbles=False, + height=512, + scribble_color='#000000', + scribble_color_fixed=False, + scribble_width=4, + scribble_width_fixed=False, + scribble_alpha=100, + scribble_alpha_fixed=False, + scribble_softness=0, + scribble_softness_fixed=False, + visible=True, + numpy=False, + initial_image=None, + elem_id=None, + elem_classes=None + ): + self.uuid = 'uuid_' + uuid.uuid4().hex + self.block = gr.HTML(canvas_html.replace('forge_mixin', self.uuid), visible=visible, elem_id=elem_id, elem_classes=elem_classes) + self.foreground = LogicalImage(visible=DEBUG_MODE, label='foreground', numpy=numpy, elem_id=self.uuid, elem_classes=['logical_image_foreground']) + self.background = LogicalImage(visible=DEBUG_MODE, label='background', numpy=numpy, value=initial_image, elem_id=self.uuid, elem_classes=['logical_image_background']) + Context.root_block.load(None, js=f'async ()=>{{new ForgeCanvas("{self.uuid}", {no_upload}, {no_scribbles}, {contrast_scribbles}, {height}, ' + f"'{scribble_color}', {scribble_color_fixed}, {scribble_width}, {scribble_width_fixed}, " + f'{scribble_alpha}, {scribble_alpha_fixed}, {scribble_softness}, {scribble_softness_fixed});}}') diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index d3ea94a89..ba206db2b 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -223,7 +223,10 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): if getattr(sd_model, 'parameterization', None) == 'v': sd_model.forge_objects.unet.model.model_sampling = model_sampling(sd_model.forge_objects.unet.model.model_config, ModelType.V_PREDICTION) + sd_model.is_sd3 = False + sd_model.latent_channels = 4 sd_model.is_sdxl = conditioner is not None + sd_model.is_sdxl_inpaint = sd_model.is_sdxl and forge_objects.unet.model.diffusion_model.in_channels == 9 sd_model.is_sd2 = not sd_model.is_sdxl and hasattr(sd_model.cond_stage_model, 'model') sd_model.is_sd1 = not sd_model.is_sdxl and not sd_model.is_sd2 sd_model.is_ssd = sd_model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in sd_model.state_dict().keys() diff --git a/modules_forge/forge_version.py b/modules_forge/forge_version.py index c630142e6..eda0bdbe1 100644 --- a/modules_forge/forge_version.py +++ b/modules_forge/forge_version.py @@ -1 +1 @@ -version = '0.0.17v1.8.0rc' +version = '1.0.0v1.10.0rc' diff --git a/pyproject.toml b/pyproject.toml index ce6594171..214aa794d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,6 +2,8 @@ target-version = "py39" +[tool.ruff.lint] + extend-select = [ "B", "C", @@ -29,10 +31,10 @@ ignore = [ "W605", # invalid escape sequence, messes with some docstrings ] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "webui.py" = ["E402"] # Module level import not at top of file -[tool.ruff.flake8-bugbear] +[tool.ruff.lint.flake8-bugbear] # Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`. extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"] diff --git a/requirements.txt b/requirements.txt index 731a1be7d..98d9430af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,10 +4,11 @@ accelerate blendmodes clean-fid +diskcache einops facexlib fastapi>=0.90.1 -gradio==3.41.2 +gradio inflection jsonmerge kornia @@ -17,6 +18,7 @@ omegaconf open-clip-torch piexif +protobuf==3.20.0 psutil pytorch_lightning requests @@ -29,3 +31,4 @@ torch torchdiffeq torchsde transformers==4.30.2 +pillow-avif-plugin==1.4.3 \ No newline at end of file diff --git a/requirements_versions.txt b/requirements_versions.txt index 432e07b31..6f6530ae1 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -1,12 +1,14 @@ +setuptools==69.5.1 # temp fix for compatibility with some old packages GitPython==3.1.32 Pillow==9.5.0 accelerate==0.21.0 blendmodes==2022 clean-fid==0.1.35 +diskcache==5.6.3 einops==0.4.1 facexlib==0.3.0 -fastapi==0.94.0 -gradio==3.41.2 +fastapi==0.104.1 +gradio==4.39.0 httpcore==0.15 inflection==0.5.1 jsonmerge==1.8.0 @@ -16,18 +18,21 @@ numpy==1.26.2 omegaconf==2.2.3 open-clip-torch==2.20.0 piexif==1.1.3 +protobuf==3.20.0 psutil==5.9.5 pytorch_lightning==1.9.4 resize-right==0.0.2 safetensors==0.4.2 scikit-image==0.21.0 -spandrel==0.1.6 +spandrel==0.3.4 +spandrel-extra-arches==0.1.1 tomesd==0.1.3 torch torchdiffeq==0.2.3 torchsde==0.2.6 transformers==4.30.2 httpx==0.24.1 +pillow-avif-plugin==1.4.3 basicsr==1.4.2 -diffusers==0.25.0 -pydantic==1.10.15 +diffusers==0.28.0 +gradio_rangeslider==0.0.6 diff --git a/script.js b/script.js index f069b1ef0..de1a9000d 100644 --- a/script.js +++ b/script.js @@ -29,6 +29,7 @@ var uiAfterUpdateCallbacks = []; var uiLoadedCallbacks = []; var uiTabChangeCallbacks = []; var optionsChangedCallbacks = []; +var optionsAvailableCallbacks = []; var uiAfterUpdateTimeout = null; var uiCurrentTab = null; @@ -77,6 +78,20 @@ function onOptionsChanged(callback) { optionsChangedCallbacks.push(callback); } +/** + * Register callback to be called when the options (in opts global variable) are available. + * The callback receives no arguments. + * If you register the callback after the options are available, it's just immediately called. + */ +function onOptionsAvailable(callback) { + if (Object.keys(opts).length != 0) { + callback(); + return; + } + + optionsAvailableCallbacks.push(callback); +} + function executeCallbacks(queue, arg) { for (const callback of queue) { try { diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index c98ab4809..5df9dff9c 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -102,7 +102,7 @@ def _get_masked_window_rgb(np_mask_grey, hardness=1.): shaped_noise_fft = _fft2(noise_rgb) shaped_noise_fft[:, :, :] = np.absolute(shaped_noise_fft[:, :, :]) ** 2 * (src_dist ** noise_q) * src_phase # perform the actual shaping - brightness_variation = 0. # color_variation # todo: temporarily tieing brightness variation to color variation for now + brightness_variation = 0. # color_variation # todo: temporarily tying brightness variation to color variation for now contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2. # scikit-image is used for histogram matching, very convenient! diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py index e1e156ddc..53a0cc44c 100644 --- a/scripts/postprocessing_codeformer.py +++ b/scripts/postprocessing_codeformer.py @@ -25,7 +25,7 @@ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, codefor if codeformer_visibility == 0 or not enable: return - restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight) + restored_img = codeformer_model.codeformer.restore(np.array(pp.image.convert("RGB"), dtype=np.uint8), w=codeformer_weight) res = Image.fromarray(restored_img) if codeformer_visibility < 1.0: diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py index 6e7566055..57e362399 100644 --- a/scripts/postprocessing_gfpgan.py +++ b/scripts/postprocessing_gfpgan.py @@ -22,7 +22,7 @@ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, gfpgan_ if gfpgan_visibility == 0 or not enable: return - restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8)) + restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image.convert("RGB"), dtype=np.uint8)) res = Image.fromarray(restored_img) if gfpgan_visibility < 1.0: diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index e269682d0..2409fd207 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -1,15 +1,28 @@ +import re + from PIL import Image import numpy as np from modules import scripts_postprocessing, shared import gradio as gr -from modules.ui_components import FormRow, ToolButton +from modules.ui_components import FormRow, ToolButton, InputAccordion from modules.ui import switch_values_symbol upscale_cache = {} +def limit_size_by_one_dimention(w, h, limit): + if h > w and h > limit: + w = limit * w // h + h = limit + elif w > limit: + h = limit * h // w + w = limit + + return int(w), int(h) + + class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): name = "Upscale" order = 1000 @@ -17,11 +30,22 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): def ui(self): selected_tab = gr.Number(value=0, visible=False) - with gr.Column(): + with InputAccordion(True, label="Upscale", elem_id="extras_upscale") as upscale_enabled: + with FormRow(): + extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + + with FormRow(): + extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") + with FormRow(): with gr.Tabs(elem_id="extras_resize_mode"): with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + with gr.Row(): + with gr.Column(scale=4): + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + with gr.Column(scale=1, min_width=160): + max_side_length = gr.Number(label="Max side length", value=0, elem_id="extras_upscale_max_side_length", tooltip="If any of two sides of the image ends up larger than specified, will downscale it to fit. 0 = no limit.", min_width=160, step=8, minimum=0) with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: with FormRow(): @@ -32,20 +56,27 @@ def ui(self): upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn", tooltip="Switch width/height") upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") - with FormRow(): - extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + def on_selected_upscale_method(upscale_method): + if not shared.opts.set_scale_by_when_changing_upscaler: + return gr.update() - with FormRow(): - extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") + match = re.search(r'(\d)[xX]|[xX](\d)', upscale_method) + if not match: + return gr.update() + + return gr.update(value=int(match.group(1) or match.group(2))) upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False) tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab]) tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab]) + extras_upscaler_1.change(on_selected_upscale_method, inputs=[extras_upscaler_1], outputs=[upscaling_resize], show_progress="hidden") + return { + "upscale_enabled": upscale_enabled, "upscale_mode": selected_tab, "upscale_by": upscaling_resize, + "max_side_length": max_side_length, "upscale_to_width": upscaling_resize_w, "upscale_to_height": upscaling_resize_h, "upscale_crop": upscaling_crop, @@ -54,12 +85,18 @@ def ui(self): "upscaler_2_visibility": extras_upscaler_2_visibility, } - def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop): + def upscale(self, image, info, upscaler, upscale_mode, upscale_by, max_side_length, upscale_to_width, upscale_to_height, upscale_crop): if upscale_mode == 1: upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height) info["Postprocess upscale to"] = f"{upscale_to_width}x{upscale_to_height}" else: info["Postprocess upscale by"] = upscale_by + if max_side_length != 0 and max(*image.size)*upscale_by > max_side_length: + upscale_mode = 1 + upscale_crop = False + upscale_to_width, upscale_to_height = limit_size_by_one_dimention(image.width*upscale_by, image.height*upscale_by, max_side_length) + upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height) + info["Max side length"] = max_side_length cache_key = (hash(np.array(image.getdata()).tobytes()), upscaler.name, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) cached_image = upscale_cache.pop(cache_key, None) @@ -81,7 +118,7 @@ def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_w return image - def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): + def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_enabled=True, upscale_mode=1, upscale_by=2.0, max_side_length=0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): if upscale_mode == 1: pp.shared.target_width = upscale_to_width pp.shared.target_height = upscale_to_height @@ -89,7 +126,13 @@ def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upsca pp.shared.target_width = int(pp.image.width * upscale_by) pp.shared.target_height = int(pp.image.height * upscale_by) - def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): + pp.shared.target_width, pp.shared.target_height = limit_size_by_one_dimention(pp.shared.target_width, pp.shared.target_height, max_side_length) + + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_enabled=True, upscale_mode=1, upscale_by=2.0, max_side_length=0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): + if not upscale_enabled: + return + + upscaler_1_name = upscaler_1_name if upscaler_1_name == "None": upscaler_1_name = None @@ -99,17 +142,20 @@ def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, if not upscaler1: return + upscaler_2_name = upscaler_2_name if upscaler_2_name == "None": upscaler_2_name = None upscaler2 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_2_name and x.name != "None"]), None) assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}' - upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, max_side_length, upscale_to_width, upscale_to_height, upscale_crop) pp.info["Postprocess upscaler"] = upscaler1.name if upscaler2 and upscaler_2_visibility > 0: - second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, max_side_length, upscale_to_width, upscale_to_height, upscale_crop) + if upscaled_image.mode != second_upscale.mode: + second_upscale = second_upscale.convert(upscaled_image.mode) upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility) pp.info["Postprocess upscaler 2"] = upscaler2.name @@ -145,5 +191,5 @@ def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None) assert upscaler1, f'could not find upscaler named {upscaler_name}' - pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False) + pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, 0, False) pp.info["Postprocess upscaler"] = upscaler1.name diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 6d3e42c06..6a42a04d9 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -11,7 +11,7 @@ import modules.scripts as scripts import gradio as gr -from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion, errors +from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_schedulers, errors from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.shared import opts, state import modules.shared as shared @@ -45,7 +45,7 @@ def apply_prompt(p, x, xs): def apply_order(p, x, xs): token_order = [] - # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen + # Initially grab the tokens from the prompt, so they can be replaced in order of earliest seen for token in x: token_order.append((p.prompt.find(token), token)) @@ -95,33 +95,38 @@ def confirm_checkpoints_or_none(p, xs): raise RuntimeError(f"Unknown checkpoint: {x}") -def apply_clip_skip(p, x, xs): - opts.data["CLIP_stop_at_last_layers"] = x +def confirm_range(min_val, max_val, axis_label): + """Generates a AxisOption.confirm() function that checks all values are within the specified range.""" + def confirm_range_fun(p, xs): + for x in xs: + if not (max_val >= x >= min_val): + raise ValueError(f'{axis_label} value "{x}" out of range [{min_val}, {max_val}]') + + return confirm_range_fun -def apply_upscale_latent_space(p, x, xs): - if x.lower().strip() != '0': - opts.data["use_scale_latent_for_hires_fix"] = True - else: - opts.data["use_scale_latent_for_hires_fix"] = False + +def apply_size(p, x: str, xs) -> None: + try: + width, _, height = x.partition('x') + width = int(width.strip()) + height = int(height.strip()) + p.width = width + p.height = height + except ValueError: + print(f"Invalid size in XYZ plot: {x}") def find_vae(name: str): - if name.lower() in ['auto', 'automatic']: - return modules.sd_vae.unspecified - if name.lower() == 'none': - return None - else: - choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()] - if len(choices) == 0: - print(f"No VAE found for {name}; using automatic") - return modules.sd_vae.unspecified - else: - return modules.sd_vae.vae_dict[choices[0]] + if (name := name.strip().lower()) in ('auto', 'automatic'): + return 'Automatic' + elif name == 'none': + return 'None' + return next((k for k in modules.sd_vae.vae_dict if k.lower() == name), print(f'No VAE found for {name}; using Automatic') or 'Automatic') def apply_vae(p, x, xs): - modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x)) + p.override_settings['sd_vae'] = find_vae(x) def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): @@ -129,7 +134,7 @@ def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): def apply_uni_pc_order(p, x, xs): - opts.data["uni_pc_order"] = min(x, p.steps - 1) + p.override_settings['uni_pc_order'] = min(x, p.steps - 1) def apply_face_restore(p, opt, x): @@ -151,12 +156,14 @@ def fun(p, x, xs): if boolean: x = True if x.lower() == "true" else False p.override_settings[field] = x + return fun def boolean_choice(reverse: bool = False): def choice(): return ["False", "True"] if reverse else ["True", "False"] + return choice @@ -201,7 +208,7 @@ def list_to_csv_string(data_list): def csv_string_to_list_strip(data_str): - return list(map(str.strip, chain.from_iterable(csv.reader(StringIO(data_str))))) + return list(map(str.strip, chain.from_iterable(csv.reader(StringIO(data_str), skipinitialspace=True)))) class AxisOption: @@ -248,18 +255,20 @@ def __init__(self, *args, **kwargs): AxisOption("Sigma min", float, apply_field("s_tmin")), AxisOption("Sigma max", float, apply_field("s_tmax")), AxisOption("Sigma noise", float, apply_field("s_noise")), - AxisOption("Schedule type", str, apply_override("k_sched_type"), choices=lambda: list(sd_samplers_kdiffusion.k_diffusion_scheduler)), + AxisOption("Schedule type", str, apply_field("scheduler"), choices=lambda: [x.label for x in sd_schedulers.schedulers]), AxisOption("Schedule min sigma", float, apply_override("sigma_min")), AxisOption("Schedule max sigma", float, apply_override("sigma_max")), AxisOption("Schedule rho", float, apply_override("rho")), + AxisOption("Beta schedule alpha", float, apply_override("beta_dist_alpha")), + AxisOption("Beta schedule beta", float, apply_override("beta_dist_beta")), AxisOption("Eta", float, apply_field("eta")), - AxisOption("Clip skip", int, apply_clip_skip), + AxisOption("Clip skip", int, apply_override('CLIP_stop_at_last_layers')), AxisOption("Denoising", float, apply_field("denoising_strength")), AxisOption("Initial noise multiplier", float, apply_field("initial_noise_multiplier")), AxisOption("Extra noise", float, apply_override("img2img_extra_noise")), AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]), AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), - AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['None'] + list(sd_vae.vae_dict)), + AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['Automatic', 'None'] + list(sd_vae.vae_dict)), AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5), AxisOption("Face restore", str, apply_face_restore, format_value=format_value), @@ -271,6 +280,7 @@ def __init__(self, *args, **kwargs): AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')), AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]), AxisOption("FP8 mode", str, apply_override("fp8_storage"), cost=0.9, choices=lambda: ["Disable", "Enable for SDXL", "Enable"]), + AxisOption("Size", str, apply_size), ] @@ -366,16 +376,17 @@ def index(ix, iy, iz): end_index = start_index + len(xs) * len(ys) grid = images.image_grid(processed_result.images[start_index:end_index], rows=len(ys)) if draw_legend: - grid = images.draw_grid_annotations(grid, processed_result.images[start_index].size[0], processed_result.images[start_index].size[1], hor_texts, ver_texts, margin_size) + grid_max_w, grid_max_h = map(max, zip(*(img.size for img in processed_result.images[start_index:end_index]))) + grid = images.draw_grid_annotations(grid, grid_max_w, grid_max_h, hor_texts, ver_texts, margin_size) processed_result.images.insert(i, grid) processed_result.all_prompts.insert(i, processed_result.all_prompts[start_index]) processed_result.all_seeds.insert(i, processed_result.all_seeds[start_index]) processed_result.infotexts.insert(i, processed_result.infotexts[start_index]) - sub_grid_size = processed_result.images[0].size z_grid = images.image_grid(processed_result.images[:z_count], rows=1) + z_sub_grid_max_w, z_sub_grid_max_h = map(max, zip(*(img.size for img in processed_result.images[:z_count]))) if draw_legend: - z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]]) + z_grid = images.draw_grid_annotations(z_grid, z_sub_grid_max_w, z_sub_grid_max_h, title_texts, [[images.GridAnnotation()]]) processed_result.images.insert(0, z_grid) # TODO: Deeper aspects of the program rely on grid info being misaligned between metadata arrays, which is not ideal. # processed_result.all_prompts.insert(0, processed_result.all_prompts[0]) @@ -387,18 +398,12 @@ def index(ix, iy, iz): class SharedSettingsStackHelper(object): def __enter__(self): - self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers - self.vae = opts.sd_vae - self.uni_pc_order = opts.uni_pc_order + pass def __exit__(self, exc_type, exc_value, tb): - opts.data["sd_vae"] = self.vae - opts.data["uni_pc_order"] = self.uni_pc_order modules.sd_models.reload_model_weights() modules.sd_vae.reload_vae_weights() - opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers - re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") @@ -560,7 +565,7 @@ def process_axis(opt, vals, vals_dropdown): mc = re_range_count.fullmatch(val) if m is not None: start = int(m.group(1)) - end = int(m.group(2))+1 + end = int(m.group(2)) + 1 step = int(m.group(3)) if m.group(3) is not None else 1 valslist_ext += list(range(start, end, step)) @@ -713,11 +718,11 @@ def cell(x, y, z, ix, iy, iz): ydim = len(ys) if vary_seeds_y else 1 if vary_seeds_x: - pc.seed += ix + pc.seed += ix if vary_seeds_y: - pc.seed += iy * xdim + pc.seed += iy * xdim if vary_seeds_z: - pc.seed += iz * xdim * ydim + pc.seed += iz * xdim * ydim try: res = process_images(pc) @@ -785,18 +790,18 @@ def cell(x, y, z, ix, iy, iz): z_count = len(zs) # Set the grid infotexts to the real ones with extra_generation_params (1 main grid + z_count sub-grids) - processed.infotexts[:1+z_count] = grid_infotext[:1+z_count] + processed.infotexts[:1 + z_count] = grid_infotext[:1 + z_count] if not include_lone_images: # Don't need sub-images anymore, drop from list: - processed.images = processed.images[:z_count+1] + processed.images = processed.images[:z_count + 1] if opts.grid_save: # Auto-save main and sub-grids: grid_count = z_count + 1 if z_count > 1 else 1 for g in range(grid_count): # TODO: See previous comment about intentional data misalignment. - adj_g = g-1 if g > 0 else g + adj_g = g - 1 if g > 0 else g images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed) if not include_sub_grids: # if not include_sub_grids then skip saving after the first grid break diff --git a/style.css b/style.css index 8ce78ff0c..cd1095526 100644 --- a/style.css +++ b/style.css @@ -2,14 +2,6 @@ @import url('webui-assets/css/sourcesanspro.css'); - -/* temporary fix to hide gradio crop tool until it's fixed https://github.com/gradio-app/gradio/issues/3810 */ - -div.gradio-image button[aria-label="Edit"] { - display: none; -} - - /* general gradio fixes */ :root, .dark{ @@ -137,6 +129,10 @@ div.gradio-html.min{ background: var(--input-background-fill); } +.gradio-gallery > button.preview{ + width: 100%; +} + .gradio-container .prose a, .gradio-container .prose a:visited{ color: unset; text-decoration: none; @@ -147,6 +143,15 @@ a{ cursor: pointer; } +.upload-container { + width: 100%; + max-width: 100%; +} + +.layer-wrap > ul { + background: var(--background-fill-primary) !important; +} + /* gradio 3.39 puts a lot of overflow: hidden all over the place for an unknown reason. */ div.gradio-container, .block.gradio-textbox, div.gradio-group, div.gradio-dropdown{ overflow: visible !important; @@ -279,7 +284,7 @@ input[type="checkbox"].input-accordion-checkbox{ display: inline-block; } -.html-log .performance p.time, .performance p.vram, .performance p.time abbr, .performance p.vram abbr { +.html-log .performance p.time, .performance p.vram, .performance p.profile, .performance p.time abbr, .performance p.vram abbr { margin-bottom: 0; color: var(--block-title-text-color); } @@ -291,6 +296,10 @@ input[type="checkbox"].input-accordion-checkbox{ margin-left: auto; } +.html-log .performance p.profile { + margin-left: 0.5em; +} + .html-log .performance .measurement{ color: var(--body-text-color); font-weight: bold; @@ -387,14 +396,7 @@ div#extras_scale_to_tab div.form{ flex-direction: row; } -#img2img_sketch, #img2maskimg, #inpaint_sketch { - overflow: overlay !important; - resize: auto; - background: var(--panel-background-fill); - z-index: 5; -} - -.image-buttons > .form{ +.image-buttons{ justify-content: center; } @@ -528,6 +530,10 @@ table.popup-table .link{ opacity: 0.75; } +.settings-comment .info ol{ + margin: 0.4em 0 0.8em 1em; +} + #sysinfo_download a.sysinfo_big_link{ font-size: 24pt; } @@ -776,9 +782,9 @@ table.popup-table .link{ position:absolute; display:block; padding:0px 0; - border:2px solid #a55000; + border:2px solid var(--primary-800); border-radius:8px; - box-shadow:1px 1px 2px #CE6400; + box-shadow:1px 1px 2px var(--primary-500); width: 200px; } @@ -795,7 +801,7 @@ table.popup-table .link{ } .context-menu-items a:hover{ - background: #a55000; + background: var(--primary-700); } @@ -803,6 +809,8 @@ table.popup-table .link{ #tab_extensions table{ border-collapse: collapse; + overflow-x: auto; + display: block; } #tab_extensions table td, #tab_extensions table th{ @@ -850,6 +858,10 @@ table.popup-table .link{ display: inline-block; } +.compact-checkbox-group div label { + padding: 0.1em 0.3em !important; +} + /* extensions tab table row hover highlight */ #extensions tr:hover td, @@ -1084,9 +1096,9 @@ footer { height:100%; } -div.block.gradio-box.edit-user-metadata { +.edit-user-metadata { width: 56em; - background: var(--body-background-fill); + background: var(--body-background-fill) !important; padding: 2em !important; } @@ -1120,16 +1132,12 @@ div.block.gradio-box.edit-user-metadata { margin-top: 1.5em; } -div.block.gradio-box.popup-dialog, .popup-dialog { +.popup-dialog { width: 56em; - background: var(--body-background-fill); + background: var(--body-background-fill) !important; padding: 2em !important; } -div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-child{ - margin-top: 1em; -} - div.block.input-accordion{ } @@ -1205,12 +1213,24 @@ body.resizing .resize-handle { overflow: hidden; } -.extra-network-pane .extra-network-pane-content { +.extra-network-pane .extra-network-pane-content-dirs { + display: flex; + flex: 1; + flex-direction: column; + overflow: hidden; +} + +.extra-network-pane .extra-network-pane-content-tree { display: flex; flex: 1; overflow: hidden; } +.extra-network-dirs-hidden .extra-network-dirs{ display: none; } +.extra-network-dirs-hidden .extra-network-tree{ display: none; } +.extra-network-dirs-hidden .resize-handle { display: none; } +.extra-network-dirs-hidden .resize-handle-row { display: flex !important; } + .extra-network-pane .extra-network-tree { flex: 1; font-size: 1rem; @@ -1260,7 +1280,7 @@ body.resizing .resize-handle { .extra-network-control { position: relative; - display: grid; + display: flex; width: 100%; padding: 0 !important; margin-top: 0 !important; @@ -1277,6 +1297,12 @@ body.resizing .resize-handle { align-items: start; } +.extra-network-control small{ + color: var(--input-placeholder-color); + line-height: 2.2rem; + margin: 0 0.5rem 0 0.75rem; +} + .extra-network-tree .tree-list--tree {} /* Remove auto indentation from tree. Will be overridden later. */ @@ -1424,6 +1450,12 @@ body.resizing .resize-handle { line-height: 1rem; } + +.extra-network-control .extra-network-control--search .extra-network-control--search-text::placeholder { + color: var(--input-placeholder-color); +} + + /* clear button (x on right side) styling */ .extra-network-control .extra-network-control--search .extra-network-control--search-text::-webkit-search-cancel-button { -webkit-appearance: none; @@ -1456,19 +1488,19 @@ body.resizing .resize-handle { background-color: var(--input-placeholder-color); } -.extra-network-control .extra-network-control--sort[data-sortmode="path"] .extra-network-control--sort-icon { +.extra-network-control .extra-network-control--sort[data-sortkey="default"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } -.extra-network-control .extra-network-control--sort[data-sortmode="name"] .extra-network-control--sort-icon { +.extra-network-control .extra-network-control--sort[data-sortkey="name"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } -.extra-network-control .extra-network-control--sort[data-sortmode="date_created"] .extra-network-control--sort-icon { +.extra-network-control .extra-network-control--sort[data-sortkey="date_created"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } -.extra-network-control .extra-network-control--sort[data-sortmode="date_modified"] .extra-network-control--sort-icon { +.extra-network-control .extra-network-control--sort[data-sortkey="date_modified"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } @@ -1518,13 +1550,18 @@ body.resizing .resize-handle { } .extra-network-control .extra-network-control--enabled { - background-color: rgba(0, 0, 0, 0.15); + background-color: rgba(0, 0, 0, 0.1); + border-radius: 0.25rem; } .dark .extra-network-control .extra-network-control--enabled { background-color: rgba(255, 255, 255, 0.15); } +.extra-network-control .extra-network-control--enabled .extra-network-control--icon{ + background-color: var(--button-secondary-text-color); +} + /* ==== REFRESH ICON ACTIONS ==== */ .extra-network-control .extra-network-control--refresh { padding: 0.25rem; @@ -1615,9 +1652,10 @@ body.resizing .resize-handle { display: inline-flex; visibility: hidden; color: var(--button-secondary-text-color); - + width: 0; } .extra-network-tree .tree-list-content:hover .button-row { visibility: visible; + width: auto; } diff --git a/webui-macos-env.sh b/webui-macos-env.sh index db7e8b1a0..00a36e177 100644 --- a/webui-macos-env.sh +++ b/webui-macos-env.sh @@ -4,14 +4,14 @@ # Please modify webui-user.sh to change these instead of this file # #################################################################### -if [[ -x "$(command -v python3.10)" ]] -then - python_cmd="python3.10" -fi - export install_dir="$HOME" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" -export TORCH_COMMAND="pip install torch==2.1.0 torchvision==0.16.0" export PYTORCH_ENABLE_MPS_FALLBACK=1 +if [[ "$(sysctl -n machdep.cpu.brand_string)" =~ ^.*"Intel".*$ ]]; then + export TORCH_COMMAND="pip install torch==2.1.2 torchvision==0.16.2" +else + export TORCH_COMMAND="pip install torch==2.3.1 torchvision==0.18.1" +fi + #################################################################### diff --git a/webui.bat b/webui.bat index e2c9079d2..7b162ce27 100644 --- a/webui.bat +++ b/webui.bat @@ -37,12 +37,18 @@ if %ERRORLEVEL% == 0 goto :activate_venv for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i" echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME% %PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt -if %ERRORLEVEL% == 0 goto :activate_venv +if %ERRORLEVEL% == 0 goto :upgrade_pip echo Unable to create venv in directory "%VENV_DIR%" goto :show_stdout_stderr +:upgrade_pip +"%VENV_DIR%\Scripts\Python.exe" -m pip install --upgrade pip +if %ERRORLEVEL% == 0 goto :activate_venv +echo Warning: Failed to upgrade PIP version + :activate_venv set PYTHON="%VENV_DIR%\Scripts\Python.exe" +call "%VENV_DIR%\Scripts\activate.bat" echo venv %PYTHON% :skip_venv diff --git a/webui.py b/webui.py index 022330b0a..921edce9f 100644 --- a/webui.py +++ b/webui.py @@ -11,6 +11,8 @@ from modules_forge import main_thread +from modules_forge.forge_canvas.canvas import canvas_js_root_path + startup_timer = timer.startup_timer startup_timer.record("launcher") @@ -92,7 +94,7 @@ def webui_worker(): auth=gradio_auth_creds, inbrowser=auto_launch_browser, prevent_thread_lock=True, - allowed_paths=cmd_opts.gradio_allowed_path, + allowed_paths=cmd_opts.gradio_allowed_path + [canvas_js_root_path], app_kwargs={ "docs_url": "/docs", "redoc_url": "/redoc", diff --git a/webui.sh b/webui.sh index f116376f7..89dae163a 100755 --- a/webui.sh +++ b/webui.sh @@ -44,7 +44,11 @@ fi # python3 executable if [[ -z "${python_cmd}" ]] then - python_cmd="python3" + python_cmd="python3.10" +fi +if [[ ! -x "$(command -v "${python_cmd}")" ]] +then + python_cmd="python3" fi # git executable @@ -113,13 +117,13 @@ then exit 1 fi -if [[ -d .git ]] +if [[ -d "$SCRIPT_DIR/.git" ]] then printf "\n%s\n" "${delimiter}" printf "Repo already cloned, using it as install directory" printf "\n%s\n" "${delimiter}" - install_dir="${PWD}/../" - clone_dir="${PWD##*/}" + install_dir="${SCRIPT_DIR}/../" + clone_dir="${SCRIPT_DIR##*/}" fi # Check prerequisites @@ -129,13 +133,19 @@ case "$gpu_info" in export HSA_OVERRIDE_GFX_VERSION=10.3.0 if [[ -z "${TORCH_COMMAND}" ]] then - pyv="$(${python_cmd} -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')" - if [[ $(bc <<< "$pyv <= 3.10") -eq 1 ]] + pyv="$(${python_cmd} -c 'import sys; print(f"{sys.version_info[0]}.{sys.version_info[1]:02d}")')" + # Using an old nightly compiled against rocm 5.2 for Navi1, see https://github.com/pytorch/pytorch/issues/106728#issuecomment-1749511711 + if [[ $pyv == "3.8" ]] + then + export TORCH_COMMAND="pip install https://download.pytorch.org/whl/nightly/rocm5.2/torch-2.0.0.dev20230209%2Brocm5.2-cp38-cp38-linux_x86_64.whl https://download.pytorch.org/whl/nightly/rocm5.2/torchvision-0.15.0.dev20230209%2Brocm5.2-cp38-cp38-linux_x86_64.whl" + elif [[ $pyv == "3.9" ]] + then + export TORCH_COMMAND="pip install https://download.pytorch.org/whl/nightly/rocm5.2/torch-2.0.0.dev20230209%2Brocm5.2-cp39-cp39-linux_x86_64.whl https://download.pytorch.org/whl/nightly/rocm5.2/torchvision-0.15.0.dev20230209%2Brocm5.2-cp39-cp39-linux_x86_64.whl" + elif [[ $pyv == "3.10" ]] then - # Navi users will still use torch 1.13 because 2.0 does not seem to work. - export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6" + export TORCH_COMMAND="pip install https://download.pytorch.org/whl/nightly/rocm5.2/torch-2.0.0.dev20230209%2Brocm5.2-cp310-cp310-linux_x86_64.whl https://download.pytorch.org/whl/nightly/rocm5.2/torchvision-0.15.0.dev20230209%2Brocm5.2-cp310-cp310-linux_x86_64.whl" else - printf "\e[1m\e[31mERROR: RX 5000 series GPUs must be using at max python 3.10, aborting...\e[0m" + printf "\e[1m\e[31mERROR: RX 5000 series GPUs python version must be between 3.8 and 3.10, aborting...\e[0m" exit 1 fi fi @@ -143,7 +153,7 @@ case "$gpu_info" in *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 ;; *"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \ - export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.7" + export TORCH_COMMAND="pip install torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.7" ;; *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 printf "\n%s\n" "${delimiter}" @@ -157,11 +167,10 @@ if ! echo "$gpu_info" | grep -q "NVIDIA"; then if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] then - export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2" - elif echo "$gpu_info" | grep -q "Huawei" && [[ -z "${TORCH_COMMAND}" ]] + export TORCH_COMMAND="pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7" + elif npu-smi info 2>/dev/null then - export TORCH_COMMAND="pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu; pip install torch_npu" - + export TORCH_COMMAND="pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu; pip install torch_npu==2.1.0" fi fi @@ -205,12 +214,15 @@ then if [[ ! -d "${venv_dir}" ]] then "${python_cmd}" -m venv "${venv_dir}" + "${venv_dir}"/bin/python -m pip install --upgrade pip first_launch=1 fi # shellcheck source=/dev/null if [[ -f "${venv_dir}"/bin/activate ]] then source "${venv_dir}"/bin/activate + # ensure use of python from venv + python_cmd="${venv_dir}"/bin/python else printf "\n%s\n" "${delimiter}" printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m" @@ -238,7 +250,7 @@ prepare_tcmalloc() { for lib in "${TCMALLOC_LIBS[@]}" do # Determine which type of tcmalloc library the library supports - TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -P $lib | head -n 1)" + TCMALLOC="$(PATH=/sbin:/usr/sbin:$PATH ldconfig -p | grep -P $lib | head -n 1)" TC_INFO=(${TCMALLOC//=>/}) if [[ ! -z "${TC_INFO}" ]]; then echo "Check TCMalloc: ${TC_INFO}"