Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for seek by worker name (#243) #245

Merged
merged 1 commit into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions examples/ai_horde_client/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,21 @@
)


def all_workers(api_key: str, simple_client: AIHordeAPISimpleClient, filename: str) -> None:
def all_workers(
api_key: str,
simple_client: AIHordeAPISimpleClient,
filename: str,
*,
worker_name: str | None = None,
) -> None:
all_workers_response: AllWorkersDetailsResponse

all_workers_response = simple_client.workers_all_details()
all_workers_response = simple_client.workers_all_details(worker_name=worker_name)

if worker_name is None:
logger.info("Getting details for all workers.")
else:
logger.info(f"Getting details for worker with name: {worker_name}")

if all_workers_response is None:
raise ValueError("No workers returned in the response.")
Expand Down Expand Up @@ -101,6 +112,13 @@ def set_maintenance_mode(
help="The worker ID to get details for.",
)

group.add_argument(
"--worker_name",
"-n",
type=str,
help="The worker name to get details for.",
)

group2 = parser.add_mutually_exclusive_group()
group2.add_argument(
"--maintenance-mode-on",
Expand All @@ -123,7 +141,14 @@ def set_maintenance_mode(

simple_client = AIHordeAPISimpleClient()

if args.all:
if args.worker_name:
all_workers(
api_key=args.apikey,
simple_client=simple_client,
filename=args.filename,
worker_name=args.worker_name,
)
elif args.all:
all_workers(
api_key=args.apikey,
simple_client=simple_client,
Expand Down
9 changes: 7 additions & 2 deletions horde_sdk/ai_horde_api/ai_horde_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,14 +955,18 @@ def text_generate_request_dry_run(

def workers_all_details(
self,
worker_name: str | None = None,
) -> AllWorkersDetailsResponse:
"""Get all the details for all workers.

Returns:
WorkersAllDetailsResponse: The response from the API.
"""
with AIHordeAPIClientSession() as horde_session:
response = horde_session.submit_request(AllWorkersDetailsRequest(), AllWorkersDetailsResponse)
response = horde_session.submit_request(
AllWorkersDetailsRequest(name=worker_name),
AllWorkersDetailsResponse,
)

if isinstance(response, RequestErrorResponse):
raise AIHordeRequestError(response)
Expand Down Expand Up @@ -1643,6 +1647,7 @@ async def text_generate_request_dry_run(

async def workers_all_details(
self,
worker_name: str | None = None,
) -> AllWorkersDetailsResponse:
"""Get all the details for all workers.

Expand All @@ -1651,7 +1656,7 @@ async def workers_all_details(
"""
if self._horde_client_session is not None:
response = await self._horde_client_session.submit_request(
AllWorkersDetailsRequest(),
AllWorkersDetailsRequest(name=worker_name),
AllWorkersDetailsResponse,
)
else:
Expand Down
9 changes: 7 additions & 2 deletions horde_sdk/ai_horde_api/apimodels/workers/_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,15 @@ def get_api_model_name(cls) -> str | None:


class AllWorkersDetailsRequest(BaseAIHordeRequest, APIKeyAllowedInRequestMixin):
"""Returns information on all works. If a moderator API key is specified, it will return additional information."""
"""Returns information on all workers.

If a moderator API key is specified, it will return additional information.
"""

type_: WORKER_TYPE = Field(WORKER_TYPE.all, alias="type")
"""Filter workers by type. Default is 'all' which returns all workers."""
name: str | None = Field(None)
"""Returns a worker matching the exact name provided. Case insensitive."""

@override
@classmethod
Expand All @@ -217,7 +222,7 @@ def get_default_success_response_type(cls) -> type[AllWorkersDetailsResponse]:
@override
@classmethod
def get_query_fields(cls) -> list[str]:
return ["type_"]
return ["type_", "name"]

@classmethod
def is_api_key_required(cls) -> bool:
Expand Down
Loading