From c90e9965c7b6b4d90bb3d63e3c58352309228e5c Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Tue, 27 Jun 2023 11:49:29 -0400 Subject: [PATCH] add train.py requirements check --- cli/train.py | 16 +++++++++++++++- pyproject.toml | 10 ++++++---- requirements.txt | 2 +- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/cli/train.py b/cli/train.py index 3896e8af5..c009f0fc8 100755 --- a/cli/train.py +++ b/cli/train.py @@ -372,15 +372,29 @@ def process_inputs(): process.unload() +def check_versions(): + log.info('checking accelerate') + import accelerate + if accelerate.__version__ != '0.19.0': + log.error(f'invalid accelerate version: required=0.19.0 found={accelerate.__version__}') + exit(1) + log.info('checking diffusers') + import diffusers + if diffusers.__version__ != '0.10.2': + log.error(f'invalid diffusers version: required=0.10.2 found={diffusers.__version__}') + exit(1) + + if __name__ == '__main__': log.info('SD.Next train script') parse_args() + setup_logging() + check_versions() sdapi.sd_url = args.server if args.user is not None: sdapi.sd_username = args.user if args.password is not None: sdapi.sd_password = args.password - setup_logging() prepare_server() verify_args() prepare_options() diff --git a/pyproject.toml b/pyproject.toml index 36978b354..ac09272ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,17 +38,19 @@ exclude = [ ] ignore = [ "A003", # Class attirbute shadowing builtin - "C408", # Rewrite as a literal "C901", # Function is too complex "E501", # Line too long "E731", # Do not assign a `lambda` expression, use a `def` "I001", # Import block is un-sorted or un-formatted - "W605", # invalid escape sequence, messes with some docstrings + "W605", # Invalid escape sequence, messes with some docstrings + "B905", # Without explicit scrict + "C408", # Rewrite as a literal "E402", # Module level import not at top of file "F401", # Imported but unused - "B905", # Without explicit scrict - "RUF005", # Consider concatenation "ISC003", # Implicit string concatenation + "RUF005", # Consider concatenation + "RUF012", # Mutable class attributes + "RUF013", # Implict optional ] [tool.ruff.flake8-bugbear] diff --git a/requirements.txt b/requirements.txt index 7903c0e06..8067314f6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,7 +47,7 @@ requests==2.31.0 tqdm==4.65.0 accelerate==0.20.3 opencv-python==4.7.0.72 -diffusers==0.16.1 +diffusers==0.17.1 einops==0.4.1 gradio==3.32.0 numexpr==2.8.4