Skip to content

Commit

Permalink
- Add aysnc support for search to fix timeout issues with search
Browse files Browse the repository at this point in the history
- Add retry logic to search to fix timeout issues with search
  • Loading branch information
regiellis committed Aug 29, 2024
1 parent 8148496 commit 56adf6a
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 136 deletions.
6 changes: 4 additions & 2 deletions civitai_models_manager/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# "groq"
# ]
# ///
import asyncio

from civitai_models_manager.__version__ import __version__
from civitai_models_manager import (
MODELS_DIR,
Expand All @@ -32,7 +34,7 @@
from .modules.list import list_models_cli, local_search_cli
from .modules.download import download_model_cli
from .modules.ai import explain_model_cli
from .modules.search import search_cli
from .modules.search import search_cli_sync
from .modules.remove import remove_models_cli
import typer

Expand Down Expand Up @@ -100,7 +102,7 @@ def search_models_command(
sort: str = "Highest Rated",
period: str = "AllTime",
):
search_cli(
search_cli_sync(
query,
tag,
types,
Expand Down
3 changes: 1 addition & 2 deletions civitai_models_manager/modules/details.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import httpx
import subprocess
import json

from typing import Any, Dict, Optional
import html2text
Expand Down Expand Up @@ -160,7 +159,7 @@ def print_model_details(
console.print(model_table)

if desc:
desc_table = create_table("Description", [("Description", "cyan")])
desc_table = create_table("", [("Description", "cyan")])
desc_table.add_row(h2t.handle(model_details["description"]))
console.print(desc_table)

Expand Down
243 changes: 111 additions & 132 deletions civitai_models_manager/modules/search.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,70 @@
import httpx
import subprocess
import questionary
import asyncio
from questionary import Style
from typing import Any, Dict, List, Optional
from rich.console import Console
from rich.text import Text
from .helpers import create_table, feedback_message
from .utils import clean_text, format_file_size
from tenacity import retry, stop_after_attempt, wait_exponential, RetryError

console = Console(soft_wrap=True)

__all__ = ["search_models", "search_cli"]

custom_style = Style(
[
("qmark", "fg:#ffff00 bold"), # Yellow question mark
("question", "fg:#ffffff bold"), # White bold question text
("answer", "fg:#ffff00 bold"), # Yellow bold answer text
("pointer", "fg:#ffff00 bold"), # Yellow bold pointer
(
"highlighted",
"fg:#ffff00 bold",
), # black text on cyan background for highlighted items
("selected", "fg:#ffff00"), # Yellow for selected items
("separator", "fg:#ffff00"), # Yellow separator
("instruction", "fg:#ffffff"), # White instruction text
("text", "fg:#ffffff"), # White general text
("disabled", "fg:#ffff00 italic"), # Yellow italic for disabled items
]
)


def pagination_menu(
metadata: Dict[str, Any], has_previous: bool, download_function
) -> Optional[str]:
__all__ = ["search_models", "search_cli", "search_cli_sync"]

custom_style = Style([
("qmark", "fg:#ffff00 bold"),
("question", "fg:#ffffff bold"),
("answer", "fg:#ffff00 bold"),
("pointer", "fg:#ffff00 bold"),
("highlighted", "fg:#ffff00 bold"),
("selected", "fg:#ffff00"),
("separator", "fg:#ffff00"),
("instruction", "fg:#ffffff"),
("text", "fg:#ffffff"),
("disabled", "fg:#ffff00 italic"),
])

def pagination_menu(metadata: Dict[str, Any], has_previous: bool, download_function) -> Optional[str]:
choices = []

if has_previous:
choices.append("Previous Page")

if metadata.get("nextPage"):
choices.append("Next Page")
choices.extend(["Download Model", "Exit"])

choices.append("Download Model")
choices.append("Exit")

action = questionary.select(
"What would you like to do?", choices=choices, style=custom_style
).ask()
action = questionary.select("What would you like to do?", choices=choices, style=custom_style).ask()

if action == "Previous Page":
return "prev"
elif action == "Next Page":
return "next"
elif action == "Download Model":
model_id = questionary.text(
"Enter the Model ID you want to download:", style=custom_style
).ask()
model_id = questionary.text("Enter the Model ID you want to download:", style=custom_style).ask()
try:
model_id = int(model_id)
subprocess.run(f"civitai-models download {model_id}", shell=True)
except ValueError:
print("Invalid Model ID. Please enter a valid number.")
return None
elif action == "Exit":
return "exit"

return None


def validate_param(key: str, value: Any, valid_values: List[str]) -> bool:
if value not in valid_values and value is not None:
feedback_message(
f"\"{value}\" is not a valid {key}.\nPlease choose from: {', '.join(valid_values)}",
"error",
)
feedback_message(f"\"{value}\" is not a valid {key}.\nPlease choose from: {', '.join(valid_values)}", "error")
return False
return True

@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
async def make_api_request(client: httpx.AsyncClient, url: str, params: Dict[str, Any]) -> Dict[str, Any]:
response = await client.get(url, params=params, timeout=30)
response.raise_for_status()
return response.json()

def search_models(
query: str = "", CIVITAI_MODELS=None, TYPES=None, **kwargs
) -> Dict[str, Any]:
async def search_models(query: str = "", CIVITAI_MODELS=None, TYPES=None, **kwargs) -> Dict[str, Any]:
allowed_params = {
"tag": None,
"types": "Checkpoint",
Expand All @@ -90,10 +73,7 @@ def search_models(
"period": "AllTime",
"page": 1,
}
params = {
**allowed_params,
**{k: v for k, v in kwargs.items() if k in allowed_params},
}
params = {**allowed_params, **{k: v for k, v in kwargs.items() if k in allowed_params}}

if query:
params["query"] = query
Expand All @@ -107,92 +87,91 @@ def search_models(
if not all(validate_param(*v) for v in validations):
return {}

response = httpx.get(CIVITAI_MODELS, params=params)
return response.json() if response.status_code == 200 else {}


def search_cli(
query: str = "",
tag=None,
types="Checkpoint",
limit=20,
sort="Highest Rated",
period="AllTime",
CIVITAI_MODELS=None,
TYPES=None,
download_function=None,
) -> None:
async with httpx.AsyncClient() as client:
try:
return await make_api_request(client, CIVITAI_MODELS, params)
except RetryError:
feedback_message("Failed to connect to the API after multiple attempts.", "error")
return {}
except httpx.HTTPStatusError as e:
feedback_message(f"HTTP error occurred: {e}", "error")
return {}
except Exception as e:
feedback_message(f"An unexpected error occurred: {e}", "error")
return {}

async def search_cli(query: str = "", tag=None, types="Checkpoint", limit=20, sort="Highest Rated", period="AllTime", CIVITAI_MODELS=None, TYPES=None, download_function=None) -> None:
current_url = CIVITAI_MODELS
has_previous = False
page_history = []

while True:
with console.status("[yellow]Searching for models...", spinner="dots"):
response = httpx.get(
current_url,
params={
"query": query,
"tag": tag,
"types": types,
"limit": limit,
"sort": sort,
"period": period,
},
)
models = response.json() if response.status_code == 200 else {}

if models.get("items") == []:
feedback_message("No models found. Please try again.", "warning")
return

metadata = models.get("metadata", {})
# console.print(metadata)

search_table = create_table(
"",
[
("Model ID", "bright_yellow"),
("Model Name", "white"),
("Model Type", "bright_yellow"),
("Model NSFW", "white"),
("Model Tags", "white"),
],
)

for model in models.get("items", []):
name = Text(clean_text(model["name"]), style="bold", overflow="ellipsis")
tags = Text(", ".join(model["tags"]), style="italic", overflow="ellipsis")
size = Text(
format_file_size(model.get("modelVersions")[0]["files"][0]["sizeKB"]),
style="yellow",
)
nsfw = (
Text("Yes", style="green")
if model["nsfw"]
else Text("No", style="bright_red")
)
search_table.add_row(
str(model["id"]),
f"{name} // [yellow]{size}[/yellow]",
model["type"],
nsfw,
tags,
async with httpx.AsyncClient() as client:
while True:
with console.status("[yellow]Searching for models...", spinner="dots"):
try:
models = await make_api_request(client, current_url, {
"query": query,
"tag": tag,
"types": types,
"limit": limit,
"sort": sort,
"period": period,
})
except Exception as e:
feedback_message(f"Error occurred: {str(e)}", "error")
return

if models.get("items") == []:
feedback_message("No models found. Please try again.", "warning")
return

metadata = models.get("metadata", {})

search_table = create_table(
"",
[
("Model ID", "bright_yellow"),
("Model Name", "white"),
("Model Type", "bright_yellow"),
("Model NSFW", "white"),
("Model Tags", "white"),
],
)

console.print(search_table)

action = pagination_menu(metadata, has_previous, download_function)

if action == "prev":
if page_history:
current_url = page_history.pop()
has_previous = bool(page_history)
elif action == "next":
if metadata.get("nextPage"):
page_history.append(current_url)
current_url = metadata["nextPage"]
has_previous = True
elif action == "exit":
break
else:
continue
for model in models.get("items", []):
name = Text(clean_text(model["name"]), style="bold", overflow="ellipsis")
tags = Text(", ".join(model["tags"]), style="italic", overflow="ellipsis")
size = Text(format_file_size(model.get("modelVersions")[0]["files"][0]["sizeKB"]), style="yellow")
nsfw = Text("Yes", style="green") if model["nsfw"] else Text("No", style="bright_red")
search_table.add_row(
str(model["id"]),
f"{name} // [yellow]{size}[/yellow]",
model["type"],
nsfw,
tags,
)

console.print(search_table)

# Run the synchronous pagination_menu in the default event loop
action = await asyncio.get_event_loop().run_in_executor(None, pagination_menu, metadata, has_previous, download_function)

if action == "prev":
if page_history:
current_url = page_history.pop()
has_previous = bool(page_history)
elif action == "next":
if metadata.get("nextPage"):
page_history.append(current_url)
current_url = metadata["nextPage"]
has_previous = True
elif action == "exit":
break
else:
continue

def search_cli_sync(query: str = "", tag=None, types="Checkpoint", limit=20, sort="Highest Rated", period="AllTime", CIVITAI_MODELS=None, TYPES=None, download_function=None) -> None:
"""
Synchronous wrapper for the asynchronous search_cli function.
"""
asyncio.run(search_cli(query, tag, types, limit, sort, period, CIVITAI_MODELS, TYPES, download_function))

0 comments on commit 56adf6a

Please sign in to comment.