Skip to content

Commit

Permalink
add train.py requirements check
Browse files Browse the repository at this point in the history
  • Loading branch information
vladmandic committed Jun 27, 2023
1 parent c80b1eb commit c90e996
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
16 changes: 15 additions & 1 deletion cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,15 +372,29 @@ def process_inputs():
process.unload()


def check_versions():
log.info('checking accelerate')
import accelerate
if accelerate.__version__ != '0.19.0':
log.error(f'invalid accelerate version: required=0.19.0 found={accelerate.__version__}')
exit(1)
log.info('checking diffusers')
import diffusers
if diffusers.__version__ != '0.10.2':
log.error(f'invalid diffusers version: required=0.10.2 found={diffusers.__version__}')
exit(1)


if __name__ == '__main__':
log.info('SD.Next train script')
parse_args()
setup_logging()
check_versions()
sdapi.sd_url = args.server
if args.user is not None:
sdapi.sd_username = args.user
if args.password is not None:
sdapi.sd_password = args.password
setup_logging()
prepare_server()
verify_args()
prepare_options()
Expand Down
10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,19 @@ exclude = [
]
ignore = [
"A003", # Class attirbute shadowing builtin
"C408", # Rewrite as a literal
"C901", # Function is too complex
"E501", # Line too long
"E731", # Do not assign a `lambda` expression, use a `def`
"I001", # Import block is un-sorted or un-formatted
"W605", # invalid escape sequence, messes with some docstrings
"W605", # Invalid escape sequence, messes with some docstrings
"B905", # Without explicit scrict
"C408", # Rewrite as a literal
"E402", # Module level import not at top of file
"F401", # Imported but unused
"B905", # Without explicit scrict
"RUF005", # Consider concatenation
"ISC003", # Implicit string concatenation
"RUF005", # Consider concatenation
"RUF012", # Mutable class attributes
"RUF013", # Implict optional
]

[tool.ruff.flake8-bugbear]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ requests==2.31.0
tqdm==4.65.0
accelerate==0.20.3
opencv-python==4.7.0.72
diffusers==0.16.1
diffusers==0.17.1
einops==0.4.1
gradio==3.32.0
numexpr==2.8.4
Expand Down

0 comments on commit c90e996

Please sign in to comment.