Skip to content

Commit

Permalink
Various tweaks around Flux onboarding (#451)
Browse files Browse the repository at this point in the history
* fix: prevent workers without flux support picking up flux jobs

* feat: adjusted TTL formula to be algorithmic

* feat: support betas

* Added support for extra_slow_workers

* Added support for workers limiting step count to the model's expectation
  • Loading branch information
db0 committed Sep 13, 2024
1 parent 1223851 commit c8a455b
Show file tree
Hide file tree
Showing 22 changed files with 381 additions and 76 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ SPDX-License-Identifier: AGPL-3.0-or-later

# Changelog

# 4.43.0

* Adjused TTL formula to be algorithmic
* prevent workers without flux support picking up flux jobs
* Adds `extra_slow_workers` bool for image gen async
* Adds `extra_slow_worker` bool for worker pop
* Adds `limit_max_steps` for worker pop

# 4.42.0

* Adds support for the Flux family of models
Expand Down
6 changes: 6 additions & 0 deletions horde/apis/models/kobold_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,12 @@ def __init__(self, api):
"The request will include the details of the job as well as the request ID."
),
),
"extra_slow_workers": fields.Boolean(
default=False,
description=(
"When True, allows very slower workers to pick up this request. " "Use this when you don't mind waiting a lot."
),
),
},
)
self.response_model_contrib_details = api.inherit(
Expand Down
24 changes: 24 additions & 0 deletions horde/apis/models/stable_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,14 @@ def __init__(self):
help="If True, this worker will pick up requests requesting LoRas.",
location="json",
)
self.job_pop_parser.add_argument(
"limit_max_steps",
type=bool,
required=False,
default=False,
help="If True, This worker will not pick up jobs with more steps than the average allowed for that model.",
location="json",
)
self.job_submit_parser.add_argument(
"seed",
type=int,
Expand Down Expand Up @@ -451,6 +459,9 @@ def __init__(self, api):
"max_pixels": fields.Integer(
description="How many waiting requests were skipped because they demanded a higher size than this worker provides.",
),
"step_count": fields.Integer(
description="How many waiting requests were skipped because they demanded a higher step count that the worker wants.",
),
"unsafe_ip": fields.Integer(
description="How many waiting requests were skipped because they came from an unsafe IP.",
),
Expand Down Expand Up @@ -544,6 +555,13 @@ def __init__(self, api):
default=True,
description="If True, this worker will pick up requests requesting LoRas.",
),
"limit_max_steps": fields.Boolean(
default=True,
description=(
"If True, This worker will not pick up jobs with more steps than the average allowed for that model."
" this is for use by workers which might run into issues doing too many steps."
),
),
},
)
self.input_model_job_submit = api.inherit(
Expand Down Expand Up @@ -591,6 +609,12 @@ def __init__(self, api):
default=True,
description="When True, allows slower workers to pick up this request. Disabling this incurs an extra kudos cost.",
),
"extra_slow_workers": fields.Boolean(
default=False,
description=(
"When True, allows very slower workers to pick up this request. " "Use this when you don't mind waiting a lot."
),
),
"censor_nsfw": fields.Boolean(
default=False,
description="If the request is SFW, and the worker accidentally generates NSFW, it will send back a censored image.",
Expand Down
22 changes: 22 additions & 0 deletions horde/apis/models/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ def __init__(self):
help="When True, allows slower workers to pick up this request. Disabling this incurs an extra kudos cost.",
location="json",
)
self.generate_parser.add_argument(
"extra_slow_workers",
type=bool,
default=False,
required=False,
help="When True, allows very slower workers to pick up this request. Use this when you don't mind waiting a lot.",
location="json",
)
self.generate_parser.add_argument(
"dry_run",
type=bool,
Expand Down Expand Up @@ -204,6 +212,13 @@ def __init__(self):
help="How many jobvs to pop at the same time",
location="json",
)
self.job_pop_parser.add_argument(
"extra_slow_worker",
type=bool,
default=False,
required=False,
location="json",
)

self.job_submit_parser = reqparse.RequestParser()
self.job_submit_parser.add_argument(
Expand Down Expand Up @@ -537,6 +552,13 @@ def __init__(self, api):
min=1,
max=20,
),
"extra_slow_worker": fields.Boolean(
default=True,
description=(
"If True, marks the worker as very slow. You should only use this if your mps/s is lower than 0.1."
"Extra slow workers are excluded from normal requests but users can opt in to use them."
),
),
},
)
self.response_model_worker_details = api.inherit(
Expand Down
3 changes: 1 addition & 2 deletions horde/apis/v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,6 @@ def post(self):
# as they're typically countermeasures to raids
if skipped_reason != "secret":
self.skipped[skipped_reason] = self.skipped.get(skipped_reason, 0) + 1
# logger.warning(datetime.utcnow())

continue
# There is a chance that by the time we finished all the checks, another worker picked up the WP.
Expand All @@ -477,7 +476,7 @@ def post(self):
# We report maintenance exception only if we couldn't find any jobs
if self.worker.maintenance:
raise e.WorkerMaintenance(self.worker.maintenance_msg)
# logger.warning(datetime.utcnow())
# logger.debug(self.skipped)
return {"id": None, "ids": [], "skipped": self.skipped}, 200

def get_sorted_wp(self, priority_user_ids=None):
Expand Down
7 changes: 7 additions & 0 deletions horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def initiate_waiting_prompt(self):
validated_backends=self.args.validated_backends,
worker_blacklist=self.args.worker_blacklist,
slow_workers=self.args.slow_workers,
extra_slow_workers=self.args.extra_slow_workers,
source_processing=self.args.source_processing,
ipaddr=self.user_ip,
safe_ip=self.safe_ip,
Expand Down Expand Up @@ -599,6 +600,10 @@ def post(self):
db_skipped["kudos"] = post_ret["skipped"]["kudos"]
if "blacklist" in post_ret.get("skipped", {}):
db_skipped["blacklist"] = post_ret["skipped"]["blacklist"]
if "step_count" in post_ret.get("skipped", {}):
db_skipped["step_count"] = post_ret["skipped"]["step_count"]
if "bridge_version" in post_ret.get("skipped", {}):
db_skipped["bridge_version"] = db_skipped.get("bridge_version", 0) + post_ret["skipped"]["bridge_version"]
post_ret["skipped"] = db_skipped
# logger.debug(post_ret)
return post_ret, retcode
Expand All @@ -621,6 +626,8 @@ def check_in(self):
allow_controlnet=self.args.allow_controlnet,
allow_sdxl_controlnet=self.args.allow_sdxl_controlnet,
allow_lora=self.args.allow_lora,
extra_slow_worker=self.args.extra_slow_worker,
limit_max_steps=self.args.limit_max_steps,
priority_usernames=self.priority_usernames,
)

Expand Down
3 changes: 3 additions & 0 deletions horde/bridge_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

BRIDGE_CAPABILITIES = {
"AI Horde Worker reGen": {
9: {"flux"},
8: {"layer_diffuse"},
7: {"qr_code", "extra_texts", "workflow"},
6: {"stable_cascade_2pass"},
Expand Down Expand Up @@ -185,6 +186,7 @@ def parse_bridge_agent(bridge_agent):
@logger.catch(reraise=True)
def check_bridge_capability(capability, bridge_agent):
bridge_name, bridge_version = parse_bridge_agent(bridge_agent)
# logger.debug([bridge_name, bridge_version])
if bridge_name not in BRIDGE_CAPABILITIES:
return False
total_capabilities = set()
Expand All @@ -194,6 +196,7 @@ def check_bridge_capability(capability, bridge_agent):
if checked_semver.compare(bridge_version) <= 0:
total_capabilities.update(BRIDGE_CAPABILITIES[bridge_name][version])
# logger.debug([total_capabilities, capability, capability in total_capabilities])
# logger.debug([bridge_name, BRIDGE_CAPABILITIES[bridge_name]])
return capability in total_capabilities


Expand Down
13 changes: 11 additions & 2 deletions horde/classes/base/processing_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class ProcessingGeneration(db.Model):
nullable=False,
server_default=expression.literal(False),
)
job_ttl = db.Column(db.Integer, default=150, nullable=False, index=True)

wp_id = db.Column(
uuid_column_type(),
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(self, *args, **kwargs):
self.model = matching_models[0]
else:
self.model = kwargs["model"]
self.set_job_ttl()
db.session.commit()

def set_generation(self, generation, things_per_sec, **kwargs):
Expand Down Expand Up @@ -163,10 +165,10 @@ def is_completed(self):
def is_faulted(self):
return self.faulted

def is_stale(self, ttl):
def is_stale(self):
if self.is_completed() or self.is_faulted():
return False
return (datetime.utcnow() - self.start_time).total_seconds() > ttl
return (datetime.utcnow() - self.start_time).total_seconds() > self.job_ttl

def delete(self):
db.session.delete(self)
Expand Down Expand Up @@ -224,3 +226,10 @@ def send_webhook(self, kudos):
break
except Exception as err:
logger.debug(f"Exception when sending generation webhook: {err}. Will retry {3-riter-1} more times...")

def set_job_ttl(self):
"""Returns how many seconds each job request should stay waiting before considering it stale and cancelling it
This function should be overriden by the invididual hordes depending on how the calculating ttl
"""
self.job_ttl = 150
db.session.commit()
23 changes: 11 additions & 12 deletions horde/classes/base/waiting_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from horde.classes.stable.processing_generation import ImageProcessingGeneration
from horde.flask import SQLITE_MODE, db
from horde.logger import logger
from horde.utils import get_db_uuid, get_expiry_date
from horde.utils import get_db_uuid, get_expiry_date, get_extra_slow_expiry_date

procgen_classes = {
"template": ProcessingGeneration,
Expand Down Expand Up @@ -93,6 +93,7 @@ class WaitingPrompt(db.Model):
trusted_workers = db.Column(db.Boolean, default=False, nullable=False, index=True)
validated_backends = db.Column(db.Boolean, default=True, nullable=False, index=True)
slow_workers = db.Column(db.Boolean, default=True, nullable=False, index=True)
extra_slow_workers = db.Column(db.Boolean, default=False, nullable=False, index=True)
worker_blacklist = db.Column(db.Boolean, default=False, nullable=False, index=True)
faulted = db.Column(db.Boolean, default=False, nullable=False, index=True)
active = db.Column(db.Boolean, default=False, nullable=False, index=True)
Expand All @@ -105,6 +106,7 @@ class WaitingPrompt(db.Model):
things = db.Column(db.BigInteger, default=0, nullable=False)
total_usage = db.Column(db.Float, default=0, nullable=False)
extra_priority = db.Column(db.Integer, default=0, nullable=False, index=True)
# TODO: Delete. Obsoleted.
job_ttl = db.Column(db.Integer, default=150, nullable=False)
disable_batching = db.Column(db.Boolean, default=False, nullable=False)
webhook = db.Column(db.String(1024))
Expand Down Expand Up @@ -204,7 +206,6 @@ def extract_params(self):
self.things = 0
self.total_usage = round(self.things * self.n, 2)
self.prepare_job_payload()
self.set_job_ttl()
db.session.commit()

def prepare_job_payload(self):
Expand Down Expand Up @@ -241,7 +242,7 @@ def start_generation(self, worker, amount=1):
self.n -= safe_amount
payload = self.get_job_payload(current_n)
# This does a commit as well
self.refresh()
self.refresh(worker)
procgen_class = procgen_classes[self.wp_type]
gens_list = []
model = None
Expand Down Expand Up @@ -457,8 +458,13 @@ def abort_for_maintenance(self):
except Exception as err:
logger.warning(f"Error when aborting WP. Skipping: {err}")

def refresh(self):
self.expiry = get_expiry_date()
def refresh(self, worker=None):
if worker is not None and worker.extra_slow_worker is True:
self.expiry = get_extra_slow_expiry_date()
else:
new_expiry = get_expiry_date()
if self.expiry < new_expiry:
self.expiry = new_expiry
db.session.commit()

def is_stale(self):
Expand All @@ -469,13 +475,6 @@ def is_stale(self):
def get_priority(self):
return self.extra_priority

def set_job_ttl(self):
"""Returns how many seconds each job request should stay waiting before considering it stale and cancelling it
This function should be overriden by the invididual hordes depending on how the calculating ttl
"""
self.job_ttl = 150
db.session.commit()

def refresh_worker_cache(self):
worker_ids = [worker.worker_id for worker in self.workers]
worker_string_ids = [str(worker.worker_id) for worker in self.workers]
Expand Down
19 changes: 10 additions & 9 deletions horde/classes/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class WorkerTemplate(db.Model):
# Used by all workers to record how much they can pick up to generate
# The value of this column is dfferent per worker type
max_power = db.Column(db.Integer, default=20, nullable=False)
extra_slow_worker = db.Column(db.Boolean, default=False, nullable=False, index=True)

paused = db.Column(db.Boolean, default=False, nullable=False)
maintenance = db.Column(db.Boolean, default=False, nullable=False)
Expand Down Expand Up @@ -196,7 +197,7 @@ def report_suspicion(self, amount=1, reason=Suspicions.WORKER_PROFANITY, formats
f"Last suspicion log: {reason.name}.\n"
f"Total Suspicion {self.get_suspicion()}",
)
db.session.commit()
db.session.flush()

def get_suspicion_reasons(self):
return set([s.suspicion_id for s in self.suspicions])
Expand Down Expand Up @@ -261,10 +262,6 @@ def toggle_paused(self, is_paused_active):

# This should be extended by each worker type
def check_in(self, **kwargs):
# To avoid excessive commits,
# we only record new changes on the worker every 30 seconds
if (datetime.utcnow() - self.last_check_in).total_seconds() < 30 and (datetime.utcnow() - self.created).total_seconds() > 30:
return
self.ipaddr = kwargs.get("ipaddr", None)
self.bridge_agent = sanitize_string(kwargs.get("bridge_agent", "unknown:0:unknown"))
self.threads = kwargs.get("threads", 1)
Expand All @@ -275,6 +272,10 @@ def check_in(self, **kwargs):
self.prioritized_users = kwargs.get("prioritized_users", [])
if not kwargs.get("safe_ip", True) and not self.user.trusted:
self.report_suspicion(reason=Suspicions.UNSAFE_IP)
# To avoid excessive commits,
# we only record new uptime on the worker every 30 seconds
if (datetime.utcnow() - self.last_check_in).total_seconds() < 30 and (datetime.utcnow() - self.created).total_seconds() > 30:
return
if not self.is_stale() and not self.paused and not self.maintenance:
self.uptime += (datetime.utcnow() - self.last_check_in).total_seconds()
# Every 10 minutes of uptime gets 100 kudos rewarded
Expand All @@ -293,7 +294,6 @@ def check_in(self, **kwargs):
# So that they have to stay up at least 10 mins to get uptime kudos
self.last_reward_uptime = self.uptime
self.last_check_in = datetime.utcnow()
db.session.commit()

def get_human_readable_uptime(self):
if self.uptime < 60:
Expand Down Expand Up @@ -511,7 +511,8 @@ def check_in(self, **kwargs):
self.set_models(kwargs.get("models"))
self.nsfw = kwargs.get("nsfw", True)
self.set_blacklist(kwargs.get("blacklist", []))
db.session.commit()
self.extra_slow_worker = kwargs.get("extra_slow_worker", False)
# Commit should happen on calling extensions

def set_blacklist(self, blacklist):
# We don't allow more workers to claim they can server more than 50 models atm (to prevent abuse)
Expand All @@ -527,7 +528,7 @@ def set_blacklist(self, blacklist):
for word in blacklist:
blacklisted_word = WorkerBlackList(worker_id=self.id, word=word[0:15])
db.session.add(blacklisted_word)
db.session.commit()
db.session.flush()

def refresh_model_cache(self):
models_list = [m.model for m in self.models]
Expand Down Expand Up @@ -563,7 +564,7 @@ def set_models(self, models):
return
# logger.debug([existing_model_names,models, existing_model_names == models])
db.session.query(WorkerModel).filter_by(worker_id=self.id).delete()
db.session.commit()
db.session.flush()
for model_name in models:
model = WorkerModel(worker_id=self.id, model=model_name)
db.session.add(model)
Expand Down
Loading

0 comments on commit c8a455b

Please sign in to comment.