diff --git a/.github/workflows/static-analysis.yaml b/.github/workflows/static-analysis.yaml index 06bdb8b2cc92..e8fe3c2f3ab9 100644 --- a/.github/workflows/static-analysis.yaml +++ b/.github/workflows/static-analysis.yaml @@ -66,63 +66,4 @@ jobs: - name: Run pre-commit run: | - pre-commit run --show-diff-on-failure --color=always --all-files - - type-completeness-check: - name: Type completeness check - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - with: - persist-credentials: false - fetch-depth: 0 - - - name: Set up uv - uses: astral-sh/setup-uv@v5 - with: - python-version: "3.12" - - - name: Calculate type completeness score - id: calculate_current_score - run: | - # `pyright` will exit with a non-zero status code if it finds any issues, - # so we need to explicitly ignore the exit code with `|| true`. - uv tool run --with-editable . pyright --verifytypes prefect --ignoreexternal --outputjson > prefect-analysis.json || true - SCORE=$(jq -r '.typeCompleteness.completenessScore' prefect-analysis.json) - echo "current_score=$SCORE" >> $GITHUB_OUTPUT - - - name: Checkout base branch - run: | - git checkout ${{ github.base_ref }} - - - name: Calculate base branch score - id: calculate_base_score - run: | - uv tool run --with-editable . pyright --verifytypes prefect --ignoreexternal --outputjson > prefect-analysis-base.json || true - BASE_SCORE=$(jq -r '.typeCompleteness.completenessScore' prefect-analysis-base.json) - echo "base_score=$BASE_SCORE" >> $GITHUB_OUTPUT - - - name: Compare scores - run: | - CURRENT_SCORE=$(echo ${{ steps.calculate_current_score.outputs.current_score }}) - BASE_SCORE=$(echo ${{ steps.calculate_base_score.outputs.base_score }}) - - if (( $(echo "$BASE_SCORE > $CURRENT_SCORE" | bc -l) )); then - echo "::notice title=Type Completeness Check::We noticed a decrease in type coverage with these changes. Check workflow summary for more details." - echo "### ℹ️ Type Completeness Check" >> $GITHUB_STEP_SUMMARY - echo "We noticed a decrease in type coverage with these changes. To maintain our codebase quality, we aim to keep or improve type coverage with each change." >> $GITHUB_STEP_SUMMARY - echo "" >> $GITHUB_STEP_SUMMARY - echo "Need help? Ping @desertaxle or @zzstoatzz for assistance!" >> $GITHUB_STEP_SUMMARY - echo "" >> $GITHUB_STEP_SUMMARY - echo "Here's what changed:" >> $GITHUB_STEP_SUMMARY - echo "" >> $GITHUB_STEP_SUMMARY - uv run scripts/pyright_diff.py prefect-analysis-base.json prefect-analysis.json >> $GITHUB_STEP_SUMMARY - SCORE_DIFF=$(echo "$BASE_SCORE - $CURRENT_SCORE" | bc -l) - if (( $(echo "$SCORE_DIFF > 0.001" | bc -l) )); then - exit 1 - fi - elif (( $(echo "$BASE_SCORE < $CURRENT_SCORE" | bc -l) )); then - echo "🎉 Great work! The type coverage has improved with these changes" >> $GITHUB_STEP_SUMMARY - else - echo "✅ Type coverage maintained" >> $GITHUB_STEP_SUMMARY - fi + pre-commit run --show-diff-on-failure --color=always --all-files \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2773300caa3a..eda214b85bc9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,6 +32,11 @@ repos: )$ - repo: local hooks: + - id: type-completeness-check + name: Type Completeness Check + language: system + entry: uv run --with pyright pyright --ignoreexternal --verifytypes prefect + pass_filenames: false - id: generate-mintlify-openapi-docs name: Generating OpenAPI docs for Mintlify language: system diff --git a/README.md b/README.md index 11dfb13c0366..3c8e7cbce6ed 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,28 @@

+

+ + Installation + + · + + Quickstart + + · + + Build workflows + + · + + Deploy workflows + + · + + Prefect Cloud + +

+ # Prefect Prefect is a workflow orchestration framework for building data pipelines in Python. @@ -26,6 +48,11 @@ With just a few lines of code, data teams can confidently automate any data proc Workflow activity is tracked and can be monitored with a self-hosted [Prefect server](https://docs.prefect.io/latest/manage/self-host/?utm_source=oss&utm_medium=oss&utm_campaign=oss_gh_repo&utm_term=none&utm_content=none) instance or managed [Prefect Cloud](https://www.prefect.io/cloud-vs-oss?utm_source=oss&utm_medium=oss&utm_campaign=oss_gh_repo&utm_term=none&utm_content=none) dashboard. +> [!TIP] +> Prefect flows can handle retries, dependencies, and even complex branching logic +> +> [Check our docs](https://docs.prefect.io/v3/get-started/index?utm_source=oss&utm_medium=oss&utm_campaign=oss_gh_repo&utm_term=none&utm_content=none) or see the example below to learn more! + ## Getting started Prefect requires Python 3.9 or later. To [install the latest or upgrade to the latest version of Prefect](https://docs.prefect.io/get-started/install), run the following command: @@ -79,8 +106,12 @@ if __name__ == "__main__": You now have a process running locally that is looking for scheduled deployments! Additionally you can run your workflow manually from the UI or CLI. You can even run deployments in response to [events](https://docs.prefect.io/latest/automate/?utm_source=oss&utm_medium=oss&utm_campaign=oss_gh_repo&utm_term=none&utm_content=none). -> [!NOTE] -> To explore different infrastructure options for your workflows, check out the [deployment documentation](https://docs.prefect.io/v3/deploy). +> [!TIP] +> Where to go next - check out our [documentation](https://docs.prefect.io/v3/get-started/index?utm_source=oss&utm_medium=oss&utm_campaign=oss_gh_repo&utm_term=none&utm_content=none) to learn more about: +> - [Deploying flows to production environments](https://docs.prefect.io/v3/deploy?utm_source=oss&utm_medium=oss&utm_campaign=oss_gh_repo&utm_term=none&utm_content=none) +> - [Adding error handling and retries](https://docs.prefect.io/v3/develop/write-tasks#retries?utm_source=oss&utm_medium=oss&utm_campaign=oss_gh_repo&utm_term=none&utm_content=none) +> - [Integrating with your existing tools](https://docs.prefect.io/integrations/integrations?utm_source=oss&utm_medium=oss&utm_campaign=oss_gh_repo&utm_term=none&utm_content=none) +> - [Setting up team collaboration features](https://docs.prefect.io/v3/manage/cloud/manage-users/manage-teams#manage-teams?utm_source=oss&utm_medium=oss&utm_campaign=oss_gh_repo&utm_term=none&utm_content=none) ## Prefect Cloud diff --git a/docs/v3/api-ref/rest-api/server/schema.json b/docs/v3/api-ref/rest-api/server/schema.json index 036dc192dbff..3e4acc39393f 100644 --- a/docs/v3/api-ref/rest-api/server/schema.json +++ b/docs/v3/api-ref/rest-api/server/schema.json @@ -9355,7 +9355,10 @@ "description": "Successful Response", "content": { "application/json": { - "schema": {} + "schema": { + "type": "object", + "title": "Response Validate Obj Ui Schemas Validate Post" + } } } }, @@ -9727,7 +9730,10 @@ "description": "Successful Response", "content": { "application/json": { - "schema": {} + "schema": { + "type": "string", + "title": "Response Hello Hello Get" + } } } }, @@ -22594,6 +22600,9 @@ "description": "An ORM representation of task run data." }, "TaskRunCount": { + "additionalProperties": { + "type": "integer" + }, "type": "object" }, "TaskRunCreate": { diff --git a/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run.py b/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run.py index 298386ef168d..0b7ab08d6ae9 100644 --- a/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run.py +++ b/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run.py @@ -166,12 +166,12 @@ from google.api_core.client_options import ClientOptions from googleapiclient import discovery from googleapiclient.discovery import Resource +from jsonpatch import JsonPatch from pydantic import Field, field_validator from prefect.logging.loggers import PrefectLogAdapter from prefect.utilities.asyncutils import run_sync_in_worker_thread from prefect.utilities.dockerutils import get_prefect_image_name -from prefect.utilities.pydantic import JsonPatch from prefect.workers.base import ( BaseJobConfiguration, BaseVariables, diff --git a/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run_v2.py b/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run_v2.py index 0d3b6989378a..348ad84fa35a 100644 --- a/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run_v2.py +++ b/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run_v2.py @@ -11,12 +11,12 @@ # noinspection PyProtectedMember from googleapiclient.discovery import Resource from googleapiclient.errors import HttpError +from jsonpatch import JsonPatch from pydantic import Field, PrivateAttr, field_validator from prefect.logging.loggers import PrefectLogAdapter from prefect.utilities.asyncutils import run_sync_in_worker_thread from prefect.utilities.dockerutils import get_prefect_image_name -from prefect.utilities.pydantic import JsonPatch from prefect.workers.base import ( BaseJobConfiguration, BaseVariables, diff --git a/src/integrations/prefect-gcp/prefect_gcp/workers/vertex.py b/src/integrations/prefect-gcp/prefect_gcp/workers/vertex.py index d11efff4c3d0..4f92c7a195c1 100644 --- a/src/integrations/prefect-gcp/prefect_gcp/workers/vertex.py +++ b/src/integrations/prefect-gcp/prefect_gcp/workers/vertex.py @@ -28,11 +28,11 @@ from uuid import uuid4 import anyio +from jsonpatch import JsonPatch from pydantic import Field, field_validator from slugify import slugify from prefect.logging.loggers import PrefectLogAdapter -from prefect.utilities.pydantic import JsonPatch from prefect.workers.base import ( BaseJobConfiguration, BaseVariables, diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py index 0d3f58637d66..2f83394816cc 100644 --- a/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py @@ -121,6 +121,7 @@ import aiohttp import anyio.abc import kubernetes_asyncio +from jsonpatch import JsonPatch from kubernetes_asyncio import config from kubernetes_asyncio.client import ( ApiClient, @@ -148,7 +149,6 @@ from prefect.server.schemas.core import Flow from prefect.server.schemas.responses import DeploymentResponse from prefect.utilities.dockerutils import get_prefect_image_name -from prefect.utilities.pydantic import JsonPatch from prefect.utilities.templating import find_placeholders from prefect.utilities.timeout import timeout_async from prefect.workers.base import ( diff --git a/src/prefect/_internal/schemas/validators.py b/src/prefect/_internal/schemas/validators.py index 4bbc9c7de15a..879774fb297e 100644 --- a/src/prefect/_internal/schemas/validators.py +++ b/src/prefect/_internal/schemas/validators.py @@ -6,6 +6,8 @@ This will be subject to consolidation and refactoring over the next few months. """ +from __future__ import annotations + import os import re import urllib.parse @@ -627,18 +629,18 @@ def validate_name_present_on_nonanonymous_blocks(values: M) -> M: @overload -def validate_command(v: str) -> Path: +def validate_working_dir(v: str) -> Path: ... @overload -def validate_command(v: None) -> None: +def validate_working_dir(v: None) -> None: ... -def validate_command(v: Optional[str]) -> Optional[Path]: +def validate_working_dir(v: Optional[Path | str]) -> Optional[Path]: """Make sure that the working directory is formatted for the current platform.""" - if v is not None: + if isinstance(v, str): return relative_path_to_current_platform(v) return v diff --git a/src/prefect/artifacts.py b/src/prefect/artifacts.py index b8903f960c8c..559093abb862 100644 --- a/src/prefect/artifacts.py +++ b/src/prefect/artifacts.py @@ -20,7 +20,10 @@ from prefect.utilities.asyncutils import sync_compatible from prefect.utilities.context import get_task_and_flow_run_ids -logger = get_logger("artifacts") +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger("artifacts") if TYPE_CHECKING: from prefect.client.orchestration import PrefectClient diff --git a/src/prefect/automations.py b/src/prefect/automations.py index 3201508220bb..6ec5192a4354 100644 --- a/src/prefect/automations.py +++ b/src/prefect/automations.py @@ -137,7 +137,7 @@ def create(self: Self) -> Self: self.id = client.create_automation(automation=automation) return self - async def aupdate(self: Self): + async def aupdate(self: Self) -> None: """ Updates an existing automation. diff --git a/src/prefect/blocks/abstract.py b/src/prefect/blocks/abstract.py index ca29dd04b03e..495eaa67f446 100644 --- a/src/prefect/blocks/abstract.py +++ b/src/prefect/blocks/abstract.py @@ -15,7 +15,7 @@ Union, ) -from typing_extensions import Self, TypeAlias +from typing_extensions import TYPE_CHECKING, Self, TypeAlias from prefect.blocks.core import Block from prefect.exceptions import MissingContextError @@ -26,7 +26,10 @@ if sys.version_info >= (3, 12): LoggingAdapter = logging.LoggerAdapter[logging.Logger] else: - LoggingAdapter = logging.LoggerAdapter + if TYPE_CHECKING: + LoggingAdapter = logging.LoggerAdapter[logging.Logger] + else: + LoggingAdapter = logging.LoggerAdapter LoggerOrAdapter: TypeAlias = Union[Logger, LoggingAdapter] diff --git a/src/prefect/cli/dashboard.py b/src/prefect/cli/dashboard.py index a7b2caea7145..47154f94cac8 100644 --- a/src/prefect/cli/dashboard.py +++ b/src/prefect/cli/dashboard.py @@ -8,7 +8,7 @@ from prefect.settings import PREFECT_UI_URL from prefect.utilities.asyncutils import run_sync_in_worker_thread -dashboard_app = PrefectTyper( +dashboard_app: PrefectTyper = PrefectTyper( name="dashboard", help="Commands for interacting with the Prefect UI.", ) @@ -16,7 +16,7 @@ @dashboard_app.command() -async def open(): +async def open() -> None: """ Open the Prefect UI in the browser. """ diff --git a/src/prefect/cli/dev.py b/src/prefect/cli/dev.py index 4a05e53bc398..c901d8fe717d 100644 --- a/src/prefect/cli/dev.py +++ b/src/prefect/cli/dev.py @@ -35,13 +35,13 @@ Note that many of these commands require extra dependencies (such as npm and MkDocs) to function properly. """ -dev_app = PrefectTyper( +dev_app: PrefectTyper = PrefectTyper( name="dev", short_help="Internal Prefect development.", help=DEV_HELP ) app.add_typer(dev_app) -def exit_with_error_if_not_editable_install(): +def exit_with_error_if_not_editable_install() -> None: if ( prefect.__module_path__.parent == "site-packages" or not (prefect.__development_base_path__ / "setup.py").exists() diff --git a/src/prefect/cli/events.py b/src/prefect/cli/events.py index 1b84c8923947..e9e7b423c15c 100644 --- a/src/prefect/cli/events.py +++ b/src/prefect/cli/events.py @@ -15,7 +15,7 @@ get_events_subscriber, ) -events_app = PrefectTyper(name="events", help="Stream events.") +events_app: PrefectTyper = PrefectTyper(name="events", help="Stream events.") app.add_typer(events_app, aliases=["event"]) @@ -60,7 +60,7 @@ async def stream( handle_error(exc) -async def handle_event(event: Event, format: StreamFormat, output_file: str): +async def handle_event(event: Event, format: StreamFormat, output_file: str) -> None: if format == StreamFormat.json: event_data = orjson.dumps(event.model_dump(), default=str).decode() elif format == StreamFormat.text: @@ -74,7 +74,7 @@ async def handle_event(event: Event, format: StreamFormat, output_file: str): print(event_data) -def handle_error(exc): +def handle_error(exc: Exception) -> None: if isinstance(exc, websockets.exceptions.ConnectionClosedError): exit_with_error(f"Connection closed, retrying... ({exc})") elif isinstance(exc, (KeyboardInterrupt, asyncio.exceptions.CancelledError)): diff --git a/src/prefect/cli/flow.py b/src/prefect/cli/flow.py index acc37b4a24cd..735e1958b374 100644 --- a/src/prefect/cli/flow.py +++ b/src/prefect/cli/flow.py @@ -19,7 +19,7 @@ from prefect.runner import Runner from prefect.utilities import urls -flow_app = PrefectTyper(name="flow", help="View and serve flows.") +flow_app: PrefectTyper = PrefectTyper(name="flow", help="View and serve flows.") app.add_typer(flow_app, aliases=["flows"]) diff --git a/src/prefect/cli/flow_run.py b/src/prefect/cli/flow_run.py index 41bb4fd4189f..63d862c081da 100644 --- a/src/prefect/cli/flow_run.py +++ b/src/prefect/cli/flow_run.py @@ -28,13 +28,15 @@ from prefect.runner import Runner from prefect.states import State -flow_run_app = PrefectTyper(name="flow-run", help="Interact with flow runs.") +flow_run_app: PrefectTyper = PrefectTyper( + name="flow-run", help="Interact with flow runs." +) app.add_typer(flow_run_app, aliases=["flow-runs"]) LOGS_DEFAULT_PAGE_SIZE = 200 LOGS_WITH_LIMIT_FLAG_DEFAULT_NUM_LOGS = 20 -logger = get_logger(__name__) +logger: "logging.Logger" = get_logger(__name__) @flow_run_app.command() diff --git a/src/prefect/cli/global_concurrency_limit.py b/src/prefect/cli/global_concurrency_limit.py index a391fe327476..dff54dd4d6c3 100644 --- a/src/prefect/cli/global_concurrency_limit.py +++ b/src/prefect/cli/global_concurrency_limit.py @@ -22,7 +22,7 @@ PrefectHTTPStatusError, ) -global_concurrency_limit_app = PrefectTyper( +global_concurrency_limit_app: PrefectTyper = PrefectTyper( name="global-concurrency-limit", help="Manage global concurrency limits.", ) diff --git a/src/prefect/cli/profile.py b/src/prefect/cli/profile.py index ce74709ba5ca..55dbd720f85a 100644 --- a/src/prefect/cli/profile.py +++ b/src/prefect/cli/profile.py @@ -26,7 +26,9 @@ from prefect.settings import ProfilesCollection from prefect.utilities.collections import AutoEnum -profile_app = PrefectTyper(name="profile", help="Select and manage Prefect profiles.") +profile_app: PrefectTyper = PrefectTyper( + name="profile", help="Select and manage Prefect profiles." +) app.add_typer(profile_app, aliases=["profiles"]) _OLD_MINIMAL_DEFAULT_PROFILE_CONTENT: str = """active = "default" @@ -263,8 +265,8 @@ def inspect( def show_profile_changes( user_profiles: ProfilesCollection, default_profiles: ProfilesCollection -): - changes = [] +) -> bool: + changes: list[tuple[str, str]] = [] for name in default_profiles.names: if name not in user_profiles: @@ -343,7 +345,7 @@ class ConnectionStatus(AutoEnum): INVALID_API = AutoEnum.auto() -async def check_server_connection(): +async def check_server_connection() -> ConnectionStatus: httpx_settings = dict(timeout=3) try: # attempt to infer Cloud 2.0 API from the connection URL diff --git a/src/prefect/cli/root.py b/src/prefect/cli/root.py index a6aa37ea7afd..1a3b9973e225 100644 --- a/src/prefect/cli/root.py +++ b/src/prefect/cli/root.py @@ -26,16 +26,16 @@ PREFECT_TEST_MODE, ) -app = PrefectTyper(add_completion=True, no_args_is_help=True) +app: PrefectTyper = PrefectTyper(add_completion=True, no_args_is_help=True) -def version_callback(value: bool): +def version_callback(value: bool) -> None: if value: print(prefect.__version__) raise typer.Exit() -def is_interactive(): +def is_interactive() -> bool: return app.console.is_interactive @@ -157,7 +157,7 @@ def get_prefect_integrations() -> Dict[str, str]: return integrations -def display(object: Dict[str, Any], nesting: int = 0): +def display(object: Dict[str, Any], nesting: int = 0) -> None: """Recursive display of a dictionary with nesting.""" for key, value in object.items(): key += ":" diff --git a/src/prefect/cli/server.py b/src/prefect/cli/server.py index acbf56a52871..8b9996931bf1 100644 --- a/src/prefect/cli/server.py +++ b/src/prefect/cli/server.py @@ -15,6 +15,7 @@ import textwrap from pathlib import Path from types import ModuleType +from typing import TYPE_CHECKING import typer import uvicorn @@ -48,23 +49,30 @@ from prefect.settings.context import temporary_settings from prefect.utilities.asyncutils import run_sync_in_worker_thread -server_app = PrefectTyper( +if TYPE_CHECKING: + import logging + +server_app: PrefectTyper = PrefectTyper( name="server", help="Start a Prefect server instance and interact with the database", ) -database_app = PrefectTyper(name="database", help="Interact with the database.") -services_app = PrefectTyper(name="services", help="Interact with server loop services.") +database_app: PrefectTyper = PrefectTyper( + name="database", help="Interact with the database." +) +services_app: PrefectTyper = PrefectTyper( + name="services", help="Interact with server loop services." +) server_app.add_typer(database_app) server_app.add_typer(services_app) app.add_typer(server_app) -logger = get_logger(__name__) +logger: "logging.Logger" = get_logger(__name__) SERVER_PID_FILE_NAME = "server.pid" SERVICES_PID_FILE = Path(PREFECT_HOME.value()) / "services.pid" -def generate_welcome_blurb(base_url: str, ui_enabled: bool): +def generate_welcome_blurb(base_url: str, ui_enabled: bool) -> str: blurb = textwrap.dedent( r""" ___ ___ ___ ___ ___ ___ _____ @@ -109,7 +117,7 @@ def generate_welcome_blurb(base_url: str, ui_enabled: bool): return blurb -def prestart_check(base_url: str): +def prestart_check(base_url: str) -> None: """ Check if `PREFECT_API_URL` is set in the current profile. If not, prompt the user to set it. diff --git a/src/prefect/cli/shell.py b/src/prefect/cli/shell.py index 6837e5d00950..968ddd2887f7 100644 --- a/src/prefect/cli/shell.py +++ b/src/prefect/cli/shell.py @@ -8,7 +8,7 @@ import subprocess import sys import threading -from typing import Any, Dict, List, Optional +from typing import IO, Any, Callable, Dict, List, Optional import typer from typing_extensions import Annotated @@ -25,13 +25,13 @@ from prefect.settings import PREFECT_UI_URL from prefect.types.entrypoint import EntrypointType -shell_app = PrefectTyper( +shell_app: PrefectTyper = PrefectTyper( name="shell", help="Serve and watch shell commands as Prefect flows." ) app.add_typer(shell_app) -def output_stream(pipe, logger_function): +def output_stream(pipe: IO[str], logger_function: Callable[[str], None]) -> None: """ Read from a pipe line by line and log using the provided logging function. @@ -44,7 +44,7 @@ def output_stream(pipe, logger_function): logger_function(line.strip()) -def output_collect(pipe, container): +def output_collect(pipe: IO[str], container: list[str]) -> None: """ Collects output from a subprocess pipe and stores it in a container list. diff --git a/src/prefect/cli/task.py b/src/prefect/cli/task.py index bedc2b4f418e..21f059190778 100644 --- a/src/prefect/cli/task.py +++ b/src/prefect/cli/task.py @@ -8,7 +8,7 @@ from prefect.task_worker import serve as task_serve from prefect.utilities.importtools import import_object -task_app = PrefectTyper(name="task", help="Work with task scheduling.") +task_app: PrefectTyper = PrefectTyper(name="task", help="Work with task scheduling.") app.add_typer(task_app, aliases=["task"]) diff --git a/src/prefect/cli/task_run.py b/src/prefect/cli/task_run.py index e0c35177fdec..b425b9a3c307 100644 --- a/src/prefect/cli/task_run.py +++ b/src/prefect/cli/task_run.py @@ -22,7 +22,9 @@ from prefect.client.schemas.sorting import LogSort, TaskRunSort from prefect.exceptions import ObjectNotFound -task_run_app = PrefectTyper(name="task-run", help="View and inspect task runs.") +task_run_app: PrefectTyper = PrefectTyper( + name="task-run", help="View and inspect task runs." +) app.add_typer(task_run_app, aliases=["task-runs"]) LOGS_DEFAULT_PAGE_SIZE = 200 diff --git a/src/prefect/cli/variable.py b/src/prefect/cli/variable.py index dfb1a7cac7d0..919ba3556eaf 100644 --- a/src/prefect/cli/variable.py +++ b/src/prefect/cli/variable.py @@ -14,7 +14,7 @@ from prefect.client.schemas.actions import VariableCreate, VariableUpdate from prefect.exceptions import ObjectNotFound -variable_app = PrefectTyper(name="variable", help="Manage variables.") +variable_app: PrefectTyper = PrefectTyper(name="variable", help="Manage variables.") app.add_typer(variable_app) diff --git a/src/prefect/cli/work_pool.py b/src/prefect/cli/work_pool.py index 2c9af2a40f52..ff528d959458 100644 --- a/src/prefect/cli/work_pool.py +++ b/src/prefect/cli/work_pool.py @@ -32,11 +32,11 @@ get_default_base_job_template_for_infrastructure_type, ) -work_pool_app = PrefectTyper(name="work-pool", help="Manage work pools.") +work_pool_app: PrefectTyper = PrefectTyper(name="work-pool", help="Manage work pools.") app.add_typer(work_pool_app, aliases=["work-pool"]) -def set_work_pool_as_default(name: str): +def set_work_pool_as_default(name: str) -> None: profile = update_current_profile({"PREFECT_DEFAULT_WORK_POOL_NAME": name}) app.console.print( f"Set {name!r} as default work pool for profile {profile.name!r}\n", diff --git a/src/prefect/cli/work_queue.py b/src/prefect/cli/work_queue.py index 6cabb8700f3a..069d8296b48a 100644 --- a/src/prefect/cli/work_queue.py +++ b/src/prefect/cli/work_queue.py @@ -20,7 +20,7 @@ from prefect.client.schemas.objects import DEFAULT_AGENT_WORK_POOL_NAME from prefect.exceptions import ObjectAlreadyExists, ObjectNotFound -work_app = PrefectTyper(name="work-queue", help="Manage work queues.") +work_app: PrefectTyper = PrefectTyper(name="work-queue", help="Manage work queues.") app.add_typer(work_app, aliases=["work-queues"]) diff --git a/src/prefect/cli/worker.py b/src/prefect/cli/worker.py index 523787548eed..34718485ceff 100644 --- a/src/prefect/cli/worker.py +++ b/src/prefect/cli/worker.py @@ -28,7 +28,9 @@ ) from prefect.workers.base import BaseWorker -worker_app = PrefectTyper(name="worker", help="Start and interact with workers.") +worker_app: PrefectTyper = PrefectTyper( + name="worker", help="Start and interact with workers." +) app.add_typer(worker_app) diff --git a/src/prefect/client/utilities.py b/src/prefect/client/utilities.py index 4622a7d6fe32..3aa7043333d3 100644 --- a/src/prefect/client/utilities.py +++ b/src/prefect/client/utilities.py @@ -5,7 +5,7 @@ # This module must not import from `prefect.client` when it is imported to avoid # circular imports for decorators such as `inject_client` which are widely used. -from collections.abc import Awaitable, Coroutine +from collections.abc import Coroutine from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Optional, Union @@ -61,8 +61,8 @@ def get_or_create_client( def client_injector( - func: Callable[Concatenate["PrefectClient", P], Awaitable[R]], -) -> Callable[P, Awaitable[R]]: + func: Callable[Concatenate["PrefectClient", P], Coroutine[Any, Any, R]], +) -> Callable[P, Coroutine[Any, Any, R]]: @wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: client, _ = get_or_create_client() diff --git a/src/prefect/deployments/base.py b/src/prefect/deployments/base.py index 9ec16a909fff..5700b8c34848 100644 --- a/src/prefect/deployments/base.py +++ b/src/prefect/deployments/base.py @@ -19,6 +19,7 @@ from prefect.client.schemas.objects import ConcurrencyLimitStrategy from prefect.client.schemas.schedules import IntervalSchedule from prefect.utilities._git import get_git_branch, get_git_remote_origin_url +from prefect.utilities.annotations import NotSet from prefect.utilities.filesystem import create_default_ignore_file from prefect.utilities.templating import apply_values @@ -113,7 +114,9 @@ def create_default_prefect_yaml( return True -def configure_project_by_recipe(recipe: str, **formatting_kwargs) -> dict: +def configure_project_by_recipe( + recipe: str, **formatting_kwargs: Any +) -> dict[str, Any] | type[NotSet]: """ Given a recipe name, returns a dictionary representing base configuration options. @@ -131,13 +134,13 @@ def configure_project_by_recipe(recipe: str, **formatting_kwargs) -> dict: raise ValueError(f"Unknown recipe {recipe!r} provided.") with recipe_path.open(mode="r") as f: - config = yaml.safe_load(f) + config: dict[str, Any] = yaml.safe_load(f) - config = apply_values( + templated_config = apply_values( template=config, values=formatting_kwargs, remove_notset=False ) - return config + return templated_config def initialize_project( diff --git a/src/prefect/deployments/flow_runs.py b/src/prefect/deployments/flow_runs.py index e54615a96952..169539df574f 100644 --- a/src/prefect/deployments/flow_runs.py +++ b/src/prefect/deployments/flow_runs.py @@ -29,7 +29,11 @@ } ) -logger = get_logger(__name__) + +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger(__name__) @sync_compatible diff --git a/src/prefect/deployments/steps/core.py b/src/prefect/deployments/steps/core.py index ef6118b297a9..4efc15dbcaf6 100644 --- a/src/prefect/deployments/steps/core.py +++ b/src/prefect/deployments/steps/core.py @@ -141,7 +141,7 @@ async def run_steps( steps: List[Dict[str, Any]], upstream_outputs: Optional[Dict[str, Any]] = None, print_function: Any = print, -): +) -> dict[str, Any]: upstream_outputs = deepcopy(upstream_outputs) if upstream_outputs else {} for step in steps: if not step: diff --git a/src/prefect/deployments/steps/pull.py b/src/prefect/deployments/steps/pull.py index 1bc0ec06d64f..54faa6f337ee 100644 --- a/src/prefect/deployments/steps/pull.py +++ b/src/prefect/deployments/steps/pull.py @@ -12,7 +12,10 @@ from prefect.runner.storage import BlockStorageAdapter, GitRepository, RemoteStorage from prefect.utilities.asyncutils import run_coro_as_sync -deployment_logger = get_logger("deployment") +if TYPE_CHECKING: + import logging + +deployment_logger: "logging.Logger" = get_logger("deployment") if TYPE_CHECKING: from prefect.blocks.core import Block @@ -197,7 +200,7 @@ def git_clone( return dict(directory=str(storage.destination.relative_to(Path.cwd()))) -async def pull_from_remote_storage(url: str, **settings: Any): +async def pull_from_remote_storage(url: str, **settings: Any) -> dict[str, Any]: """ Pulls code from a remote storage location into the current working directory. @@ -239,7 +242,9 @@ async def pull_from_remote_storage(url: str, **settings: Any): return {"directory": directory} -async def pull_with_block(block_document_name: str, block_type_slug: str): +async def pull_with_block( + block_document_name: str, block_type_slug: str +) -> dict[str, Any]: """ Pulls code using a block. diff --git a/src/prefect/deployments/steps/utility.py b/src/prefect/deployments/steps/utility.py index 21e2436858d0..53f8529ff387 100644 --- a/src/prefect/deployments/steps/utility.py +++ b/src/prefect/deployments/steps/utility.py @@ -26,7 +26,7 @@ import string import subprocess import sys -from typing import Dict, Optional +from typing import Any, Dict, Optional from anyio import create_task_group from anyio.streams.text import TextReceiveStream @@ -205,7 +205,7 @@ async def pip_install_requirements( directory: Optional[str] = None, requirements_file: str = "requirements.txt", stream_output: bool = True, -): +) -> dict[str, Any]: """ Installs dependencies from a requirements.txt file. diff --git a/src/prefect/docker/docker_image.py b/src/prefect/docker/docker_image.py index c58442cd0e94..a2de9980d52a 100644 --- a/src/prefect/docker/docker_image.py +++ b/src/prefect/docker/docker_image.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional +from typing import Any, Optional from pendulum import now as pendulum_now @@ -34,7 +34,11 @@ class DockerImage: """ def __init__( - self, name: str, tag: Optional[str] = None, dockerfile="auto", **build_kwargs + self, + name: str, + tag: Optional[str] = None, + dockerfile: str = "auto", + **build_kwargs: Any, ): image_name, image_tag = parse_image_tag(name) if tag and image_tag: @@ -49,16 +53,16 @@ def __init__( namespace = PREFECT_DEFAULT_DOCKER_BUILD_NAMESPACE.value() # join the namespace and repository to create the full image name # ignore namespace if it is None - self.name = "/".join(filter(None, [namespace, repository])) - self.tag = tag or image_tag or slugify(pendulum_now("utc").isoformat()) - self.dockerfile = dockerfile - self.build_kwargs = build_kwargs + self.name: str = "/".join(filter(None, [namespace, repository])) + self.tag: str = tag or image_tag or slugify(pendulum_now("utc").isoformat()) + self.dockerfile: str = dockerfile + self.build_kwargs: dict[str, Any] = build_kwargs @property - def reference(self): + def reference(self) -> str: return f"{self.name}:{self.tag}" - def build(self): + def build(self) -> None: full_image_name = self.reference build_kwargs = self.build_kwargs.copy() build_kwargs["context"] = Path.cwd() @@ -72,7 +76,7 @@ def build(self): build_kwargs["dockerfile"] = self.dockerfile build_image(**build_kwargs) - def push(self): + def push(self) -> None: with docker_client() as client: events = client.api.push( repository=self.name, tag=self.tag, stream=True, decode=True diff --git a/src/prefect/engine.py b/src/prefect/engine.py index 262fdaa67071..a2829a7a2001 100644 --- a/src/prefect/engine.py +++ b/src/prefect/engine.py @@ -1,6 +1,6 @@ import os import sys -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable from uuid import UUID from prefect._internal.compatibility.migration import getattr_migration @@ -15,12 +15,19 @@ run_coro_as_sync, ) -engine_logger = get_logger("engine") +if TYPE_CHECKING: + import logging + + from prefect.flow_engine import FlowRun + from prefect.flows import Flow + from prefect.logging.loggers import LoggingAdapter + +engine_logger: "logging.Logger" = get_logger("engine") if __name__ == "__main__": try: - flow_run_id = UUID( + flow_run_id: UUID = UUID( sys.argv[1] if len(sys.argv) > 1 else os.environ.get("PREFECT__FLOW_RUN_ID") ) except Exception: @@ -37,11 +44,11 @@ run_flow, ) - flow_run = load_flow_run(flow_run_id=flow_run_id) - run_logger = flow_run_logger(flow_run=flow_run) + flow_run: "FlowRun" = load_flow_run(flow_run_id=flow_run_id) + run_logger: "LoggingAdapter" = flow_run_logger(flow_run=flow_run) try: - flow = load_flow(flow_run) + flow: "Flow[..., Any]" = load_flow(flow_run) except Exception: run_logger.error( "Unexpected exception encountered when trying to load flow", @@ -55,15 +62,17 @@ else: run_flow(flow, flow_run=flow_run, error_logger=run_logger) - except Abort as exc: + except Abort as abort_signal: + abort_signal: Abort engine_logger.info( f"Engine execution of flow run '{flow_run_id}' aborted by orchestrator:" - f" {exc}" + f" {abort_signal}" ) exit(0) - except Pause as exc: + except Pause as pause_signal: + pause_signal: Pause engine_logger.info( - f"Engine execution of flow run '{flow_run_id}' is paused: {exc}" + f"Engine execution of flow run '{flow_run_id}' is paused: {pause_signal}" ) exit(0) except Exception: diff --git a/src/prefect/events/cli/automations.py b/src/prefect/events/cli/automations.py index b4ce5dbac804..c70c2ec34c05 100644 --- a/src/prefect/events/cli/automations.py +++ b/src/prefect/events/cli/automations.py @@ -3,7 +3,7 @@ """ import functools -from typing import Optional, Type +from typing import Any, Callable, Optional, Type from uuid import UUID import orjson @@ -21,16 +21,16 @@ from prefect.events.schemas.automations import Automation from prefect.exceptions import PrefectHTTPStatusError -automations_app = PrefectTyper( +automations_app: PrefectTyper = PrefectTyper( name="automation", help="Manage automations.", ) app.add_typer(automations_app, aliases=["automations"]) -def requires_automations(func): +def requires_automations(func: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(func) - async def wrapper(*args, **kwargs): + async def wrapper(*args: Any, **kwargs: Any) -> Any: try: return await func(*args, **kwargs) except RuntimeError as exc: diff --git a/src/prefect/events/clients.py b/src/prefect/events/clients.py index 41b83ee9f41e..3ce243ee9337 100644 --- a/src/prefect/events/clients.py +++ b/src/prefect/events/clients.py @@ -70,18 +70,21 @@ labelnames=["client"], ) -logger = get_logger(__name__) +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger(__name__) -def http_to_ws(url: str): +def http_to_ws(url: str) -> str: return url.replace("https://", "wss://").replace("http://", "ws://").rstrip("/") -def events_in_socket_from_api_url(url: str): +def events_in_socket_from_api_url(url: str) -> str: return http_to_ws(url) + "/events/in" -def events_out_socket_from_api_url(url: str): +def events_out_socket_from_api_url(url: str) -> str: return http_to_ws(url) + "/events/out" @@ -250,11 +253,11 @@ class AssertingEventsClient(EventsClient): last: ClassVar["Optional[AssertingEventsClient]"] = None all: ClassVar[List["AssertingEventsClient"]] = [] - args: Tuple - kwargs: Dict[str, Any] - events: List[Event] + args: tuple[Any, ...] + kwargs: dict[str, Any] + events: list[Event] - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): AssertingEventsClient.last = self AssertingEventsClient.all.append(self) self.args = args @@ -431,13 +434,13 @@ class AssertingPassthroughEventsClient(PrefectEventsClient): during tests AND sends them to a Prefect server.""" last: ClassVar["Optional[AssertingPassthroughEventsClient]"] = None - all: ClassVar[List["AssertingPassthroughEventsClient"]] = [] + all: ClassVar[list["AssertingPassthroughEventsClient"]] = [] - args: Tuple - kwargs: Dict[str, Any] - events: List[Event] + args: tuple[Any, ...] + kwargs: dict[str, Any] + events: list[Event] - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) AssertingPassthroughEventsClient.last = self AssertingPassthroughEventsClient.all.append(self) @@ -449,7 +452,7 @@ def reset(cls) -> None: cls.last = None cls.all = [] - def pop_events(self) -> List[Event]: + def pop_events(self) -> list[Event]: events = self.events self.events = [] return events diff --git a/src/prefect/events/schemas/automations.py b/src/prefect/events/schemas/automations.py index 97fe5c7e9a50..2eeea40214c6 100644 --- a/src/prefect/events/schemas/automations.py +++ b/src/prefect/events/schemas/automations.py @@ -52,7 +52,7 @@ def describe_for_cli(self, indent: int = 0) -> str: _deployment_id: Optional[UUID] = PrivateAttr(default=None) - def set_deployment_id(self, deployment_id: UUID): + def set_deployment_id(self, deployment_id: UUID) -> None: self._deployment_id = deployment_id def owner_resource(self) -> Optional[str]: @@ -277,7 +277,7 @@ class MetricTriggerQuery(PrefectBaseModel): ) @field_validator("range", "firing_for") - def enforce_minimum_range(cls, value: timedelta): + def enforce_minimum_range(cls, value: timedelta) -> timedelta: if value < timedelta(seconds=300): raise ValueError("The minimum range is 300 seconds (5 minutes)") return value @@ -404,13 +404,17 @@ class AutomationCore(PrefectBaseModel, extra="ignore"): # type: ignore[call-arg """Defines an action a user wants to take when a certain number of events do or don't happen to the matching resources""" - name: str = Field(..., description="The name of this automation") - description: str = Field("", description="A longer description of this automation") + name: str = Field(default=..., description="The name of this automation") + description: str = Field( + default="", description="A longer description of this automation" + ) - enabled: bool = Field(True, description="Whether this automation will be evaluated") + enabled: bool = Field( + default=True, description="Whether this automation will be evaluated" + ) trigger: TriggerTypes = Field( - ..., + default=..., description=( "The criteria for which events this Automation covers and how it will " "respond to the presence or absence of those events" @@ -418,7 +422,7 @@ class AutomationCore(PrefectBaseModel, extra="ignore"): # type: ignore[call-arg ) actions: List[ActionTypes] = Field( - ..., + default=..., description="The actions to perform when this Automation triggers", ) @@ -438,4 +442,4 @@ class AutomationCore(PrefectBaseModel, extra="ignore"): # type: ignore[call-arg class Automation(AutomationCore): - id: UUID = Field(..., description="The ID of this automation") + id: UUID = Field(default=..., description="The ID of this automation") diff --git a/src/prefect/events/schemas/events.py b/src/prefect/events/schemas/events.py index eddb69e42742..f7d7e7dba57b 100644 --- a/src/prefect/events/schemas/events.py +++ b/src/prefect/events/schemas/events.py @@ -1,6 +1,7 @@ import copy from collections import defaultdict from typing import ( + TYPE_CHECKING, Any, ClassVar, Dict, @@ -32,7 +33,10 @@ from .labelling import Labelled -logger = get_logger(__name__) +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger(__name__) class Resource(Labelled): diff --git a/src/prefect/filesystems.py b/src/prefect/filesystems.py index 20c0a45d23dd..aec69a916bd7 100644 --- a/src/prefect/filesystems.py +++ b/src/prefect/filesystems.py @@ -281,7 +281,7 @@ class RemoteFileSystem(WritableFileSystem, WritableDeploymentStorage): _filesystem: fsspec.AbstractFileSystem = None @field_validator("basepath") - def check_basepath(cls, value): + def check_basepath(cls, value: str) -> str: return validate_basepath(value) def _resolve_path(self, path: str) -> str: diff --git a/src/prefect/flow_engine.py b/src/prefect/flow_engine.py index cbf50b9d17d3..359e7ff5f995 100644 --- a/src/prefect/flow_engine.py +++ b/src/prefect/flow_engine.py @@ -146,7 +146,7 @@ class BaseFlowRunEngine(Generic[P, R]): _flow_run_name_set: bool = False _telemetry: RunTelemetry = field(default_factory=RunTelemetry) - def __post_init__(self): + def __post_init__(self) -> None: if self.flow is None and self.flow_run_id is None: raise ValueError("Either a flow or a flow_run_id must be provided.") @@ -167,7 +167,7 @@ def is_pending(self) -> bool: return False # TODO: handle this differently? return getattr(self, "flow_run").state.is_pending() - def cancel_all_tasks(self): + def cancel_all_tasks(self) -> None: if hasattr(self.flow.task_runner, "cancel_all"): self.flow.task_runner.cancel_all() # type: ignore @@ -208,6 +208,8 @@ def _update_otel_labels( @dataclass class FlowRunEngine(BaseFlowRunEngine[P, R]): _client: Optional[SyncPrefectClient] = None + flow_run: FlowRun | None = None + parameters: dict[str, Any] | None = None @property def client(self) -> SyncPrefectClient: @@ -502,7 +504,7 @@ def create_flow_run(self, client: SyncPrefectClient) -> FlowRun: tags=TagsContext.get().current_tags, ) - def call_hooks(self, state: Optional[State] = None): + def call_hooks(self, state: Optional[State] = None) -> None: if state is None: state = self.state flow = self.flow @@ -600,7 +602,9 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None): # set the logger to the flow run logger - self.logger = flow_run_logger(flow_run=self.flow_run, flow=self.flow) + self.logger: "logging.Logger" = flow_run_logger( + flow_run=self.flow_run, flow=self.flow + ) # type: ignore # update the flow run name if necessary if not self._flow_run_name_set and self.flow.flow_run_name: @@ -768,6 +772,8 @@ class AsyncFlowRunEngine(BaseFlowRunEngine[P, R]): """ _client: Optional[PrefectClient] = None + parameters: dict[str, Any] | None = None + flow_run: FlowRun | None = None @property def client(self) -> PrefectClient: @@ -1061,7 +1067,7 @@ async def create_flow_run(self, client: PrefectClient) -> FlowRun: tags=TagsContext.get().current_tags, ) - async def call_hooks(self, state: Optional[State] = None): + async def call_hooks(self, state: Optional[State] = None) -> None: if state is None: state = self.state flow = self.flow @@ -1158,7 +1164,9 @@ async def setup_run_context(self, client: Optional[PrefectClient] = None): stack.enter_context(ConcurrencyContext()) # set the logger to the flow run logger - self.logger = flow_run_logger(flow_run=self.flow_run, flow=self.flow) + self.logger: "logging.Logger" = flow_run_logger( + flow_run=self.flow_run, flow=self.flow + ) # update the flow run name if necessary @@ -1320,7 +1328,7 @@ def run_flow_sync( flow: Flow[P, R], flow_run: Optional[FlowRun] = None, parameters: Optional[Dict[str, Any]] = None, - wait_for: Optional[Iterable[PrefectFuture]] = None, + wait_for: Optional[Iterable[PrefectFuture[Any]]] = None, return_type: Literal["state", "result"] = "result", ) -> Union[R, State, None]: engine = FlowRunEngine[P, R]( @@ -1342,7 +1350,7 @@ async def run_flow_async( flow: Flow[P, R], flow_run: Optional[FlowRun] = None, parameters: Optional[Dict[str, Any]] = None, - wait_for: Optional[Iterable[PrefectFuture]] = None, + wait_for: Optional[Iterable[PrefectFuture[Any]]] = None, return_type: Literal["state", "result"] = "result", ) -> Union[R, State, None]: engine = AsyncFlowRunEngine[P, R]( @@ -1361,7 +1369,7 @@ def run_generator_flow_sync( flow: Flow[P, R], flow_run: Optional[FlowRun] = None, parameters: Optional[Dict[str, Any]] = None, - wait_for: Optional[Iterable[PrefectFuture]] = None, + wait_for: Optional[Iterable[PrefectFuture[Any]]] = None, return_type: Literal["state", "result"] = "result", ) -> Generator[R, None, None]: if return_type != "result": diff --git a/src/prefect/flows.py b/src/prefect/flows.py index af93f1a652d1..9a3af1eff546 100644 --- a/src/prefect/flows.py +++ b/src/prefect/flows.py @@ -86,6 +86,7 @@ sync_compatible, ) from prefect.utilities.callables import ( + ParameterSchema, get_call_parameters, parameter_schema, parameters_to_args_kwargs, @@ -272,7 +273,7 @@ def __init__( if not callable(fn): raise TypeError("'fn' must be callable") - self.name = name or fn.__name__.replace("_", "-").replace( + self.name: str = name or fn.__name__.replace("_", "-").replace( "", "unknown-lambda", # prefect API will not accept "<" or ">" in flow names ) @@ -287,29 +288,29 @@ def __init__( self.flow_run_name = flow_run_name if task_runner is None: - self.task_runner = cast( + self.task_runner: TaskRunner[PrefectFuture[Any]] = cast( TaskRunner[PrefectFuture[Any]], ThreadPoolTaskRunner() ) else: - self.task_runner = ( + self.task_runner: TaskRunner[PrefectFuture[Any]] = ( task_runner() if isinstance(task_runner, type) else task_runner ) self.log_prints = log_prints - self.description = description or inspect.getdoc(fn) + self.description: str | None = description or inspect.getdoc(fn) update_wrapper(self, fn) self.fn = fn # the flow is considered async if its function is async or an async # generator - self.isasync = asyncio.iscoroutinefunction( + self.isasync: bool = asyncio.iscoroutinefunction( self.fn ) or inspect.isasyncgenfunction(self.fn) # the flow is considered a generator if its function is a generator or # an async generator - self.isgenerator = inspect.isgeneratorfunction( + self.isgenerator: bool = inspect.isgeneratorfunction( self.fn ) or inspect.isasyncgenfunction(self.fn) @@ -326,22 +327,24 @@ def __init__( pass # `getsourcefile` can return null values and "" for objects in repls self.version = version - self.timeout_seconds = float(timeout_seconds) if timeout_seconds else None + self.timeout_seconds: float | None = ( + float(timeout_seconds) if timeout_seconds else None + ) # FlowRunPolicy settings # TODO: We can instantiate a `FlowRunPolicy` and add Pydantic bound checks to # validate that the user passes positive numbers here - self.retries = ( + self.retries: int = ( retries if retries is not None else PREFECT_FLOW_DEFAULT_RETRIES.value() ) - self.retry_delay_seconds = ( + self.retry_delay_seconds: float | int = ( retry_delay_seconds if retry_delay_seconds is not None else PREFECT_FLOW_DEFAULT_RETRY_DELAY_SECONDS.value() ) - self.parameters = parameter_schema(self.fn) + self.parameters: ParameterSchema = parameter_schema(self.fn) self.should_validate_parameters = validate_parameters if self.should_validate_parameters: @@ -421,7 +424,7 @@ def with_options( description: Optional[str] = None, flow_run_name: Optional[Union[Callable[[], str], str]] = None, task_runner: Union[ - Type[TaskRunner[PrefectFuture[R]]], TaskRunner[PrefectFuture[R]], None + Type[TaskRunner[PrefectFuture[Any]]], TaskRunner[PrefectFuture[Any]], None ] = None, timeout_seconds: Union[int, float, None] = None, validate_parameters: Optional[bool] = None, @@ -1708,7 +1711,7 @@ def from_source( ... -flow = FlowDecorator() +flow: FlowDecorator = FlowDecorator() def _raise_on_name_with_banned_characters(name: Optional[str]) -> Optional[str]: diff --git a/src/prefect/futures.py b/src/prefect/futures.py index 97e5bba64cc0..ee31acc58503 100644 --- a/src/prefect/futures.py +++ b/src/prefect/futures.py @@ -5,7 +5,7 @@ import uuid from collections.abc import Generator, Iterator from functools import partial -from typing import Any, Callable, Generic, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, Union from typing_extensions import NamedTuple, Self, TypeVar @@ -22,7 +22,10 @@ F = TypeVar("F") R = TypeVar("R") -logger = get_logger(__name__) +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger(__name__) class PrefectFuture(abc.ABC, Generic[R]): diff --git a/src/prefect/infrastructure/provisioners/__init__.py b/src/prefect/infrastructure/provisioners/__init__.py index 545a8576a102..bb8270360841 100644 --- a/src/prefect/infrastructure/provisioners/__init__.py +++ b/src/prefect/infrastructure/provisioners/__init__.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol, Type +from prefect.infrastructure.provisioners.coiled import CoiledPushProvisioner from prefect.infrastructure.provisioners.modal import ModalPushProvisioner from .cloud_run import CloudRunPushProvisioner from .container_instance import ContainerInstancePushProvisioner @@ -15,6 +16,7 @@ "azure-container-instance:push": ContainerInstancePushProvisioner, "ecs:push": ElasticContainerServicePushProvisioner, "modal:push": ModalPushProvisioner, + "coiled:push": CoiledPushProvisioner, } diff --git a/src/prefect/infrastructure/provisioners/cloud_run.py b/src/prefect/infrastructure/provisioners/cloud_run.py index 4b6d281c688d..011d5e2f4066 100644 --- a/src/prefect/infrastructure/provisioners/cloud_run.py +++ b/src/prefect/infrastructure/provisioners/cloud_run.py @@ -33,7 +33,7 @@ class CloudRunPushProvisioner: def __init__(self): - self._console = Console() + self._console: Console = Console() self._project = None self._region = None self._service_account_name = "prefect-cloud-run" @@ -41,14 +41,14 @@ def __init__(self): self._image_repository_name = "prefect-images" @property - def console(self): + def console(self) -> Console: return self._console @console.setter - def console(self, value): + def console(self, value: Console) -> None: self._console = value - async def _run_command(self, command: str, *args, **kwargs): + async def _run_command(self, command: str, *args: Any, **kwargs: Any) -> str: result = await run_process(shlex.split(command), check=False, *args, **kwargs) if result.returncode != 0: diff --git a/src/prefect/infrastructure/provisioners/coiled.py b/src/prefect/infrastructure/provisioners/coiled.py new file mode 100644 index 000000000000..f0adb5d7ce50 --- /dev/null +++ b/src/prefect/infrastructure/provisioners/coiled.py @@ -0,0 +1,249 @@ +import importlib +import shlex +import sys +from copy import deepcopy +from types import ModuleType +from typing import TYPE_CHECKING, Any, Dict, Optional + +from anyio import run_process +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn +from rich.prompt import Confirm + +from prefect.client.schemas.actions import BlockDocumentCreate +from prefect.client.schemas.objects import BlockDocument +from prefect.client.utilities import inject_client +from prefect.exceptions import ObjectNotFound +from prefect.utilities.importtools import lazy_import + +if TYPE_CHECKING: + from prefect.client.orchestration import PrefectClient + + +coiled: ModuleType = lazy_import("coiled") + + +class CoiledPushProvisioner: + """ + A infrastructure provisioner for Coiled push work pools. + """ + + def __init__(self, client: Optional["PrefectClient"] = None): + self._console = Console() + + @property + def console(self) -> Console: + return self._console + + @console.setter + def console(self, value: Console) -> None: + self._console = value + + @staticmethod + def _is_coiled_installed() -> bool: + """ + Checks if the coiled package is installed. + + Returns: + True if the coiled package is installed, False otherwise + """ + try: + importlib.import_module("coiled") + return True + except ModuleNotFoundError: + return False + + async def _install_coiled(self): + """ + Installs the coiled package. + """ + with Progress( + SpinnerColumn(), + TextColumn("[bold blue]Installing coiled..."), + transient=True, + console=self.console, + ) as progress: + task = progress.add_task("coiled install") + progress.start() + global coiled + await run_process( + [shlex.quote(sys.executable), "-m", "pip", "install", "coiled"] + ) + coiled = importlib.import_module("coiled") + progress.advance(task) + + async def _get_coiled_token(self) -> str: + """ + Gets a Coiled API token from the current Coiled configuration. + """ + import dask.config + + return dask.config.get("coiled.token", "") + + async def _create_new_coiled_token(self): + """ + Triggers a Coiled login via the browser if no current token. Will create a new token. + """ + await run_process(["coiled", "login"]) + + async def _create_coiled_credentials_block( + self, + block_document_name: str, + coiled_token: str, + client: "PrefectClient", + ) -> BlockDocument: + """ + Creates a CoiledCredentials block containing the provided token. + + Args: + block_document_name: The name of the block document to create + coiled_token: The Coiled API token + + Returns: + The ID of the created block + """ + assert client is not None, "client injection failed" + try: + credentials_block_type = await client.read_block_type_by_slug( + "coiled-credentials" + ) + except ObjectNotFound: + # Shouldn't happen, but just in case + raise RuntimeError( + "Unable to find CoiledCredentials block type. Please ensure you are" + " using Prefect Cloud." + ) + credentials_block_schema = ( + await client.get_most_recent_block_schema_for_block_type( + block_type_id=credentials_block_type.id + ) + ) + assert ( + credentials_block_schema is not None + ), f"Unable to find schema for block type {credentials_block_type.slug}" + + block_doc = await client.create_block_document( + block_document=BlockDocumentCreate( + name=block_document_name, + data={ + "api_token": coiled_token, + }, + block_type_id=credentials_block_type.id, + block_schema_id=credentials_block_schema.id, + ) + ) + return block_doc + + @inject_client + async def provision( + self, + work_pool_name: str, + base_job_template: Dict[str, Any], + client: Optional["PrefectClient"] = None, + ) -> Dict[str, Any]: + """ + Provisions resources necessary for a Coiled push work pool. + + Provisioned resources: + - A CoiledCredentials block containing a Coiled API token + + Args: + work_pool_name: The name of the work pool to provision resources for + base_job_template: The base job template to update + + Returns: + A copy of the provided base job template with the provisioned resources + """ + credentials_block_name = f"{work_pool_name}-coiled-credentials" + base_job_template_copy = deepcopy(base_job_template) + assert client is not None, "client injection failed" + try: + block_doc = await client.read_block_document_by_name( + credentials_block_name, "coiled-credentials" + ) + self.console.print( + f"Work pool [blue]{work_pool_name!r}[/] will reuse the existing Coiled" + f" credentials block [blue]{credentials_block_name!r}[/blue]" + ) + except ObjectNotFound: + if self._console.is_interactive and not Confirm.ask( + ( + "\n" + "To configure your Coiled push work pool we'll need to store a Coiled" + " API token with Prefect Cloud as a block. We'll pull the token from" + " your local Coiled configuration or create a new token if we" + " can't find one.\n" + "\n" + "Would you like to continue?" + ), + console=self.console, + default=True, + ): + self.console.print( + "No problem! You can always configure your Coiled push work pool" + " later via the Prefect UI." + ) + return base_job_template + + if not self._is_coiled_installed(): + if self.console.is_interactive and Confirm.ask( + ( + "The [blue]coiled[/] package is required to configure" + " authentication for your work pool.\n" + "\n" + "Would you like to install it now?" + ), + console=self.console, + default=True, + ): + await self._install_coiled() + + if not self._is_coiled_installed(): + raise RuntimeError( + "The coiled package is not installed.\n\nPlease try installing coiled," + " or you can use the Prefect UI to create your Coiled push work pool." + ) + + # Get the current Coiled API token + coiled_api_token = await self._get_coiled_token() + if not coiled_api_token: + # Create a new token one wasn't found + if self.console.is_interactive and Confirm.ask( + "Coiled credentials not found. Would you like to create a new token?", + console=self.console, + default=True, + ): + await self._create_new_coiled_token() + coiled_api_token = await self._get_coiled_token() + else: + raise RuntimeError( + "Coiled credentials not found. Please create a new token by" + " running [blue]coiled login[/] and try again." + ) + + # Create the credentials block + with Progress( + SpinnerColumn(), + TextColumn("[bold blue]Saving Coiled credentials..."), + transient=True, + console=self.console, + ) as progress: + task = progress.add_task("create coiled credentials block") + progress.start() + block_doc = await self._create_coiled_credentials_block( + credentials_block_name, + coiled_api_token, + client=client, + ) + progress.advance(task) + + base_job_template_copy["variables"]["properties"]["credentials"]["default"] = { + "$ref": {"block_document_id": str(block_doc.id)} + } + if "image" in base_job_template_copy["variables"]["properties"]: + base_job_template_copy["variables"]["properties"]["image"]["default"] = "" + self.console.print( + f"Successfully configured Coiled push work pool {work_pool_name!r}!", + style="green", + ) + return base_job_template_copy diff --git a/src/prefect/infrastructure/provisioners/container_instance.py b/src/prefect/infrastructure/provisioners/container_instance.py index 17eb85fdb222..4bbace63fa27 100644 --- a/src/prefect/infrastructure/provisioners/container_instance.py +++ b/src/prefect/infrastructure/provisioners/container_instance.py @@ -10,6 +10,7 @@ ContainerInstancePushProvisioner: A class for provisioning infrastructure using Azure Container Instances. """ +from __future__ import annotations import json import random @@ -64,7 +65,7 @@ async def run_command( failure_message: Optional[str] = None, ignore_if_exists: bool = False, return_json: bool = False, - ): + ) -> str | dict[str, Any] | None: """ Runs an Azure CLI command and processes the output. @@ -156,7 +157,7 @@ def __init__(self): self._subscription_name = None self._location = "eastus" self._identity_name = "prefect-acr-identity" - self.azure_cli = AzureCLI(self.console) + self.azure_cli: AzureCLI = AzureCLI(self.console) self._credentials_block_name = None self._resource_group_name = "prefect-aci-push-pool-rg" self._app_registration_name = "prefect-aci-push-pool-app" @@ -170,7 +171,7 @@ def console(self) -> Console: def console(self, value: Console) -> None: self._console = value - async def set_location(self): + async def set_location(self) -> None: """ Set the Azure resource deployment location to the default or 'eastus' on failure. diff --git a/src/prefect/infrastructure/provisioners/ecs.py b/src/prefect/infrastructure/provisioners/ecs.py index a851055e11d3..2f6cb1460f87 100644 --- a/src/prefect/infrastructure/provisioners/ecs.py +++ b/src/prefect/infrastructure/provisioners/ecs.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import contextlib import contextvars @@ -9,9 +11,11 @@ from copy import deepcopy from functools import partial from textwrap import dedent -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from types import ModuleType +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional import anyio +import anyio.to_thread from anyio import run_process from rich.console import Console from rich.panel import Panel @@ -33,13 +37,15 @@ if TYPE_CHECKING: from prefect.client.orchestration import PrefectClient -boto3 = lazy_import("boto3") +boto3: ModuleType = lazy_import("boto3") -current_console = contextvars.ContextVar("console", default=Console()) +current_console: contextvars.ContextVar[Console] = contextvars.ContextVar( + "console", default=Console() +) @contextlib.contextmanager -def console_context(value: Console): +def console_context(value: Console) -> Generator[None, None, None]: token = current_console.set(value) try: yield @@ -73,7 +79,7 @@ async def get_task_count(self) -> int: """ return 1 if await self.requires_provisioning() else 0 - def _get_policy_by_name(self, name): + def _get_policy_by_name(self, name: str) -> dict[str, Any] | None: paginator = self._iam_client.get_paginator("list_policies") page_iterator = paginator.paginate(Scope="Local") @@ -119,9 +125,9 @@ async def get_planned_actions(self) -> List[str]: async def provision( self, - policy_document: Dict[str, Any], + policy_document: dict[str, Any], advance: Callable[[], None], - ): + ) -> str: """ Provisions an IAM policy. @@ -153,7 +159,7 @@ async def provision( return policy["Arn"] @property - def next_steps(self): + def next_steps(self) -> list[str]: return [] @@ -215,7 +221,7 @@ async def get_planned_actions(self) -> List[str]: async def provision( self, advance: Callable[[], None], - ): + ) -> None: """ Provisions an IAM user. @@ -231,7 +237,7 @@ async def provision( advance() @property - def next_steps(self): + def next_steps(self) -> list[str]: return [] @@ -241,7 +247,7 @@ def __init__(self, user_name: str, block_document_name: str): self._user_name = user_name self._requires_provisioning = None - async def get_task_count(self): + async def get_task_count(self) -> int: """ Returns the number of tasks that will be executed to provision this resource. @@ -357,7 +363,7 @@ async def provision( } @property - def next_steps(self): + def next_steps(self) -> list[str]: return [] @@ -374,7 +380,7 @@ def __init__( credentials_block_name or f"{work_pool_name}-aws-credentials" ) self._policy_name = policy_name - self._policy_document = { + self._policy_document: dict[str, Any] = { "Version": "2012-10-17", "Statement": [ { @@ -417,7 +423,11 @@ def __init__( self._execution_role_resource = ExecutionRoleResource() @property - def resources(self): + def resources( + self, + ) -> list[ + "ExecutionRoleResource | IamUserResource | IamPolicyResource | CredentialsBlockResource" + ]: return [ self._execution_role_resource, self._iam_user_resource, @@ -425,7 +435,7 @@ def resources(self): self._credentials_block_resource, ] - async def get_task_count(self): + async def get_task_count(self) -> int: """ Returns the number of tasks that will be executed to provision this resource. @@ -461,9 +471,9 @@ async def get_planned_actions(self) -> List[str]: async def provision( self, - base_job_template: Dict[str, Any], + base_job_template: dict[str, Any], advance: Callable[[], None], - ): + ) -> None: """ Provisions the authentication resources. @@ -507,7 +517,7 @@ async def provision( ) @property - def next_steps(self): + def next_steps(self) -> list[str]: return [ next_step for resource in self.resources @@ -521,7 +531,7 @@ def __init__(self, cluster_name: str = "prefect-ecs-cluster"): self._cluster_name = cluster_name self._requires_provisioning = None - async def get_task_count(self): + async def get_task_count(self) -> int: """ Returns the number of tasks that will be executed to provision this resource. @@ -566,9 +576,9 @@ async def get_planned_actions(self) -> List[str]: async def provision( self, - base_job_template: Dict[str, Any], + base_job_template: dict[str, Any], advance: Callable[[], None], - ): + ) -> None: """ Provisions an ECS cluster. @@ -592,7 +602,7 @@ async def provision( ] = self._cluster_name @property - def next_steps(self): + def next_steps(self) -> list[str]: return [] @@ -608,7 +618,7 @@ def __init__( self._requires_provisioning = None self._ecs_security_group_name = ecs_security_group_name - async def get_task_count(self): + async def get_task_count(self) -> int: """ Returns the number of tasks that will be executed to provision this resource. @@ -642,7 +652,9 @@ async def _get_existing_vpc_cidrs(self): response = await anyio.to_thread.run_sync(self._ec2_client.describe_vpcs) return [vpc["CidrBlock"] for vpc in response["Vpcs"]] - async def _find_non_overlapping_cidr(self, default_cidr="172.31.0.0/16"): + async def _find_non_overlapping_cidr( + self, default_cidr: str = "172.31.0.0/16" + ) -> str: """Find a non-overlapping CIDR block""" response = await anyio.to_thread.run_sync(self._ec2_client.describe_vpcs) existing_cidrs = [vpc["CidrBlock"] for vpc in response["Vpcs"]] @@ -708,9 +720,9 @@ async def get_planned_actions(self) -> List[str]: async def provision( self, - base_job_template: Dict[str, Any], + base_job_template: dict[str, Any], advance: Callable[[], None], - ): + ) -> None: """ Provisions a VPC. @@ -768,7 +780,7 @@ async def provision( ) )["AvailabilityZones"] zones = [az["ZoneName"] for az in azs] - subnets = [] + subnets: list[Any] = [] for i, subnet_cidr in enumerate(subnet_cidrs[0:3]): subnets.append( await anyio.to_thread.run_sync( @@ -828,7 +840,7 @@ async def provision( ) @property - def next_steps(self): + def next_steps(self) -> list[str]: return [] @@ -838,9 +850,9 @@ def __init__(self, work_pool_name: str, repository_name: str = "prefect-flows"): self._repository_name = repository_name self._requires_provisioning = None self._work_pool_name = work_pool_name - self._next_steps = [] + self._next_steps: list[str | Panel] = [] - async def get_task_count(self): + async def get_task_count(self) -> int: """ Returns the number of tasks that will be executed to provision this resource. @@ -895,9 +907,9 @@ async def get_planned_actions(self) -> List[str]: async def provision( self, - base_job_template: Dict[str, Any], + base_job_template: dict[str, Any], advance: Callable[[], None], - ): + ) -> None: """ Provisions an ECR repository. @@ -978,7 +990,7 @@ def my_flow(name: str = "world"): ) @property - def next_steps(self): + def next_steps(self) -> list[str | Panel]: return self._next_steps @@ -1000,7 +1012,7 @@ def __init__(self, execution_role_name: str = "PrefectEcsTaskExecutionRole"): ) self._requires_provisioning = None - async def get_task_count(self): + async def get_task_count(self) -> int: """ Returns the number of tasks that will be executed to provision this resource. @@ -1046,9 +1058,9 @@ async def get_planned_actions(self) -> List[str]: async def provision( self, - base_job_template: Dict[str, Any], + base_job_template: dict[str, Any], advance: Callable[[], None], - ): + ) -> str: """ Provisions an IAM role. @@ -1087,7 +1099,7 @@ async def provision( return response["Role"]["Arn"] @property - def next_steps(self): + def next_steps(self) -> list[str]: return [] @@ -1100,11 +1112,11 @@ def __init__(self): self._console = Console() @property - def console(self): + def console(self) -> Console: return self._console @console.setter - def console(self, value): + def console(self, value: Console) -> None: self._console = value async def _prompt_boto3_installation(self): @@ -1115,7 +1127,7 @@ async def _prompt_boto3_installation(self): boto3 = importlib.import_module("boto3") @staticmethod - def is_boto3_installed(): + def is_boto3_installed() -> bool: """ Check if boto3 is installed. """ @@ -1157,8 +1169,8 @@ def _generate_resources( async def provision( self, work_pool_name: str, - base_job_template: dict, - ) -> Dict[str, Any]: + base_job_template: dict[str, Any], + ) -> dict[str, Any]: """ Provisions the infrastructure for an ECS push work pool. @@ -1310,7 +1322,7 @@ async def provision( # provision calls will be no-ops, but update the base job template base_job_template_copy = deepcopy(base_job_template) - next_steps = [] + next_steps: list[str | Panel] = [] with Progress(console=self._console, disable=num_tasks == 0) as progress: task = progress.add_task( "Provisioning Infrastructure", diff --git a/src/prefect/infrastructure/provisioners/modal.py b/src/prefect/infrastructure/provisioners/modal.py index 274775817f86..04960d9c5d89 100644 --- a/src/prefect/infrastructure/provisioners/modal.py +++ b/src/prefect/infrastructure/provisioners/modal.py @@ -2,6 +2,7 @@ import shlex import sys from copy import deepcopy +from types import ModuleType from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from anyio import run_process @@ -19,7 +20,7 @@ from prefect.client.orchestration import PrefectClient -modal = lazy_import("modal") +modal: ModuleType = lazy_import("modal") class ModalPushProvisioner: @@ -28,14 +29,14 @@ class ModalPushProvisioner: """ def __init__(self, client: Optional["PrefectClient"] = None): - self._console = Console() + self._console: Console = Console() @property - def console(self): + def console(self) -> Console: return self._console @console.setter - def console(self, value): + def console(self, value: Console) -> None: self._console = value @staticmethod diff --git a/src/prefect/input/actions.py b/src/prefect/input/actions.py index 9e88e5c59b69..e771ca491c39 100644 --- a/src/prefect/input/actions.py +++ b/src/prefect/input/actions.py @@ -1,3 +1,4 @@ +import inspect from typing import TYPE_CHECKING, Any, Optional, Set from uuid import UUID @@ -44,9 +45,12 @@ async def create_flow_run_input_from_model( else: json_safe = orjson.loads(model_instance.json()) - await create_flow_run_input( + coro = create_flow_run_input( key=key, value=json_safe, flow_run_id=flow_run_id, sender=sender ) + if TYPE_CHECKING: + assert inspect.iscoroutine(coro) + await coro @sync_compatible diff --git a/src/prefect/input/run_input.py b/src/prefect/input/run_input.py index 75434eb4cec4..1567b0d2ff72 100644 --- a/src/prefect/input/run_input.py +++ b/src/prefect/input/run_input.py @@ -60,11 +60,15 @@ async def receiver_flow(): ``` """ +from __future__ import annotations + +import inspect from inspect import isclass from typing import ( TYPE_CHECKING, Any, ClassVar, + Coroutine, Dict, Generic, Literal, @@ -81,6 +85,7 @@ async def receiver_flow(): import anyio import pydantic from pydantic import ConfigDict +from typing_extensions import Self from prefect.input.actions import ( create_flow_run_input, @@ -144,7 +149,7 @@ class RunInputMetadata(pydantic.BaseModel): receiver: UUID -class RunInput(pydantic.BaseModel): +class BaseRunInput(pydantic.BaseModel): model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") _description: Optional[str] = pydantic.PrivateAttr(default=None) @@ -172,23 +177,29 @@ async def save(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None): if is_v2_model(cls): schema = create_v2_schema(cls.__name__, model_base=cls) else: - schema = cls.schema(by_alias=True) + schema = cls.model_json_schema(by_alias=True) - await create_flow_run_input( + coro = create_flow_run_input( key=keyset["schema"], value=schema, flow_run_id=flow_run_id ) + if TYPE_CHECKING: + assert inspect.iscoroutine(coro) + await coro description = cls._description if isinstance(cls._description, str) else None if description: - await create_flow_run_input( + coro = create_flow_run_input( key=keyset["description"], value=description, flow_run_id=flow_run_id, ) + if TYPE_CHECKING: + assert inspect.iscoroutine(coro) + await coro @classmethod @sync_compatible - async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None): + async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None) -> Self: """ Load the run input response from the given key. @@ -208,7 +219,7 @@ async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None): return instance @classmethod - def load_from_flow_run_input(cls, flow_run_input: "FlowRunInput"): + def load_from_flow_run_input(cls, flow_run_input: "FlowRunInput") -> Self: """ Load the run input from a FlowRunInput object. @@ -284,6 +295,8 @@ async def send_to( key_prefix=key_prefix, ) + +class RunInput(BaseRunInput): @classmethod def receive( cls, @@ -293,7 +306,7 @@ def receive( exclude_keys: Optional[Set[str]] = None, key_prefix: Optional[str] = None, flow_run_id: Optional[UUID] = None, - ): + ) -> GetInputHandler[Self]: if key_prefix is None: key_prefix = f"{cls.__name__.lower()}-auto" @@ -322,12 +335,12 @@ def subclass_from_base_model_type( return type(f"{model_cls.__name__}RunInput", (RunInput, model_cls), {}) # type: ignore -class AutomaticRunInput(RunInput, Generic[T]): +class AutomaticRunInput(BaseRunInput, Generic[T]): value: T @classmethod @sync_compatible - async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None) -> T: + async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None) -> Self: """ Load the run input response from the given key. @@ -335,7 +348,10 @@ async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None) -> T: - keyset (Keyset): the keyset to load the input for - flow_run_id (UUID, optional): the flow run ID to load the input for """ - instance = await super().load(keyset, flow_run_id=flow_run_id) + instance_coro = super().load(keyset, flow_run_id=flow_run_id) + if TYPE_CHECKING: + assert inspect.iscoroutine(instance_coro) + instance = await instance_coro return instance.value @classmethod @@ -370,17 +386,34 @@ def subclass_from_type(cls, _type: Type[T]) -> Type["AutomaticRunInput[T]"]: # Creating a new Pydantic model class dynamically with the name based # on the type prefix. - new_cls: Type["AutomaticRunInput"] = pydantic.create_model( + new_cls: Type["AutomaticRunInput[T]"] = pydantic.create_model( class_name, **fields, __base__=AutomaticRunInput ) return new_cls @classmethod - def receive(cls, *args, **kwargs): - if kwargs.get("key_prefix") is None: - kwargs["key_prefix"] = f"{cls.__name__.lower()}-auto" + def receive( + cls, + timeout: Optional[float] = 3600, + poll_interval: float = 10, + raise_timeout_error: bool = False, + exclude_keys: Optional[Set[str]] = None, + key_prefix: Optional[str] = None, + flow_run_id: Optional[UUID] = None, + with_metadata: bool = False, + ) -> GetAutomaticInputHandler[T]: + key_prefix = key_prefix or f"{cls.__name__.lower()}-auto" - return GetAutomaticInputHandler(run_input_cls=cls, *args, **kwargs) + return GetAutomaticInputHandler( + run_input_cls=cls, + key_prefix=key_prefix, + timeout=timeout, + poll_interval=poll_interval, + raise_timeout_error=raise_timeout_error, + exclude_keys=exclude_keys, + flow_run_id=flow_run_id, + with_metadata=with_metadata, + ) def run_input_subclass_from_type( @@ -409,24 +442,24 @@ def __init__( self, run_input_cls: Type[R], key_prefix: str, - timeout: Optional[float] = 3600, + timeout: float | None = 3600, poll_interval: float = 10, raise_timeout_error: bool = False, exclude_keys: Optional[Set[str]] = None, flow_run_id: Optional[UUID] = None, ): - self.run_input_cls = run_input_cls - self.key_prefix = key_prefix - self.timeout = timeout - self.poll_interval = poll_interval - self.exclude_keys = set() - self.raise_timeout_error = raise_timeout_error - self.flow_run_id = ensure_flow_run_id(flow_run_id) + self.run_input_cls: Type[R] = run_input_cls + self.key_prefix: str = key_prefix + self.timeout: float | None = timeout + self.poll_interval: float = poll_interval + self.exclude_keys: set[str] = set() + self.raise_timeout_error: bool = raise_timeout_error + self.flow_run_id: UUID = ensure_flow_run_id(flow_run_id) if exclude_keys is not None: self.exclude_keys.update(exclude_keys) - def __iter__(self): + def __iter__(self) -> Self: return self def __next__(self) -> R: @@ -437,24 +470,31 @@ def __next__(self) -> R: raise raise StopIteration - def __aiter__(self): + def __aiter__(self) -> Self: return self async def __anext__(self) -> R: try: - return await self.next() + coro = self.next() + if TYPE_CHECKING: + assert inspect.iscoroutine(coro) + return await coro except TimeoutError: if self.raise_timeout_error: raise raise StopAsyncIteration - async def filter_for_inputs(self): - flow_run_inputs = await filter_flow_run_input( + async def filter_for_inputs(self) -> list["FlowRunInput"]: + flow_run_inputs_coro = filter_flow_run_input( key_prefix=self.key_prefix, limit=1, exclude_keys=self.exclude_keys, flow_run_id=self.flow_run_id, ) + if TYPE_CHECKING: + assert inspect.iscoroutine(flow_run_inputs_coro) + + flow_run_inputs = await flow_run_inputs_coro if flow_run_inputs: self.exclude_keys.add(*[i.key for i in flow_run_inputs]) @@ -478,22 +518,91 @@ async def next(self) -> R: return self.to_instance(flow_run_inputs[0]) -class GetAutomaticInputHandler(GetInputHandler, Generic[T]): - def __init__(self, *args, **kwargs): - self.with_metadata = kwargs.pop("with_metadata", False) - super().__init__(*args, **kwargs) +class GetAutomaticInputHandler(Generic[T]): + def __init__( + self, + run_input_cls: Type[AutomaticRunInput[T]], + key_prefix: str, + timeout: float | None = 3600, + poll_interval: float = 10, + raise_timeout_error: bool = False, + exclude_keys: Optional[Set[str]] = None, + flow_run_id: Optional[UUID] = None, + with_metadata: bool = False, + ): + self.run_input_cls: Type[AutomaticRunInput[T]] = run_input_cls + self.key_prefix: str = key_prefix + self.timeout: float | None = timeout + self.poll_interval: float = poll_interval + self.exclude_keys: set[str] = set() + self.raise_timeout_error: bool = raise_timeout_error + self.flow_run_id: UUID = ensure_flow_run_id(flow_run_id) + self.with_metadata = with_metadata - def __next__(self) -> T: - return cast(T, super().__next__()) + if exclude_keys is not None: + self.exclude_keys.update(exclude_keys) - async def __anext__(self) -> T: - return cast(T, await super().__anext__()) + def __iter__(self) -> Self: + return self + + def __next__(self) -> T | AutomaticRunInput[T]: + try: + not_coro = self.next() + if TYPE_CHECKING: + assert not isinstance(not_coro, Coroutine) + return not_coro + except TimeoutError: + if self.raise_timeout_error: + raise + raise StopIteration + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> Union[T, AutomaticRunInput[T]]: + try: + coro = self.next() + if TYPE_CHECKING: + assert inspect.iscoroutine(coro) + return cast(Union[T, AutomaticRunInput[T]], await coro) + except TimeoutError: + if self.raise_timeout_error: + raise + raise StopAsyncIteration + + async def filter_for_inputs(self) -> list["FlowRunInput"]: + flow_run_inputs_coro = filter_flow_run_input( + key_prefix=self.key_prefix, + limit=1, + exclude_keys=self.exclude_keys, + flow_run_id=self.flow_run_id, + ) + if TYPE_CHECKING: + assert inspect.iscoroutine(flow_run_inputs_coro) + + flow_run_inputs = await flow_run_inputs_coro + + if flow_run_inputs: + self.exclude_keys.add(*[i.key for i in flow_run_inputs]) + + return flow_run_inputs @sync_compatible - async def next(self) -> T: - return cast(T, await super().next()) + async def next(self) -> Union[T, AutomaticRunInput[T]]: + flow_run_inputs = await self.filter_for_inputs() + if flow_run_inputs: + return self.to_instance(flow_run_inputs[0]) - def to_instance(self, flow_run_input: "FlowRunInput") -> T: + with anyio.fail_after(self.timeout): + while True: + await anyio.sleep(self.poll_interval) + flow_run_inputs = await self.filter_for_inputs() + if flow_run_inputs: + return self.to_instance(flow_run_inputs[0]) + + def to_instance( + self, flow_run_input: "FlowRunInput" + ) -> Union[T, AutomaticRunInput[T]]: run_input = self.run_input_cls.load_from_flow_run_input(flow_run_input) if self.with_metadata: @@ -503,14 +612,15 @@ def to_instance(self, flow_run_input: "FlowRunInput") -> T: async def _send_input( flow_run_id: UUID, - run_input: Any, + run_input: RunInput | pydantic.BaseModel, sender: Optional[str] = None, key_prefix: Optional[str] = None, ): + _run_input: Union[RunInput, AutomaticRunInput[Any]] if isinstance(run_input, RunInput): - _run_input: RunInput = run_input + _run_input = run_input else: - input_cls: Type[AutomaticRunInput] = run_input_subclass_from_type( + input_cls: Type[AutomaticRunInput[Any]] = run_input_subclass_from_type( type(run_input) ) _run_input = input_cls(value=run_input) @@ -520,9 +630,13 @@ async def _send_input( key = f"{key_prefix}-{uuid4()}" - await create_flow_run_input_from_model( + coro = create_flow_run_input_from_model( key=key, flow_run_id=flow_run_id, model_instance=_run_input, sender=sender ) + if TYPE_CHECKING: + assert inspect.iscoroutine(coro) + + await coro @sync_compatible diff --git a/src/prefect/logging/configuration.py b/src/prefect/logging/configuration.py index 9b666668e33d..73216645cd99 100644 --- a/src/prefect/logging/configuration.py +++ b/src/prefect/logging/configuration.py @@ -6,7 +6,7 @@ import warnings from functools import partial from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, Optional import yaml @@ -24,10 +24,10 @@ PROCESS_LOGGING_CONFIG: Optional[Dict[str, Any]] = None # Regex call to replace non-alphanumeric characters to '_' to create a valid env var -to_envvar = partial(re.sub, re.compile(r"[^0-9a-zA-Z]+"), "_") +to_envvar: Callable[[str], str] = partial(re.sub, re.compile(r"[^0-9a-zA-Z]+"), "_") -def load_logging_config(path: Path) -> dict: +def load_logging_config(path: Path) -> dict[str, Any]: """ Loads logging configuration from a path allowing override from the environment """ diff --git a/src/prefect/logging/filters.py b/src/prefect/logging/filters.py index 43deb4847a6a..013025063d2a 100644 --- a/src/prefect/logging/filters.py +++ b/src/prefect/logging/filters.py @@ -5,7 +5,7 @@ from prefect.utilities.names import obfuscate -def redact_substr(obj: Any, substr: str): +def redact_substr(obj: Any, substr: str) -> Any: """ Redact a string from a potentially nested object. @@ -17,7 +17,7 @@ def redact_substr(obj: Any, substr: str): Any: The object with the API key redacted. """ - def redact_item(item): + def redact_item(item: Any) -> Any: if isinstance(item, str): return item.replace(substr, obfuscate(substr)) return item diff --git a/src/prefect/logging/formatters.py b/src/prefect/logging/formatters.py index a8eb92c8170d..9fe2c752984b 100644 --- a/src/prefect/logging/formatters.py +++ b/src/prefect/logging/formatters.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import logging.handlers import sys import traceback from types import TracebackType -from typing import Optional, Tuple, Type, Union +from typing import Any, Literal, Optional, Tuple, Type, Union import orjson @@ -14,7 +16,7 @@ ] -def format_exception_info(exc_info: ExceptionInfoType) -> dict: +def format_exception_info(exc_info: ExceptionInfoType) -> dict[str, Any]: # if sys.exc_info() returned a (None, None, None) tuple, # then there's nothing to format if exc_info[0] is None: @@ -40,13 +42,15 @@ class JsonFormatter(logging.Formatter): newlines. """ - def __init__(self, fmt, dmft, style) -> None: # noqa + def __init__( + self, fmt: Literal["pretty", "default"], dmft: str, style: str + ) -> None: # noqa super().__init__() if fmt not in ["pretty", "default"]: raise ValueError("Format must be either 'pretty' or 'default'.") - self.serializer = JSONSerializer( + self.serializer: JSONSerializer = JSONSerializer( jsonlib="orjson", dumps_kwargs={"option": orjson.OPT_INDENT_2} if fmt == "pretty" else {}, ) @@ -72,13 +76,13 @@ def format(self, record: logging.LogRecord) -> str: class PrefectFormatter(logging.Formatter): def __init__( self, - format=None, - datefmt=None, - style="%", - validate=True, + format: str | None = None, + datefmt: str | None = None, + style: str = "%", + validate: bool = True, *, - defaults=None, - task_run_fmt: Optional[str] = None, + defaults: dict[str, Any] | None = None, + task_run_fmt: str | None = None, flow_run_fmt: Optional[str] = None, ) -> None: """ @@ -118,7 +122,7 @@ def __init__( self._flow_run_style.validate() self._task_run_style.validate() - def formatMessage(self, record: logging.LogRecord): + def formatMessage(self, record: logging.LogRecord) -> str: if record.name == "prefect.flow_runs": style = self._flow_run_style elif record.name == "prefect.task_runs": diff --git a/src/prefect/logging/handlers.py b/src/prefect/logging/handlers.py index a8323cff6f5b..4b76bef55688 100644 --- a/src/prefect/logging/handlers.py +++ b/src/prefect/logging/handlers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import logging import sys @@ -6,7 +8,7 @@ import uuid import warnings from contextlib import asynccontextmanager -from typing import Any, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, TextIO, Type import pendulum from rich.console import Console @@ -34,6 +36,14 @@ PREFECT_LOGGING_TO_API_WHEN_MISSING_FLOW, ) +if sys.version_info >= (3, 12): + StreamHandler = logging.StreamHandler[TextIO] +else: + if TYPE_CHECKING: + StreamHandler = logging.StreamHandler[TextIO] + else: + StreamHandler = logging.StreamHandler + class APILogWorker(BatchedQueueService[Dict[str, Any]]): @property @@ -90,7 +100,7 @@ class APILogHandler(logging.Handler): """ @classmethod - def flush(cls): + def flush(cls) -> None: """ Tell the `APILogWorker` to send any currently enqueued logs and block until completion. @@ -118,7 +128,7 @@ def flush(cls): return APILogWorker.drain_all(timeout=5) @classmethod - async def aflush(cls): + async def aflush(cls) -> bool: """ Tell the `APILogWorker` to send any currently enqueued logs and block until completion. @@ -126,7 +136,7 @@ async def aflush(cls): return await APILogWorker.drain_all() - def emit(self, record: logging.LogRecord): + def emit(self, record: logging.LogRecord) -> None: """ Send a log to the `APILogWorker` """ @@ -239,7 +249,7 @@ def _get_payload_size(self, log: Dict[str, Any]) -> int: class WorkerAPILogHandler(APILogHandler): - def emit(self, record: logging.LogRecord): + def emit(self, record: logging.LogRecord) -> None: # Open-source API servers do not currently support worker logs, and # worker logs only have an associated worker ID when connected to Cloud, # so we won't send worker logs to the API unless they have a worker ID. @@ -278,13 +288,13 @@ def prepare(self, record: logging.LogRecord) -> Dict[str, Any]: return log -class PrefectConsoleHandler(logging.StreamHandler): +class PrefectConsoleHandler(StreamHandler): def __init__( self, - stream=None, - highlighter: Highlighter = PrefectConsoleHighlighter, - styles: Optional[Dict[str, str]] = None, - level: Union[int, str] = logging.NOTSET, + stream: TextIO | None = None, + highlighter: type[Highlighter] = PrefectConsoleHighlighter, + styles: dict[str, str] | None = None, + level: int | str = logging.NOTSET, ): """ The default console handler for Prefect, which highlights log levels, @@ -307,14 +317,14 @@ def __init__( theme = Theme(inherit=False) self.level = level - self.console = Console( + self.console: Console = Console( highlighter=highlighter, theme=theme, file=self.stream, markup=markup_console, ) - def emit(self, record: logging.LogRecord): + def emit(self, record: logging.LogRecord) -> None: try: message = self.format(record) self.console.print(message, soft_wrap=True) diff --git a/src/prefect/logging/highlighters.py b/src/prefect/logging/highlighters.py index b842f7c95240..ac9b84a273d0 100644 --- a/src/prefect/logging/highlighters.py +++ b/src/prefect/logging/highlighters.py @@ -7,7 +7,7 @@ class LevelHighlighter(RegexHighlighter): """Apply style to log levels.""" base_style = "level." - highlights = [ + highlights: list[str] = [ r"(?PDEBUG)", r"(?PINFO)", r"(?PWARNING)", @@ -20,7 +20,7 @@ class UrlHighlighter(RegexHighlighter): """Apply style to urls.""" base_style = "url." - highlights = [ + highlights: list[str] = [ r"(?P(https|http|ws|wss):\/\/[0-9a-zA-Z\$\-\_\+\!`\(\)\,\.\?\/\;\:\&\=\%\#]*)", r"(?P(file):\/\/[0-9a-zA-Z\$\-\_\+\!`\(\)\,\.\?\/\;\:\&\=\%\#]*)", ] @@ -30,7 +30,7 @@ class NameHighlighter(RegexHighlighter): """Apply style to names.""" base_style = "name." - highlights = [ + highlights: list[str] = [ # ?i means case insensitive # ?<= means find string right after the words: flow run r"(?i)(?P(?<=flow run) \'(.*?)\')", @@ -44,7 +44,7 @@ class StateHighlighter(RegexHighlighter): """Apply style to states.""" base_style = "state." - highlights = [ + highlights: list[str] = [ rf"(?P<{state.lower()}_state>{state.title()})" for state in StateType ] + [ r"(?PCached)(?=\(type=COMPLETED\))" # Highlight only "Cached" @@ -55,7 +55,7 @@ class PrefectConsoleHighlighter(RegexHighlighter): """Applies style from multiple highlighters.""" base_style = "log." - highlights = ( + highlights: list[str] = ( LevelHighlighter.highlights + UrlHighlighter.highlights + NameHighlighter.highlights diff --git a/src/prefect/logging/loggers.py b/src/prefect/logging/loggers.py index 45a2f6195f73..3021ccec4e32 100644 --- a/src/prefect/logging/loggers.py +++ b/src/prefect/logging/loggers.py @@ -7,7 +7,7 @@ from contextlib import contextmanager from functools import lru_cache from logging import LogRecord -from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, List, Mapping, MutableMapping, Optional, Union from typing_extensions import Self @@ -15,6 +15,14 @@ from prefect.logging.filters import ObfuscateApiKeyFilter from prefect.telemetry.logging import add_telemetry_log_handler +if sys.version_info >= (3, 12): + LoggingAdapter = logging.LoggerAdapter[logging.Logger] +else: + if TYPE_CHECKING: + LoggingAdapter = logging.LoggerAdapter[logging.Logger] + else: + LoggingAdapter = logging.LoggerAdapter + if TYPE_CHECKING: from prefect.client.schemas import FlowRun as ClientFlowRun from prefect.client.schemas.objects import FlowRun, TaskRun @@ -23,11 +31,6 @@ from prefect.tasks import Task from prefect.workers.base import BaseWorker -if sys.version_info >= (3, 12): - LoggingAdapter = logging.LoggerAdapter[logging.Logger] -else: - LoggingAdapter = logging.LoggerAdapter - class PrefectLogAdapter(LoggingAdapter): """ @@ -39,9 +42,9 @@ class PrefectLogAdapter(LoggingAdapter): not a bug in the LoggingAdapter and subclassing is the intended workaround. """ - extra: Mapping[str, object] | None - - def process(self, msg: str, kwargs: dict[str, Any]) -> tuple[str, dict[str, Any]]: # type: ignore[incompatibleMethodOverride] + def process( + self, msg: str, kwargs: MutableMapping[str, Any] + ) -> tuple[str, MutableMapping[str, Any]]: kwargs["extra"] = {**(self.extra or {}), **(kwargs.get("extra") or {})} return (msg, kwargs) @@ -164,7 +167,7 @@ def flow_run_logger( flow_run: Union["FlowRun", "ClientFlowRun"], flow: Optional["Flow[Any, Any]"] = None, **kwargs: str, -) -> LoggingAdapter: +) -> PrefectLogAdapter: """ Create a flow run logger with the run's metadata attached. @@ -192,7 +195,7 @@ def task_run_logger( flow_run: Optional["FlowRun"] = None, flow: Optional["Flow[Any, Any]"] = None, **kwargs: Any, -): +) -> LoggingAdapter: """ Create a task run logger with the run's metadata attached. @@ -228,7 +231,9 @@ def task_run_logger( ) -def get_worker_logger(worker: "BaseWorker", name: Optional[str] = None): +def get_worker_logger( + worker: "BaseWorker[Any, Any, Any]", name: Optional[str] = None +) -> logging.Logger | LoggingAdapter: """ Create a worker logger with the worker's metadata attached. @@ -364,7 +369,9 @@ def __init__(self, eavesdrop_on: str, level: int = logging.NOTSET): # It's important that we use a very minimalistic formatter for use cases where # we may present these logs back to the user. We shouldn't leak filenames, # versions, or other environmental information. - self.formatter = logging.Formatter("[%(levelname)s]: %(message)s") + self.formatter: logging.Formatter | None = logging.Formatter( + "[%(levelname)s]: %(message)s" + ) def __enter__(self) -> Self: self._target_logger = logging.getLogger(self.eavesdrop_on) @@ -374,7 +381,7 @@ def __enter__(self) -> Self: self._lines = [] return self - def __exit__(self, *_): + def __exit__(self, *_: Any) -> None: if self._target_logger: self._target_logger.removeHandler(self) self._target_logger.level = self._original_level diff --git a/src/prefect/main.py b/src/prefect/main.py index d61e2c80e4d7..1e9e5421998f 100644 --- a/src/prefect/main.py +++ b/src/prefect/main.py @@ -3,7 +3,7 @@ from prefect.deployments import deploy from prefect.states import State from prefect.logging import get_run_logger -from prefect.flows import flow, Flow, serve, aserve +from prefect.flows import FlowDecorator, flow, Flow, serve, aserve from prefect.transactions import Transaction from prefect.tasks import task, Task from prefect.context import tags @@ -58,6 +58,8 @@ inject_renamed_module_alias_finder() +flow: FlowDecorator + # Declare API for type-checkers __all__ = [ diff --git a/src/prefect/runner/runner.py b/src/prefect/runner/runner.py index 27fbde6e1a36..d7f7d55eda7c 100644 --- a/src/prefect/runner/runner.py +++ b/src/prefect/runner/runner.py @@ -32,6 +32,8 @@ def fast_flow(): """ +from __future__ import annotations + import asyncio import datetime import logging @@ -63,13 +65,14 @@ def fast_flow(): import anyio.abc import pendulum from cachetools import LRUCache +from typing_extensions import Self from prefect._internal.concurrency.api import ( create_call, from_async, from_sync, ) -from prefect.client.orchestration import get_client +from prefect.client.orchestration import PrefectClient, get_client from prefect.client.schemas.filters import ( FlowRunFilter, FlowRunFilterId, @@ -196,25 +199,25 @@ def goodbye_flow(name): if name and ("/" in name or "%" in name): raise ValueError("Runner name cannot contain '/' or '%'") - self.name = Path(name).stem if name is not None else f"runner-{uuid4()}" - self._logger = get_logger("runner") - - self.started = False - self.stopping = False - self.pause_on_shutdown = pause_on_shutdown - self.limit = limit or settings.runner.process_limit - self.webserver = webserver - - self.query_seconds = query_seconds or settings.runner.poll_frequency - self._prefetch_seconds = prefetch_seconds - self.heartbeat_seconds = ( + self.name: str = Path(name).stem if name is not None else f"runner-{uuid4()}" + self._logger: "logging.Logger" = get_logger("runner") + + self.started: bool = False + self.stopping: bool = False + self.pause_on_shutdown: bool = pause_on_shutdown + self.limit: int | None = limit or settings.runner.process_limit + self.webserver: bool = webserver + + self.query_seconds: float = query_seconds or settings.runner.poll_frequency + self._prefetch_seconds: float = prefetch_seconds + self.heartbeat_seconds: float | None = ( heartbeat_seconds or settings.runner.heartbeat_frequency ) if self.heartbeat_seconds is not None and self.heartbeat_seconds < 30: raise ValueError("Heartbeat must be 30 seconds or greater.") - self._limiter: Optional[anyio.CapacityLimiter] = None - self._client = get_client() + self._limiter: anyio.CapacityLimiter | None = None + self._client: PrefectClient = get_client() self._submitting_flow_run_ids: set[UUID] = set() self._cancelling_flow_run_ids: set[UUID] = set() self._scheduled_task_scopes: set[UUID] = set() @@ -224,8 +227,8 @@ def goodbye_flow(name): self._tmp_dir: Path = ( Path(tempfile.gettempdir()) / "runner_storage" / str(uuid4()) ) - self._storage_objs: List[RunnerStorage] = [] - self._deployment_storage_map: Dict[UUID, RunnerStorage] = {} + self._storage_objs: list[RunnerStorage] = [] + self._deployment_storage_map: dict[UUID, RunnerStorage] = {} self._loop: Optional[asyncio.AbstractEventLoop] = None @@ -488,7 +491,7 @@ def execute_in_background( return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop) - async def cancel_all(self): + async def cancel_all(self) -> None: runs_to_cancel = [] # done to avoid dictionary size changing during iteration @@ -525,7 +528,7 @@ async def stop(self): async def execute_flow_run( self, flow_run_id: UUID, entrypoint: Optional[str] = None - ): + ) -> None: """ Executes a single flow run with the given ID. @@ -777,7 +780,7 @@ async def _get_and_submit_flow_runs(self): if self.stopping: return runs_response = await self._get_scheduled_flow_runs() - self.last_polled = pendulum.now("UTC") + self.last_polled: pendulum.DateTime = pendulum.now("UTC") return await self._submit_scheduled_flow_runs(flow_run_response=runs_response) async def _check_for_cancelled_flow_runs( @@ -1390,7 +1393,7 @@ async def _run_on_crashed_hooks( await _run_hooks(hooks, flow_run, flow, state) - async def __aenter__(self): + async def __aenter__(self) -> Self: self._logger.debug("Starting runner...") self._client = get_client() self._tmp_dir.mkdir(parents=True) @@ -1412,7 +1415,7 @@ async def __aenter__(self): self.started = True return self - async def __aexit__(self, *exc_info: Any): + async def __aexit__(self, *exc_info: Any) -> None: self._logger.debug("Stopping runner...") if self.pause_on_shutdown: await self._pause_schedules() @@ -1430,7 +1433,7 @@ async def __aexit__(self, *exc_info: Any): shutil.rmtree(str(self._tmp_dir)) del self._runs_task_group, self._loops_task_group - def __repr__(self): + def __repr__(self) -> str: return f"Runner(name={self.name!r})" diff --git a/src/prefect/runner/server.py b/src/prefect/runner/server.py index 9a3688b09b5c..b58cbf02e4d5 100644 --- a/src/prefect/runner/server.py +++ b/src/prefect/runner/server.py @@ -26,12 +26,14 @@ from prefect.utilities.importtools import load_script_as_module if TYPE_CHECKING: + import logging + from prefect.client.schemas.responses import DeploymentResponse from prefect.runner import Runner from pydantic import BaseModel -logger = get_logger("webserver") +logger: "logging.Logger" = get_logger("webserver") RunnableEndpoint = Literal["deployment", "flow", "task"] diff --git a/src/prefect/runner/storage.py b/src/prefect/runner/storage.py index 7a374b65eb8d..031ab1a05139 100644 --- a/src/prefect/runner/storage.py +++ b/src/prefect/runner/storage.py @@ -33,7 +33,7 @@ class RunnerStorage(Protocol): remotely stored flow code. """ - def set_base_path(self, path: Path): + def set_base_path(self, path: Path) -> None: """ Sets the base path to use when pulling contents from remote storage to local storage. @@ -55,7 +55,7 @@ def destination(self) -> Path: """ ... - async def pull_code(self): + async def pull_code(self) -> None: """ Pulls contents from remote storage to the local filesystem. """ @@ -150,7 +150,7 @@ def __init__( def destination(self) -> Path: return self._storage_base_path / self._name - def set_base_path(self, path: Path): + def set_base_path(self, path: Path) -> None: self._storage_base_path = path @property @@ -221,7 +221,7 @@ async def is_sparsely_checked_out(self) -> bool: except Exception: return False - async def pull_code(self): + async def pull_code(self) -> None: """ Pulls the contents of the configured repository to the local filesystem. """ @@ -324,7 +324,7 @@ async def _clone_repo(self): cwd=self.destination, ) - def __eq__(self, __value) -> bool: + def __eq__(self, __value: Any) -> bool: if isinstance(__value, GitRepository): return ( self._url == __value._url @@ -339,7 +339,7 @@ def __repr__(self) -> str: f" branch={self._branch!r})" ) - def to_pull_step(self) -> Dict: + def to_pull_step(self) -> dict[str, Any]: pull_step = { "prefect.deployments.steps.git_clone": { "repository": self._url, @@ -466,7 +466,7 @@ def replace_blocks_with_values(obj: Any) -> Any: return fsspec.filesystem(scheme, **settings_with_block_values) - def set_base_path(self, path: Path): + def set_base_path(self, path: Path) -> None: self._storage_base_path = path @property @@ -492,7 +492,7 @@ def _remote_path(self) -> Path: _, netloc, urlpath, _, _ = urlsplit(self._url) return Path(netloc) / Path(urlpath.lstrip("/")) - async def pull_code(self): + async def pull_code(self) -> None: """ Pulls contents from remote storage to the local filesystem. """ @@ -522,7 +522,7 @@ async def pull_code(self): f" {self.destination!r}" ) from exc - def to_pull_step(self) -> dict: + def to_pull_step(self) -> dict[str, Any]: """ Returns a dictionary representation of the storage object that can be used as a deployment pull step. @@ -551,7 +551,7 @@ def replace_block_with_placeholder(obj: Any) -> Any: ] = required_package return step - def __eq__(self, __value) -> bool: + def __eq__(self, __value: Any) -> bool: """ Equality check for runner storage objects. """ @@ -590,7 +590,7 @@ def __init__( else str(uuid4()) ) - def set_base_path(self, path: Path): + def set_base_path(self, path: Path) -> None: self._storage_base_path = path @property @@ -601,12 +601,12 @@ def pull_interval(self) -> Optional[int]: def destination(self) -> Path: return self._storage_base_path / self._name - async def pull_code(self): + async def pull_code(self) -> None: if not self.destination.exists(): self.destination.mkdir(parents=True, exist_ok=True) await self._block.get_directory(local_path=str(self.destination)) - def to_pull_step(self) -> dict: + def to_pull_step(self) -> dict[str, Any]: # Give blocks the change to implement their own pull step if hasattr(self._block, "get_pull_step"): return self._block.get_pull_step() @@ -623,7 +623,7 @@ def to_pull_step(self) -> dict: } } - def __eq__(self, __value) -> bool: + def __eq__(self, __value: Any) -> bool: if isinstance(__value, BlockStorageAdapter): return self._block == __value._block return False @@ -658,19 +658,19 @@ def __init__( def destination(self) -> Path: return self._path - def set_base_path(self, path: Path): + def set_base_path(self, path: Path) -> None: self._storage_base_path = path @property def pull_interval(self) -> Optional[int]: return self._pull_interval - async def pull_code(self): + async def pull_code(self) -> None: # Local storage assumes the code already exists on the local filesystem # and does not need to be pulled from a remote location pass - def to_pull_step(self) -> dict: + def to_pull_step(self) -> dict[str, Any]: """ Returns a dictionary representation of the storage object that can be used as a deployment pull step. @@ -682,7 +682,7 @@ def to_pull_step(self) -> dict: } return step - def __eq__(self, __value) -> bool: + def __eq__(self, __value: Any) -> bool: if isinstance(__value, LocalStorage): return self._path == __value._path return False diff --git a/src/prefect/runner/submit.py b/src/prefect/runner/submit.py index ec42a4029a79..56c14b392e98 100644 --- a/src/prefect/runner/submit.py +++ b/src/prefect/runner/submit.py @@ -1,11 +1,13 @@ +from __future__ import annotations + import asyncio import inspect import uuid -from typing import Any, Dict, List, Optional, Union, overload +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, overload import anyio import httpx -from typing_extensions import Literal +from typing_extensions import Literal, TypeAlias from prefect.client.orchestration import get_client from prefect.client.schemas.filters import FlowRunFilter, TaskRunFilter @@ -22,12 +24,17 @@ from prefect.tasks import Task from prefect.utilities.asyncutils import sync_compatible -logger = get_logger("webserver") +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger("webserver") + +FlowOrTask: TypeAlias = Union[Flow[Any, Any], Task[Any, Any]] async def _submit_flow_to_runner( - flow: Flow, - parameters: Dict[str, Any], + flow: Flow[Any, Any], + parameters: dict[str, Any], retry_failed_submissions: bool = True, ) -> FlowRun: """ @@ -91,7 +98,7 @@ async def _submit_flow_to_runner( @overload def submit_to_runner( - prefect_callable: Union[Flow, Task], + prefect_callable: Union[Flow[Any, Any], Task[Any, Any]], parameters: Dict[str, Any], retry_failed_submissions: bool = True, ) -> FlowRun: @@ -100,19 +107,19 @@ def submit_to_runner( @overload def submit_to_runner( - prefect_callable: Union[Flow, Task], - parameters: List[Dict[str, Any]], + prefect_callable: Union[Flow[Any, Any], Task[Any, Any]], + parameters: list[dict[str, Any]], retry_failed_submissions: bool = True, -) -> List[FlowRun]: +) -> list[FlowRun]: ... @sync_compatible async def submit_to_runner( - prefect_callable: Union[Flow, Task], - parameters: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + prefect_callable: FlowOrTask, + parameters: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = None, retry_failed_submissions: bool = True, -) -> Union[FlowRun, List[FlowRun]]: +) -> Union[FlowRun, list[FlowRun]]: """ Submit a callable in the background via the runner webserver one or more times. diff --git a/src/prefect/runtime/deployment.py b/src/prefect/runtime/deployment.py index c0acff8c39bb..9dd2c49938d2 100644 --- a/src/prefect/runtime/deployment.py +++ b/src/prefect/runtime/deployment.py @@ -25,8 +25,10 @@ def get_task_runner(): object or those directly provided via API for this run """ +from __future__ import annotations + import os -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, List, Optional from prefect._internal.concurrency.api import create_call, from_sync from prefect.client.orchestration import get_client @@ -34,12 +36,17 @@ def get_task_runner(): from .flow_run import _get_flow_run +if TYPE_CHECKING: + from prefect.client.schemas.responses import DeploymentResponse + __all__ = ["id", "flow_run_id", "name", "parameters", "version"] -CACHED_DEPLOYMENT = {} +CACHED_DEPLOYMENT: dict[str, "DeploymentResponse"] = {} -type_cast = { +type_cast: dict[ + type[bool] | type[int] | type[float] | type[str] | type[None], Callable[[Any], Any] +] = { bool: lambda x: x.lower() == "true", int: int, float: float, @@ -88,7 +95,7 @@ def __dir__() -> List[str]: return sorted(__all__) -async def _get_deployment(deployment_id): +async def _get_deployment(deployment_id: str) -> "DeploymentResponse": # deployments won't change between calls so let's avoid the lifecycle of a client if CACHED_DEPLOYMENT.get(deployment_id): return CACHED_DEPLOYMENT[deployment_id] @@ -115,7 +122,7 @@ def get_id() -> Optional[str]: return str(deployment_id) -def get_parameters() -> Dict: +def get_parameters() -> dict[str, Any]: run_id = get_flow_run_id() if run_id is None: return {} @@ -126,7 +133,7 @@ def get_parameters() -> Dict: return flow_run.parameters or {} -def get_name() -> Optional[Dict]: +def get_name() -> Optional[str]: dep_id = get_id() if dep_id is None: @@ -138,7 +145,7 @@ def get_name() -> Optional[Dict]: return deployment.name -def get_version() -> Optional[Dict]: +def get_version() -> Optional[str]: dep_id = get_id() if dep_id is None: @@ -154,7 +161,7 @@ def get_flow_run_id() -> Optional[str]: return os.getenv("PREFECT__FLOW_RUN_ID") -FIELDS = { +FIELDS: dict[str, Callable[[], Any]] = { "id": get_id, "flow_run_id": get_flow_run_id, "parameters": get_parameters, diff --git a/src/prefect/runtime/flow_run.py b/src/prefect/runtime/flow_run.py index 3c9be513bba8..0c4709e7aa9c 100644 --- a/src/prefect/runtime/flow_run.py +++ b/src/prefect/runtime/flow_run.py @@ -20,8 +20,10 @@ - `run_count`: the number of times this flow run has been run """ +from __future__ import annotations + import os -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import pendulum @@ -30,6 +32,9 @@ from prefect.context import FlowRunContext, TaskRunContext from prefect.settings import PREFECT_API_URL, PREFECT_UI_URL +if TYPE_CHECKING: + from prefect.client.schemas.objects import Flow, FlowRun, TaskRun + __all__ = [ "id", "tags", @@ -56,7 +61,15 @@ def _pendulum_parse(dt: str) -> pendulum.DateTime: return pendulum.parse(dt, tz=None, strict=False).set(tz="UTC") -type_cast = { +type_cast: dict[ + type[bool] + | type[int] + | type[float] + | type[str] + | type[None] + | type[pendulum.DateTime], + Callable[[Any], Any], +] = { bool: lambda x: x.lower() == "true", int: int, float: float, @@ -106,17 +119,17 @@ def __dir__() -> List[str]: return sorted(__all__) -async def _get_flow_run(flow_run_id): +async def _get_flow_run(flow_run_id: str) -> "FlowRun": async with get_client() as client: return await client.read_flow_run(flow_run_id) -async def _get_task_run(task_run_id): +async def _get_task_run(task_run_id: str) -> "TaskRun": async with get_client() as client: return await client.read_task_run(task_run_id) -async def _get_flow_from_run(flow_run_id): +async def _get_flow_from_run(flow_run_id: str) -> "Flow": async with get_client() as client: flow_run = await client.read_flow_run(flow_run_id) return await client.read_flow(flow_run.flow_id) @@ -323,7 +336,7 @@ def get_job_variables() -> Optional[Dict[str, Any]]: return flow_run_ctx.flow_run.job_variables if flow_run_ctx else None -FIELDS = { +FIELDS: dict[str, Callable[[], Any]] = { "id": get_id, "tags": get_tags, "scheduled_start_time": get_scheduled_start_time, diff --git a/src/prefect/runtime/task_run.py b/src/prefect/runtime/task_run.py index da28070fa826..9b3e4cd5b1f9 100644 --- a/src/prefect/runtime/task_run.py +++ b/src/prefect/runtime/task_run.py @@ -15,15 +15,19 @@ - `task_name`: the name of the task """ +from __future__ import annotations + import os -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional from prefect.context import TaskRunContext __all__ = ["id", "tags", "name", "parameters", "run_count", "task_name"] -type_cast = { +type_cast: dict[ + type[bool] | type[int] | type[float] | type[str] | type[None], Callable[[Any], Any] +] = { bool: lambda x: x.lower() == "true", int: int, float: float, @@ -118,7 +122,7 @@ def get_parameters() -> Dict[str, Any]: return {} -FIELDS = { +FIELDS: dict[str, Callable[[], Any]] = { "id": get_id, "tags": get_tags, "name": get_name, diff --git a/src/prefect/server/api/admin.py b/src/prefect/server/api/admin.py index b55eb0bafd49..51c37460be64 100644 --- a/src/prefect/server/api/admin.py +++ b/src/prefect/server/api/admin.py @@ -9,7 +9,7 @@ from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter -router = PrefectRouter(prefix="/admin", tags=["Admin"]) +router: PrefectRouter = PrefectRouter(prefix="/admin", tags=["Admin"]) @router.get("/settings") @@ -37,7 +37,7 @@ async def clear_database( description="Pass confirm=True to confirm you want to modify the database.", ), response: Response = None, # type: ignore -): +) -> None: """Clear all database tables without dropping them.""" if not confirm: response.status_code = status.HTTP_400_BAD_REQUEST @@ -58,7 +58,7 @@ async def drop_database( description="Pass confirm=True to confirm you want to modify the database.", ), response: Response = None, -): +) -> None: """Drop all database objects.""" if not confirm: response.status_code = status.HTTP_400_BAD_REQUEST @@ -76,7 +76,7 @@ async def create_database( description="Pass confirm=True to confirm you want to modify the database.", ), response: Response = None, -): +) -> None: """Create all database objects.""" if not confirm: response.status_code = status.HTTP_400_BAD_REQUEST diff --git a/src/prefect/server/api/artifacts.py b/src/prefect/server/api/artifacts.py index d3300b1a348f..9e6c049c4a48 100644 --- a/src/prefect/server/api/artifacts.py +++ b/src/prefect/server/api/artifacts.py @@ -14,7 +14,7 @@ from prefect.server.schemas import actions, core, filters, sorting from prefect.server.utilities.server import PrefectRouter -router = PrefectRouter( +router: PrefectRouter = PrefectRouter( prefix="/artifacts", tags=["Artifacts"], ) @@ -191,7 +191,7 @@ async def update_artifact( ..., description="The ID of the artifact to update.", alias="id" ), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Update an artifact in the database. """ @@ -211,7 +211,7 @@ async def delete_artifact( ..., description="The ID of the artifact to delete.", alias="id" ), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Delete an artifact from the database. """ diff --git a/src/prefect/server/api/automations.py b/src/prefect/server/api/automations.py index 832e66515f17..847442273177 100644 --- a/src/prefect/server/api/automations.py +++ b/src/prefect/server/api/automations.py @@ -27,7 +27,7 @@ ValidationError as JSONSchemaValidationError, ) -router = PrefectRouter( +router: PrefectRouter = PrefectRouter( prefix="/automations", tags=["Automations"], dependencies=[], @@ -91,7 +91,7 @@ async def update_automation( automation: AutomationUpdate, automation_id: UUID = Path(..., alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: # reset any client-provided IDs on the provided triggers automation.trigger.reset_ids() @@ -134,7 +134,7 @@ async def patch_automation( automation: AutomationPartialUpdate, automation_id: UUID = Path(..., alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: try: async with db.session_context(begin_transaction=True) as session: updated = await automations_models.update_automation( @@ -159,7 +159,7 @@ async def patch_automation( async def delete_automation( automation_id: UUID = Path(..., alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: async with db.session_context(begin_transaction=True) as session: deleted = await automations_models.delete_automation( session=session, @@ -228,7 +228,7 @@ async def read_automations_related_to_resource( async def delete_automations_owned_by_resource( resource_id: str = Path(..., alias="resource_id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: async with db.session_context(begin_transaction=True) as session: await automations_models.delete_automations_owned_by_resource( session, diff --git a/src/prefect/server/api/block_capabilities.py b/src/prefect/server/api/block_capabilities.py index 3da721adaf30..f386391f12ec 100644 --- a/src/prefect/server/api/block_capabilities.py +++ b/src/prefect/server/api/block_capabilities.py @@ -10,7 +10,9 @@ from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter -router = PrefectRouter(prefix="/block_capabilities", tags=["Block capabilities"]) +router: PrefectRouter = PrefectRouter( + prefix="/block_capabilities", tags=["Block capabilities"] +) @router.get("/") diff --git a/src/prefect/server/api/block_documents.py b/src/prefect/server/api/block_documents.py index dd8deb82b6dc..02153f6cf0e5 100644 --- a/src/prefect/server/api/block_documents.py +++ b/src/prefect/server/api/block_documents.py @@ -12,7 +12,9 @@ from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter -router = PrefectRouter(prefix="/block_documents", tags=["Block documents"]) +router: PrefectRouter = PrefectRouter( + prefix="/block_documents", tags=["Block documents"] +) @router.post("/", status_code=status.HTTP_201_CREATED) @@ -124,7 +126,7 @@ async def delete_block_document( ..., description="The block document id", alias="id" ), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: async with db.session_context(begin_transaction=True) as session: result = await models.block_documents.delete_block_document( session=session, block_document_id=block_document_id @@ -142,7 +144,7 @@ async def update_block_document_data( ..., description="The block document id", alias="id" ), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: try: async with db.session_context(begin_transaction=True) as session: result = await models.block_documents.update_block_document( diff --git a/src/prefect/server/api/block_schemas.py b/src/prefect/server/api/block_schemas.py index d079f24187ce..36525a0e9531 100644 --- a/src/prefect/server/api/block_schemas.py +++ b/src/prefect/server/api/block_schemas.py @@ -21,7 +21,7 @@ from prefect.server.models.block_schemas import MissingBlockTypeException from prefect.server.utilities.server import PrefectRouter -router = PrefectRouter(prefix="/block_schemas", tags=["Block schemas"]) +router: PrefectRouter = PrefectRouter(prefix="/block_schemas", tags=["Block schemas"]) @router.post("/", status_code=status.HTTP_201_CREATED) @@ -68,8 +68,8 @@ async def create_block_schema( async def delete_block_schema( block_schema_id: UUID = Path(..., description="The block schema id", alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), - api_version=Depends(dependencies.provide_request_api_version), -): + api_version: str = Depends(dependencies.provide_request_api_version), +) -> None: """ Delete a block schema by id. """ diff --git a/src/prefect/server/api/block_types.py b/src/prefect/server/api/block_types.py index 50351781dc64..b29a8eeb0144 100644 --- a/src/prefect/server/api/block_types.py +++ b/src/prefect/server/api/block_types.py @@ -10,7 +10,7 @@ from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter -router = PrefectRouter(prefix="/block_types", tags=["Block types"]) +router: PrefectRouter = PrefectRouter(prefix="/block_types", tags=["Block types"]) @router.post("/", status_code=status.HTTP_201_CREATED) @@ -101,7 +101,7 @@ async def update_block_type( block_type: schemas.actions.BlockTypeUpdate, block_type_id: UUID = Path(..., description="The block type ID", alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Update a block type. """ @@ -131,7 +131,7 @@ async def update_block_type( async def delete_block_type( block_type_id: UUID = Path(..., description="The block type ID", alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: async with db.session_context(begin_transaction=True) as session: db_block_type = await models.block_types.read_block_type( session=session, block_type_id=block_type_id @@ -204,7 +204,7 @@ async def read_block_document_by_name_for_block_type( @router.post("/install_system_block_types") async def install_system_block_types( db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: # Don't begin a transaction. _install_protected_system_blocks will manage # the transactions. async with db.session_context(begin_transaction=False) as session: diff --git a/src/prefect/server/api/clients.py b/src/prefect/server/api/clients.py index 1c4d5424e5cc..3a2dd6a1d763 100644 --- a/src/prefect/server/api/clients.py +++ b/src/prefect/server/api/clients.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional from urllib.parse import quote from uuid import UUID @@ -19,7 +19,10 @@ from prefect.server.schemas.responses import DeploymentResponse, OrchestrationResult from prefect.types import StrictVariableValue -logger = get_logger(__name__) +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger(__name__) class BaseClient: diff --git a/src/prefect/server/api/collections.py b/src/prefect/server/api/collections.py index 66b47d9d2068..a5c01ebc3f82 100644 --- a/src/prefect/server/api/collections.py +++ b/src/prefect/server/api/collections.py @@ -8,9 +8,11 @@ from prefect.server.utilities.server import PrefectRouter -router = PrefectRouter(prefix="/collections", tags=["Collections"]) +router: PrefectRouter = PrefectRouter(prefix="/collections", tags=["Collections"]) -GLOBAL_COLLECTIONS_VIEW_CACHE: TTLCache = TTLCache(maxsize=200, ttl=60 * 10) +GLOBAL_COLLECTIONS_VIEW_CACHE: TTLCache[str, dict[str, Any]] = TTLCache( + maxsize=200, ttl=60 * 10 +) REGISTRY_VIEWS = ( "https://raw.githubusercontent.com/PrefectHQ/prefect-collection-registry/main/views" @@ -43,7 +45,7 @@ async def read_view_content(view: str) -> Dict[str, Any]: raise -async def get_collection_view(view: str): +async def get_collection_view(view: str) -> dict[str, Any]: try: return GLOBAL_COLLECTIONS_VIEW_CACHE[view] except KeyError: diff --git a/src/prefect/server/api/concurrency_limits.py b/src/prefect/server/api/concurrency_limits.py index 990cac57feeb..231f88b19e6d 100644 --- a/src/prefect/server/api/concurrency_limits.py +++ b/src/prefect/server/api/concurrency_limits.py @@ -17,7 +17,9 @@ from prefect.server.utilities.server import PrefectRouter from prefect.settings import PREFECT_TASK_RUN_TAG_CONCURRENCY_SLOT_WAIT_SECONDS -router = PrefectRouter(prefix="/concurrency_limits", tags=["Concurrency Limits"]) +router: PrefectRouter = PrefectRouter( + prefix="/concurrency_limits", tags=["Concurrency Limits"] +) @router.post("/") @@ -119,7 +121,7 @@ async def reset_concurrency_limit_by_tag( description="Manual override for active concurrency limit slots.", ), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: async with db.session_context(begin_transaction=True) as session: model = await models.concurrency_limits.reset_concurrency_limit_by_tag( session=session, tag=tag, slot_override=slot_override @@ -136,7 +138,7 @@ async def delete_concurrency_limit( ..., description="The concurrency limit id", alias="id" ), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: async with db.session_context(begin_transaction=True) as session: result = await models.concurrency_limits.delete_concurrency_limit( session=session, concurrency_limit_id=concurrency_limit_id @@ -151,7 +153,7 @@ async def delete_concurrency_limit( async def delete_concurrency_limit_by_tag( tag: str = Path(..., description="The tag name"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: async with db.session_context(begin_transaction=True) as session: result = await models.concurrency_limits.delete_concurrency_limit_by_tag( session=session, tag=tag @@ -263,7 +265,7 @@ async def decrement_concurrency_limits_v1( ..., description="The ID of the task run releasing the slot" ), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: async with db.session_context(begin_transaction=True) as session: filtered_limits = ( await concurrency_limits.filter_concurrency_limits_for_orchestration( diff --git a/src/prefect/server/api/concurrency_limits_v2.py b/src/prefect/server/api/concurrency_limits_v2.py index 76d96313c060..75b0f25115c2 100644 --- a/src/prefect/server/api/concurrency_limits_v2.py +++ b/src/prefect/server/api/concurrency_limits_v2.py @@ -11,7 +11,9 @@ from prefect.server.utilities.schemas import PrefectBaseModel from prefect.server.utilities.server import PrefectRouter -router = PrefectRouter(prefix="/v2/concurrency_limits", tags=["Concurrency Limits V2"]) +router: PrefectRouter = PrefectRouter( + prefix="/v2/concurrency_limits", tags=["Concurrency Limits V2"] +) @router.post("/", status_code=status.HTTP_201_CREATED) @@ -85,7 +87,7 @@ async def update_concurrency_limit_v2( ..., description="The ID or name of the concurrency limit", alias="id_or_name" ), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: if isinstance(id_or_name, str): # TODO: this seems like it shouldn't be necessary try: id_or_name = UUID(id_or_name) @@ -115,7 +117,7 @@ async def delete_concurrency_limit_v2( ..., description="The ID or name of the concurrency limit", alias="id_or_name" ), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: if isinstance(id_or_name, str): # TODO: this seems like it shouldn't be necessary try: id_or_name = UUID(id_or_name) diff --git a/src/prefect/server/api/csrf_token.py b/src/prefect/server/api/csrf_token.py index 94eaa399da7f..b4554876b9ae 100644 --- a/src/prefect/server/api/csrf_token.py +++ b/src/prefect/server/api/csrf_token.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + from fastapi import Depends, Query, status from starlette.exceptions import HTTPException @@ -7,9 +9,12 @@ from prefect.server.utilities.server import PrefectRouter from prefect.settings import PREFECT_SERVER_CSRF_PROTECTION_ENABLED -logger = get_logger("server.api") +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger("server.api") -router = PrefectRouter(prefix="/csrf-token") +router: PrefectRouter = PrefectRouter(prefix="/csrf-token") @router.get("") diff --git a/src/prefect/server/api/dependencies.py b/src/prefect/server/api/dependencies.py index 6f9fb797b12e..b20db69cf127 100644 --- a/src/prefect/server/api/dependencies.py +++ b/src/prefect/server/api/dependencies.py @@ -1,11 +1,12 @@ """ Utilities for injecting FastAPI dependencies. """ +from __future__ import annotations import logging import re from base64 import b64decode -from typing import Annotated, Optional +from typing import Annotated, Any, Optional from uuid import UUID from fastapi import Body, Depends, Header, HTTPException, status @@ -16,13 +17,15 @@ from prefect.settings import PREFECT_API_DEFAULT_LIMIT -def provide_request_api_version(x_prefect_api_version: str = Header(None)): +def provide_request_api_version( + x_prefect_api_version: str = Header(None), +) -> Version | None: if not x_prefect_api_version: return # parse version try: - major, minor, patch = [int(v) for v in x_prefect_api_version.split(".")] + _, _, _ = [int(v) for v in x_prefect_api_version.split(".")] except ValueError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -46,15 +49,15 @@ class EnforceMinimumAPIVersion: def __init__(self, minimum_api_version: str, logger: logging.Logger): self.minimum_api_version = minimum_api_version versions = [int(v) for v in minimum_api_version.split(".")] - self.api_major = versions[0] - self.api_minor = versions[1] - self.api_patch = versions[2] + self.api_major: int = versions[0] + self.api_minor: int = versions[1] + self.api_patch: int = versions[2] self.logger = logger async def __call__( self, x_prefect_api_version: str = Header(None), - ): + ) -> None: request_version = x_prefect_api_version # if no version header, assume latest and continue @@ -96,7 +99,7 @@ async def _notify_of_outdated_version(self, request_version: str): ) -def LimitBody() -> Depends: +def LimitBody() -> Any: """ A `fastapi.Depends` factory for pulling a `limit: int` parameter from the request body while determining the default from the current settings. @@ -163,7 +166,7 @@ def get_updated_by( return None -def is_ephemeral_request(request: Request): +def is_ephemeral_request(request: Request) -> bool: """ A dependency that returns whether the request is to an ephemeral server. """ diff --git a/src/prefect/server/api/deployments.py b/src/prefect/server/api/deployments.py index b0b89b721f79..96463c09e79a 100644 --- a/src/prefect/server/api/deployments.py +++ b/src/prefect/server/api/deployments.py @@ -37,7 +37,7 @@ validate, ) -router = PrefectRouter(prefix="/deployments", tags=["Deployments"]) +router: PrefectRouter = PrefectRouter(prefix="/deployments", tags=["Deployments"]) def _multiple_schedules_error(deployment_id) -> HTTPException: @@ -181,7 +181,7 @@ async def update_deployment( deployment: schemas.actions.DeploymentUpdate, deployment_id: UUID = Path(..., description="The deployment id", alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: async with db.session_context(begin_transaction=True) as session: existing_deployment = await models.deployments.read_deployment( session=session, deployment_id=deployment_id @@ -485,7 +485,7 @@ async def count_deployments( async def delete_deployment( deployment_id: UUID = Path(..., description="The deployment id", alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Delete a deployment by id. """ @@ -811,7 +811,7 @@ async def update_deployment_schedule( default=..., description="The updated schedule" ), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: async with db.session_context(begin_transaction=True) as session: deployment = await models.deployments.read_deployment( session=session, deployment_id=deployment_id @@ -844,7 +844,7 @@ async def delete_deployment_schedule( deployment_id: UUID = Path(..., description="The deployment id", alias="id"), schedule_id: UUID = Path(..., description="The schedule id", alias="schedule_id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: async with db.session_context(begin_transaction=True) as session: deployment = await models.deployments.read_deployment( session=session, deployment_id=deployment_id diff --git a/src/prefect/server/api/events.py b/src/prefect/server/api/events.py index 87f4f914f484..7e5c87f7e0c3 100644 --- a/src/prefect/server/api/events.py +++ b/src/prefect/server/api/events.py @@ -1,5 +1,5 @@ import base64 -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional from fastapi import Response, WebSocket, status from fastapi.exceptions import HTTPException @@ -34,17 +34,20 @@ PREFECT_EVENTS_WEBSOCKET_BACKFILL_PAGE_SIZE, ) -logger = get_logger(__name__) +if TYPE_CHECKING: + import logging +logger: "logging.Logger" = get_logger(__name__) -router = PrefectRouter(prefix="/events", tags=["Events"]) + +router: PrefectRouter = PrefectRouter(prefix="/events", tags=["Events"]) @router.post("", status_code=status.HTTP_204_NO_CONTENT, response_class=Response) async def create_events( events: List[Event], ephemeral_request: bool = Depends(is_ephemeral_request), -): +) -> None: """Record a batch of Events""" if ephemeral_request: await EventsPipeline().process_events(events) diff --git a/src/prefect/server/api/flow_run_notification_policies.py b/src/prefect/server/api/flow_run_notification_policies.py index c0a6da90274b..47b4f871fc2a 100644 --- a/src/prefect/server/api/flow_run_notification_policies.py +++ b/src/prefect/server/api/flow_run_notification_policies.py @@ -13,7 +13,7 @@ from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter -router = PrefectRouter( +router: PrefectRouter = PrefectRouter( prefix="/flow_run_notification_policies", tags=["Flow Run Notification Policies"] ) @@ -39,7 +39,7 @@ async def update_flow_run_notification_policy( ..., description="The flow run notification policy id", alias="id" ), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Updates an existing flow run notification policy. """ @@ -104,7 +104,7 @@ async def delete_flow_run_notification_policy( ..., description="The flow run notification policy id", alias="id" ), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Delete a flow run notification policy by id. """ diff --git a/src/prefect/server/api/flow_run_states.py b/src/prefect/server/api/flow_run_states.py index 72784f0175a1..438e531a9498 100644 --- a/src/prefect/server/api/flow_run_states.py +++ b/src/prefect/server/api/flow_run_states.py @@ -12,7 +12,9 @@ from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter -router = PrefectRouter(prefix="/flow_run_states", tags=["Flow Run States"]) +router: PrefectRouter = PrefectRouter( + prefix="/flow_run_states", tags=["Flow Run States"] +) @router.get("/{id}") diff --git a/src/prefect/server/api/flow_runs.py b/src/prefect/server/api/flow_runs.py index b3e36cf2a0b2..8ddfa9d2689f 100644 --- a/src/prefect/server/api/flow_runs.py +++ b/src/prefect/server/api/flow_runs.py @@ -5,7 +5,7 @@ import csv import datetime import io -from typing import Any, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional from uuid import UUID import orjson @@ -37,7 +37,10 @@ read_flow_run_graph, ) from prefect.server.orchestration import dependencies as orchestration_dependencies -from prefect.server.orchestration.policies import BaseOrchestrationPolicy +from prefect.server.orchestration.policies import ( + FlowRunOrchestrationPolicy, + TaskRunOrchestrationPolicy, +) from prefect.server.schemas.graph import Graph from prefect.server.schemas.responses import ( FlowRunPaginationResponse, @@ -47,9 +50,12 @@ from prefect.types import DateTime from prefect.utilities import schema_tools -logger = get_logger("server.api") +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger("server.api") -router = PrefectRouter(prefix="/flow_runs", tags=["Flow Runs"]) +router: PrefectRouter = PrefectRouter(prefix="/flow_runs", tags=["Flow Runs"]) @router.post("/") @@ -101,7 +107,7 @@ async def update_flow_run( flow_run: schemas.actions.FlowRunUpdate, flow_run_id: UUID = Path(..., description="The flow run id", alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Updates a flow run. """ @@ -349,18 +355,18 @@ async def read_flow_run_graph_v2( async def resume_flow_run( flow_run_id: UUID = Path(..., description="The flow run id", alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), - run_input: Optional[Dict] = Body(default=None, embed=True), + run_input: Optional[dict[str, Any]] = Body(default=None, embed=True), response: Response = None, - flow_policy: Type[BaseOrchestrationPolicy] = Depends( + flow_policy: type[FlowRunOrchestrationPolicy] = Depends( orchestration_dependencies.provide_flow_policy ), - task_policy: BaseOrchestrationPolicy = Depends( + task_policy: type[TaskRunOrchestrationPolicy] = Depends( orchestration_dependencies.provide_task_policy ), orchestration_parameters: Dict[str, Any] = Depends( orchestration_dependencies.provide_flow_orchestration_parameters ), - api_version=Depends(dependencies.provide_request_api_version), + api_version: str = Depends(dependencies.provide_request_api_version), ) -> OrchestrationResult: """ Resume a paused flow run. @@ -539,7 +545,7 @@ async def read_flow_runs( async def delete_flow_run( flow_run_id: UUID = Path(..., description="The flow run id", alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Delete a flow run by id. """ @@ -565,14 +571,14 @@ async def set_flow_run_state( ), ), db: PrefectDBInterface = Depends(provide_database_interface), - response: Response = None, - flow_policy: Type[BaseOrchestrationPolicy] = Depends( + flow_policy: type[FlowRunOrchestrationPolicy] = Depends( orchestration_dependencies.provide_flow_policy ), orchestration_parameters: Dict[str, Any] = Depends( orchestration_dependencies.provide_flow_orchestration_parameters ), - api_version=Depends(dependencies.provide_request_api_version), + response: Response = None, + api_version: str = Depends(dependencies.provide_request_api_version), ) -> OrchestrationResult: """Set a flow run state, invoking any orchestration rules.""" @@ -611,7 +617,7 @@ async def create_flow_run_input( value: bytes = Body(..., description="The value of the input"), sender: Optional[str] = Body(None, description="The sender of the input"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Create a key/value input for a flow run. """ @@ -693,7 +699,7 @@ async def delete_flow_run_input( flow_run_id: UUID = Path(..., description="The flow run id", alias="id"), key: str = Path(..., description="The input key", alias="key"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Delete a flow run input """ @@ -848,7 +854,7 @@ async def update_flow_run_labels( flow_run_id: UUID = Path(..., description="The flow run id", alias="id"), labels: Dict[str, Any] = Body(..., description="The labels to update"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Update the labels of a flow run. """ diff --git a/src/prefect/server/api/flows.py b/src/prefect/server/api/flows.py index 4f7b5696c88d..a62adaacab5e 100644 --- a/src/prefect/server/api/flows.py +++ b/src/prefect/server/api/flows.py @@ -16,7 +16,7 @@ from prefect.server.schemas.responses import FlowPaginationResponse from prefect.server.utilities.server import PrefectRouter -router = PrefectRouter(prefix="/flows", tags=["Flows"]) +router: PrefectRouter = PrefectRouter(prefix="/flows", tags=["Flows"]) @router.post("/") @@ -46,7 +46,7 @@ async def update_flow( flow: schemas.actions.FlowUpdate, flow_id: UUID = Path(..., description="The flow id", alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Updates a flow. """ @@ -150,7 +150,7 @@ async def read_flows( async def delete_flow( flow_id: UUID = Path(..., description="The flow id", alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Delete a flow by id. """ diff --git a/src/prefect/server/api/logs.py b/src/prefect/server/api/logs.py index 18b723c1f91f..b473177fc3d5 100644 --- a/src/prefect/server/api/logs.py +++ b/src/prefect/server/api/logs.py @@ -12,14 +12,14 @@ from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter -router = PrefectRouter(prefix="/logs", tags=["Logs"]) +router: PrefectRouter = PrefectRouter(prefix="/logs", tags=["Logs"]) @router.post("/", status_code=status.HTTP_201_CREATED) async def create_logs( logs: List[schemas.actions.LogCreate], db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """Create new logs from the provided schema.""" for batch in models.logs.split_logs_into_batches(logs): async with db.session_context(begin_transaction=True) as session: diff --git a/src/prefect/server/api/root.py b/src/prefect/server/api/root.py index 8a52153a4b8b..27420a12d563 100644 --- a/src/prefect/server/api/root.py +++ b/src/prefect/server/api/root.py @@ -8,11 +8,11 @@ from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter -router = PrefectRouter(prefix="", tags=["Root"]) +router: PrefectRouter = PrefectRouter(prefix="", tags=["Root"]) @router.get("/hello") -async def hello(): +async def hello() -> str: """Say hello!""" return "👋" @@ -20,7 +20,7 @@ async def hello(): @router.get("/ready") async def perform_readiness_check( db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> JSONResponse: is_db_connectable = await db.is_db_connectable() if is_db_connectable: diff --git a/src/prefect/server/api/run_history.py b/src/prefect/server/api/run_history.py index b707b06ed4fb..85e7c3577861 100644 --- a/src/prefect/server/api/run_history.py +++ b/src/prefect/server/api/run_history.py @@ -4,7 +4,7 @@ import datetime import json -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional import pydantic import sqlalchemy as sa @@ -16,7 +16,10 @@ from prefect.server.database import PrefectDBInterface, db_injector from prefect.types import DateTime -logger = get_logger("server.api") +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger("server.api") @db_injector diff --git a/src/prefect/server/api/saved_searches.py b/src/prefect/server/api/saved_searches.py index fce131484590..786c8df03533 100644 --- a/src/prefect/server/api/saved_searches.py +++ b/src/prefect/server/api/saved_searches.py @@ -14,7 +14,7 @@ from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter -router = PrefectRouter(prefix="/saved_searches", tags=["SavedSearches"]) +router: PrefectRouter = PrefectRouter(prefix="/saved_searches", tags=["SavedSearches"]) @router.put("/") @@ -85,7 +85,7 @@ async def read_saved_searches( async def delete_saved_search( saved_search_id: UUID = Path(..., description="The saved search id", alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Delete a saved search by id. """ diff --git a/src/prefect/server/api/server.py b/src/prefect/server/api/server.py index 908ae5c3afc8..55f6b26f56e9 100644 --- a/src/prefect/server/api/server.py +++ b/src/prefect/server/api/server.py @@ -21,7 +21,7 @@ from contextlib import asynccontextmanager from functools import partial, wraps from hashlib import sha256 -from typing import Any, Awaitable, Callable, Optional +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional import anyio import asyncpg @@ -66,17 +66,20 @@ ) from prefect.utilities.hashing import hash_objects +if TYPE_CHECKING: + import logging + TITLE = "Prefect Server" API_TITLE = "Prefect Prefect REST API" UI_TITLE = "Prefect Prefect REST API UI" -API_VERSION = prefect.__version__ +API_VERSION: str = prefect.__version__ # migrations should run only once per app start; the ephemeral API can potentially # create multiple apps in a single process LIFESPAN_RAN_FOR_APP: set[Any] = set() -logger = get_logger("server") +logger: "logging.Logger" = get_logger("server") -enforce_minimum_version = EnforceMinimumAPIVersion( +enforce_minimum_version: EnforceMinimumAPIVersion = EnforceMinimumAPIVersion( # this should be <= SERVER_API_VERSION; clients that send # a version header under this value will be rejected minimum_api_version="0.8.0", @@ -154,7 +157,9 @@ async def __call__(self, scope: Any, receive: Any, send: Any) -> None: await self.app(scope, receive, send) -async def validation_exception_handler(request: Request, exc: RequestValidationError): +async def validation_exception_handler( + request: Request, exc: RequestValidationError +) -> JSONResponse: """Provide a detailed message for request validation errors.""" return JSONResponse( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, @@ -168,7 +173,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE ) -async def integrity_exception_handler(request: Request, exc: Exception): +async def integrity_exception_handler(request: Request, exc: Exception) -> JSONResponse: """Capture database integrity errors.""" logger.error("Encountered exception in request:", exc_info=True) return JSONResponse( @@ -183,7 +188,7 @@ async def integrity_exception_handler(request: Request, exc: Exception): ) -def is_client_retryable_exception(exc: Exception): +def is_client_retryable_exception(exc: Exception) -> bool: if isinstance(exc, sqlalchemy.exc.OperationalError) and isinstance( exc.orig, sqlite3.OperationalError ): @@ -217,7 +222,7 @@ def replace_placeholder_string_in_files( placeholder: str, replacement: str, allowed_extensions: list[str] | None = None, -): +) -> None: """ Recursively loops through all files in the given directory and replaces a placeholder string. @@ -253,7 +258,9 @@ def copy_directory(directory: str, path: str) -> None: shutil.copy2(source, destination) -async def custom_internal_exception_handler(request: Request, exc: Exception): +async def custom_internal_exception_handler( + request: Request, exc: Exception +) -> JSONResponse: """ Log a detailed exception for internal server errors before returning. @@ -278,7 +285,7 @@ async def custom_internal_exception_handler(request: Request, exc: Exception): async def prefect_object_not_found_exception_handler( request: Request, exc: ObjectNotFoundError -): +) -> JSONResponse: """Return 404 status code on object not found exceptions.""" return JSONResponse( content={"exception_message": str(exc)}, status_code=status.HTTP_404_NOT_FOUND @@ -353,6 +360,9 @@ async def server_version() -> str: # type: ignore[reportUnusedFunction] async def token_validation(request: Request, call_next: Any): # type: ignore[reportUnusedFunction] header_token = request.headers.get("Authorization") + # used for probes in k8s and such + if request.url.path in ["/api/health", "/api/ready"]: + return await call_next(request) try: if header_token is None: return JSONResponse( @@ -787,7 +797,7 @@ def openapi(): return app -subprocess_server_logger = get_logger() +subprocess_server_logger: "logging.Logger" = get_logger() class SubprocessASGIServer: @@ -808,12 +818,11 @@ def __init__(self, port: Optional[int] = None): # This ensures initialization happens only once if not hasattr(self, "_initialized"): self.port: Optional[int] = port - self.server_process = None - self.server = None - self.running = False + self.server_process: subprocess.Popen[Any] | None = None + self.running: bool = False self._initialized = True - def find_available_port(self): + def find_available_port(self) -> int: max_attempts = 10 for _ in range(max_attempts): port = random.choice(self._port_range) @@ -823,7 +832,7 @@ def find_available_port(self): raise RuntimeError("Unable to find an available port after multiple attempts") @staticmethod - def is_port_available(port: int): + def is_port_available(port: int) -> bool: with contextlib.closing( socket.socket(socket.AF_INET, socket.SOCK_STREAM) ) as sock: @@ -841,7 +850,7 @@ def address(self) -> str: def api_url(self) -> str: return f"{self.address}/api" - def start(self, timeout: Optional[int] = None): + def start(self, timeout: Optional[int] = None) -> None: """ Start the server in a separate process. Safe to call multiple times; only starts the server once. @@ -935,7 +944,7 @@ def _run_uvicorn_command(self) -> subprocess.Popen[Any]: }, ) - def stop(self): + def stop(self) -> None: if self.server_process: subprocess_server_logger.info( f"Stopping temporary server on {self.address}" diff --git a/src/prefect/server/api/task_run_states.py b/src/prefect/server/api/task_run_states.py index ef68e8b4d7a0..9722a1d3d93b 100644 --- a/src/prefect/server/api/task_run_states.py +++ b/src/prefect/server/api/task_run_states.py @@ -12,7 +12,9 @@ from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter -router = PrefectRouter(prefix="/task_run_states", tags=["Task Run States"]) +router: PrefectRouter = PrefectRouter( + prefix="/task_run_states", tags=["Task Run States"] +) @router.get("/{id}") diff --git a/src/prefect/server/api/task_runs.py b/src/prefect/server/api/task_runs.py index 69d5e14adbd6..316eb2551e5d 100644 --- a/src/prefect/server/api/task_runs.py +++ b/src/prefect/server/api/task_runs.py @@ -4,7 +4,7 @@ import asyncio import datetime -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional from uuid import UUID import pendulum @@ -27,17 +27,19 @@ from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.orchestration import dependencies as orchestration_dependencies from prefect.server.orchestration.core_policy import CoreTaskPolicy -from prefect.server.orchestration.policies import BaseOrchestrationPolicy +from prefect.server.orchestration.policies import TaskRunOrchestrationPolicy from prefect.server.schemas.responses import OrchestrationResult from prefect.server.task_queue import MultiQueue, TaskQueue from prefect.server.utilities import subscriptions from prefect.server.utilities.server import PrefectRouter from prefect.types import DateTime -logger = get_logger("server.api") +if TYPE_CHECKING: + import logging +logger: "logging.Logger" = get_logger("server.api") -router = PrefectRouter(prefix="/task_runs", tags=["Task Runs"]) +router: PrefectRouter = PrefectRouter(prefix="/task_runs", tags=["Task Runs"]) @router.post("/") @@ -87,7 +89,7 @@ async def update_task_run( task_run: schemas.actions.TaskRunUpdate, task_run_id: UUID = Path(..., description="The task run id", alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Updates a task run. """ @@ -214,7 +216,7 @@ async def read_task_runs( async def delete_task_run( task_run_id: UUID = Path(..., description="The task run id", alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Delete a task run by id. """ @@ -239,7 +241,7 @@ async def set_task_run_state( ), db: PrefectDBInterface = Depends(provide_database_interface), response: Response = None, - task_policy: BaseOrchestrationPolicy = Depends( + task_policy: TaskRunOrchestrationPolicy = Depends( orchestration_dependencies.provide_task_policy ), orchestration_parameters: Dict[str, Any] = Depends( @@ -275,7 +277,7 @@ async def set_task_run_state( @router.websocket("/subscriptions/scheduled") -async def scheduled_task_subscription(websocket: WebSocket): +async def scheduled_task_subscription(websocket: WebSocket) -> None: websocket = await subscriptions.accept_prefect_socket(websocket) if not websocket: return diff --git a/src/prefect/server/api/task_workers.py b/src/prefect/server/api/task_workers.py index b3ebc3edb6ff..cc8d30a8cea9 100644 --- a/src/prefect/server/api/task_workers.py +++ b/src/prefect/server/api/task_workers.py @@ -7,7 +7,7 @@ from prefect.server.models.task_workers import TaskWorkerResponse from prefect.server.utilities.server import PrefectRouter -router = PrefectRouter(prefix="/task_workers", tags=["Task Workers"]) +router: PrefectRouter = PrefectRouter(prefix="/task_workers", tags=["Task Workers"]) class TaskWorkerFilter(BaseModel): diff --git a/src/prefect/server/api/templates.py b/src/prefect/server/api/templates.py index 29b954aac56e..55c927eb962f 100644 --- a/src/prefect/server/api/templates.py +++ b/src/prefect/server/api/templates.py @@ -8,7 +8,7 @@ validate_user_template, ) -router = PrefectRouter(prefix="/templates", tags=["Automations"]) +router: PrefectRouter = PrefectRouter(prefix="/templates", tags=["Automations"]) @router.post( diff --git a/src/prefect/server/api/ui/flow_runs.py b/src/prefect/server/api/ui/flow_runs.py index a67999f92919..f443c4dcc8fc 100644 --- a/src/prefect/server/api/ui/flow_runs.py +++ b/src/prefect/server/api/ui/flow_runs.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import datetime -from typing import List +from typing import TYPE_CHECKING, List from uuid import UUID import sqlalchemy as sa @@ -14,9 +16,12 @@ from prefect.server.utilities.server import PrefectRouter from prefect.types import DateTime -logger = get_logger("server.api.ui.flow_runs") +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger("server.api.ui.flow_runs") -router = PrefectRouter(prefix="/ui/flow_runs", tags=["Flow Runs", "UI"]) +router: PrefectRouter = PrefectRouter(prefix="/ui/flow_runs", tags=["Flow Runs", "UI"]) class SimpleFlowRun(PrefectBaseModel): diff --git a/src/prefect/server/api/ui/flows.py b/src/prefect/server/api/ui/flows.py index 8c0b3375d480..e78dad89c674 100644 --- a/src/prefect/server/api/ui/flows.py +++ b/src/prefect/server/api/ui/flows.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from datetime import datetime -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from uuid import UUID import pendulum @@ -15,9 +17,12 @@ from prefect.server.utilities.server import PrefectRouter from prefect.types import DateTime -logger = get_logger() +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger() -router = PrefectRouter(prefix="/ui/flows", tags=["Flows", "UI"]) +router: PrefectRouter = PrefectRouter(prefix="/ui/flows", tags=["Flows", "UI"]) class SimpleNextFlowRun(PrefectBaseModel): @@ -32,7 +37,9 @@ class SimpleNextFlowRun(PrefectBaseModel): @field_validator("next_scheduled_start_time", mode="before") @classmethod - def validate_next_scheduled_start_time(cls, v): + def validate_next_scheduled_start_time( + cls, v: pendulum.DateTime | datetime + ) -> pendulum.DateTime: if isinstance(v, datetime): return pendulum.instance(v) return v diff --git a/src/prefect/server/api/ui/schemas.py b/src/prefect/server/api/ui/schemas.py index 768b9c7231d2..91a0bbeb472a 100644 --- a/src/prefect/server/api/ui/schemas.py +++ b/src/prefect/server/api/ui/schemas.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import TYPE_CHECKING, Any, Dict from fastapi import Body, Depends, HTTPException, status @@ -14,9 +14,12 @@ validate, ) -router = APIRouter(prefix="/ui/schemas", tags=["UI", "Schemas"]) +if TYPE_CHECKING: + import logging -logger = get_logger("server.api.ui.schemas") +router: APIRouter = APIRouter(prefix="/ui/schemas", tags=["UI", "Schemas"]) + +logger: "logging.Logger" = get_logger("server.api.ui.schemas") @router.post("/validate") @@ -24,7 +27,7 @@ async def validate_obj( json_schema: Dict[str, Any] = Body(..., embed=True, alias="schema"), values: Dict[str, Any] = Body(..., embed=True), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> dict[str, Any]: schema = preprocess_schema(json_schema) try: diff --git a/src/prefect/server/api/ui/task_runs.py b/src/prefect/server/api/ui/task_runs.py index 938d2a27d57e..6b2de8c9ba5c 100644 --- a/src/prefect/server/api/ui/task_runs.py +++ b/src/prefect/server/api/ui/task_runs.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import List, Optional, cast +from typing import TYPE_CHECKING, List, Optional, cast import pendulum import sqlalchemy as sa @@ -14,9 +14,12 @@ from prefect.server.utilities.server import PrefectRouter from prefect.types import DateTime -logger = get_logger("orion.api.ui.task_runs") +if TYPE_CHECKING: + import logging -router = PrefectRouter(prefix="/ui/task_runs", tags=["Task Runs", "UI"]) +logger: "logging.Logger" = get_logger("server.api.ui.task_runs") + +router: PrefectRouter = PrefectRouter(prefix="/ui/task_runs", tags=["Task Runs", "UI"]) FAILED_STATES = [schemas.states.StateType.CRASHED, schemas.states.StateType.FAILED] @@ -28,7 +31,7 @@ class TaskRunCount(PrefectBaseModel): failed: int = Field(default=..., description="The number of failed task runs.") @model_serializer - def ser_model(self) -> dict: + def ser_model(self) -> dict[str, int]: return { "completed": int(self.completed), "failed": int(self.failed), diff --git a/src/prefect/server/api/validation.py b/src/prefect/server/api/validation.py index ebc085a32e25..3c1905ca571d 100644 --- a/src/prefect/server/api/validation.py +++ b/src/prefect/server/api/validation.py @@ -35,7 +35,7 @@ allow None values, the Pydantic model will fail to validate at runtime. """ -from typing import Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union from uuid import UUID import pydantic @@ -50,7 +50,10 @@ from prefect.server.schemas.core import WorkPool from prefect.utilities.schema_tools import ValidationError, is_valid_schema, validate -logger = get_logger("server.api.validation") +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger("server.api.validation") DeploymentAction = Union[ schemas.actions.DeploymentCreate, schemas.actions.DeploymentUpdate diff --git a/src/prefect/server/api/variables.py b/src/prefect/server/api/variables.py index b05cb8c648e2..e7e5785d3615 100644 --- a/src/prefect/server/api/variables.py +++ b/src/prefect/server/api/variables.py @@ -11,12 +11,18 @@ from prefect.server import models from prefect.server.api.dependencies import LimitBody -from prefect.server.database import PrefectDBInterface, provide_database_interface +from prefect.server.database import ( + PrefectDBInterface, + orm_models, + provide_database_interface, +) from prefect.server.schemas import actions, core, filters, sorting from prefect.server.utilities.server import PrefectRouter -async def get_variable_or_404(session: AsyncSession, variable_id: UUID): +async def get_variable_or_404( + session: AsyncSession, variable_id: UUID +) -> orm_models.Variable: """Returns a variable or raises 404 HTTPException if it does not exist""" variable = await models.variables.read_variable( @@ -28,7 +34,9 @@ async def get_variable_or_404(session: AsyncSession, variable_id: UUID): return variable -async def get_variable_by_name_or_404(session: AsyncSession, name: str): +async def get_variable_by_name_or_404( + session: AsyncSession, name: str +) -> orm_models.Variable: """Returns a variable or raises 404 HTTPException if it does not exist""" variable = await models.variables.read_variable_by_name(session=session, name=name) @@ -38,7 +46,7 @@ async def get_variable_by_name_or_404(session: AsyncSession, name: str): return variable -router = PrefectRouter( +router: PrefectRouter = PrefectRouter( prefix="/variables", tags=["Variables"], ) @@ -120,7 +128,7 @@ async def update_variable( variable: actions.VariableUpdate, variable_id: UUID = Path(..., alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: async with db.session_context(begin_transaction=True) as session: updated = await models.variables.update_variable( session=session, @@ -136,7 +144,7 @@ async def update_variable_by_name( variable: actions.VariableUpdate, name: str = Path(..., alias="name"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: async with db.session_context(begin_transaction=True) as session: updated = await models.variables.update_variable_by_name( session=session, @@ -151,7 +159,7 @@ async def update_variable_by_name( async def delete_variable( variable_id: UUID = Path(..., alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: async with db.session_context(begin_transaction=True) as session: deleted = await models.variables.delete_variable( session=session, variable_id=variable_id @@ -164,7 +172,7 @@ async def delete_variable( async def delete_variable_by_name( name: str = Path(...), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: async with db.session_context(begin_transaction=True) as session: deleted = await models.variables.delete_variable_by_name( session=session, name=name diff --git a/src/prefect/server/api/work_queues.py b/src/prefect/server/api/work_queues.py index ad9e17df1645..1267fe7ae87a 100644 --- a/src/prefect/server/api/work_queues.py +++ b/src/prefect/server/api/work_queues.py @@ -33,7 +33,7 @@ from prefect.server.utilities.server import PrefectRouter from prefect.types import DateTime -router = PrefectRouter(prefix="/work_queues", tags=["Work Queues"]) +router: PrefectRouter = PrefectRouter(prefix="/work_queues", tags=["Work Queues"]) @router.post("/", status_code=status.HTTP_201_CREATED) @@ -69,7 +69,7 @@ async def update_work_queue( work_queue: schemas.actions.WorkQueueUpdate, work_queue_id: UUID = Path(..., description="The work queue id", alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Updates an existing work queue. """ @@ -229,7 +229,7 @@ async def read_work_queues( async def delete_work_queue( work_queue_id: UUID = Path(..., description="The work queue id", alias="id"), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Delete a work queue by id. """ diff --git a/src/prefect/server/api/workers.py b/src/prefect/server/api/workers.py index 0f0d8546dd19..7d46d16bec42 100644 --- a/src/prefect/server/api/workers.py +++ b/src/prefect/server/api/workers.py @@ -35,7 +35,7 @@ if TYPE_CHECKING: from prefect.server.database.orm_models import ORMWorkQueue -router = PrefectRouter( +router: PrefectRouter = PrefectRouter( prefix="/work_pools", tags=["Work Pools"], ) @@ -257,7 +257,7 @@ async def update_work_pool( work_pool_name: str = Path(..., description="The work pool name", alias="name"), worker_lookups: WorkerLookups = Depends(WorkerLookups), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Update a work pool """ @@ -292,7 +292,7 @@ async def delete_work_pool( work_pool_name: str = Path(..., description="The work pool name", alias="name"), worker_lookups: WorkerLookups = Depends(WorkerLookups), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Delete a work pool """ @@ -505,7 +505,7 @@ async def update_work_queue( ), worker_lookups: WorkerLookups = Depends(WorkerLookups), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Update a work pool queue """ @@ -535,7 +535,7 @@ async def delete_work_queue( ), worker_lookups: WorkerLookups = Depends(WorkerLookups), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Delete a work pool queue """ @@ -573,7 +573,7 @@ async def worker_heartbeat( ), worker_lookups: WorkerLookups = Depends(WorkerLookups), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: async with db.session_context(begin_transaction=True) as session: work_pool = await models.workers.read_work_pool_by_name( session=session, @@ -638,7 +638,7 @@ async def delete_worker( ), worker_lookups: WorkerLookups = Depends(WorkerLookups), db: PrefectDBInterface = Depends(provide_database_interface), -): +) -> None: """ Delete a work pool's worker """ diff --git a/src/prefect/server/models/block_registration.py b/src/prefect/server/models/block_registration.py index f537f75966b7..76cdb66b897d 100644 --- a/src/prefect/server/models/block_registration.py +++ b/src/prefect/server/models/block_registration.py @@ -13,10 +13,12 @@ from prefect.server import models, schemas if TYPE_CHECKING: + import logging + from prefect.client.schemas import BlockSchema as ClientBlockSchema from prefect.client.schemas import BlockType as ClientBlockType -logger = get_logger("server") +logger: "logging.Logger" = get_logger("server") COLLECTIONS_BLOCKS_DATA_PATH = ( Path(__file__).parent.parent / "collection_blocks_data.json" diff --git a/src/prefect/server/models/block_schemas.py b/src/prefect/server/models/block_schemas.py index 264dc4913fc8..dee1cb34f4c9 100644 --- a/src/prefect/server/models/block_schemas.py +++ b/src/prefect/server/models/block_schemas.py @@ -40,7 +40,7 @@ async def create_block_schema( "ClientBlockSchema", ], override: bool = False, - definitions: Optional[Dict] = None, + definitions: Optional[dict[str, Any]] = None, ) -> Union[BlockSchema, orm_models.BlockSchema]: """ Create a new block schema. diff --git a/src/prefect/server/models/flow_runs.py b/src/prefect/server/models/flow_runs.py index fca96360dbfa..d5543d8d98d1 100644 --- a/src/prefect/server/models/flow_runs.py +++ b/src/prefect/server/models/flow_runs.py @@ -7,6 +7,7 @@ import datetime from itertools import chain from typing import ( + TYPE_CHECKING, Any, Dict, List, @@ -49,13 +50,13 @@ ) from prefect.types import KeyValueLabels -logger = get_logger("flow_runs") +if TYPE_CHECKING: + import logging +logger: "logging.Logger" = get_logger("flow_runs") -logger = get_logger("flow_runs") - -T = TypeVar("T", bound=tuple) +T = TypeVar("T", bound=tuple[Any, ...]) @db_injector @@ -281,7 +282,7 @@ async def _apply_flow_run_filters( async def read_flow_runs( db: PrefectDBInterface, session: AsyncSession, - columns: Optional[List] = None, + columns: Optional[list[str]] = None, flow_filter: Optional[schemas.filters.FlowFilter] = None, flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None, task_run_filter: Optional[schemas.filters.TaskRunFilter] = None, @@ -344,7 +345,7 @@ async def read_flow_runs( async def cleanup_flow_run_concurrency_slots( session: AsyncSession, flow_run: orm_models.FlowRun, -): +) -> None: """ Cleanup flow run related resources, such as releasing concurrency slots. All operations should be idempotent and safe to call multiple times. diff --git a/src/prefect/server/models/logs.py b/src/prefect/server/models/logs.py index ce5eac18eace..5dda835c9de9 100644 --- a/src/prefect/server/models/logs.py +++ b/src/prefect/server/models/logs.py @@ -3,7 +3,7 @@ Intended for internal use by the Prefect REST API. """ -from typing import Generator, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Generator, List, Optional, Sequence, Tuple from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -23,7 +23,10 @@ # ...so we can only INSERT batches of a certain size at a time LOG_BATCH_SIZE = MAXIMUM_QUERY_PARAMETERS // NUMBER_OF_LOG_FIELDS -logger = get_logger(__name__) +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger(__name__) def split_logs_into_batches( diff --git a/src/prefect/server/models/task_runs.py b/src/prefect/server/models/task_runs.py index c821b6c4b43b..ae7f54b7889f 100644 --- a/src/prefect/server/models/task_runs.py +++ b/src/prefect/server/models/task_runs.py @@ -4,7 +4,17 @@ """ import contextlib -from typing import Any, Dict, Optional, Sequence, Type, TypeVar, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Optional, + Sequence, + Type, + TypeVar, + Union, + cast, +) from uuid import UUID import pendulum @@ -29,9 +39,12 @@ from prefect.server.orchestration.rules import TaskOrchestrationContext from prefect.server.schemas.responses import OrchestrationResult -T = TypeVar("T", bound=tuple) +if TYPE_CHECKING: + import logging + +T = TypeVar("T", bound=tuple[Any, ...]) -logger = get_logger("server") +logger: "logging.Logger" = get_logger("server") @db_injector diff --git a/src/prefect/server/models/task_workers.py b/src/prefect/server/models/task_workers.py index 899eba98b4d9..4a285feee54f 100644 --- a/src/prefect/server/models/task_workers.py +++ b/src/prefect/server/models/task_workers.py @@ -77,7 +77,7 @@ def reset(self) -> None: # Global instance of the task worker tracker -task_worker_tracker = InMemoryTaskWorkerTracker() +task_worker_tracker: InMemoryTaskWorkerTracker = InMemoryTaskWorkerTracker() # Main utilities to be used in the API layer diff --git a/src/prefect/server/schemas/actions.py b/src/prefect/server/schemas/actions.py index 30746fc529f9..1fd1c20b5ec7 100644 --- a/src/prefect/server/schemas/actions.py +++ b/src/prefect/server/schemas/actions.py @@ -2,6 +2,8 @@ Reduced schemas for accepting API actions. """ +from __future__ import annotations + import json from copy import deepcopy from typing import Any, ClassVar, Dict, List, Optional, Union @@ -44,24 +46,24 @@ from prefect.utilities.templating import find_placeholders -def validate_block_type_slug(value): +def validate_block_type_slug(value: str) -> str: raise_on_name_alphanumeric_dashes_only(value, field_name="Block type slug") return value -def validate_block_document_name(value): +def validate_block_document_name(value: str | None) -> str | None: if value is not None: raise_on_name_alphanumeric_dashes_only(value, field_name="Block document name") return value -def validate_artifact_key(value): +def validate_artifact_key(value: str | None) -> str | None: if value is not None: raise_on_name_alphanumeric_dashes_only(value, field_name="Artifact key") return value -def validate_variable_name(value): +def validate_variable_name(value: str) -> str: raise_on_name_alphanumeric_underscores_only(value, field_name="Variable name") return value @@ -112,7 +114,9 @@ class DeploymentScheduleCreate(ActionBaseModel): @field_validator("max_scheduled_runs") @classmethod - def validate_max_scheduled_runs(cls, v): + def validate_max_scheduled_runs( + cls, v: PositiveInteger | None + ) -> PositiveInteger | None: return validate_schedule_max_scheduled_runs( v, PREFECT_DEPLOYMENT_SCHEDULE_MAX_SCHEDULED_RUNS.value() ) @@ -133,7 +137,9 @@ class DeploymentScheduleUpdate(ActionBaseModel): @field_validator("max_scheduled_runs") @classmethod - def validate_max_scheduled_runs(cls, v): + def validate_max_scheduled_runs( + cls, v: PositiveInteger | None + ) -> PositiveInteger | None: return validate_schedule_max_scheduled_runs( v, PREFECT_DEPLOYMENT_SCHEDULE_MAX_SCHEDULED_RUNS.value() ) @@ -187,7 +193,7 @@ class DeploymentCreate(ActionBaseModel): description="A dictionary of key-value labels. Values can be strings, numbers, or booleans.", examples=[{"key": "value1", "key2": 42}], ) - pull_steps: Optional[List[dict]] = Field(None) + pull_steps: Optional[List[dict[str, Any]]] = Field(None) work_queue_name: Optional[str] = Field(None) work_pool_name: Optional[str] = Field( @@ -206,7 +212,7 @@ class DeploymentCreate(ActionBaseModel): description="Overrides for the flow's infrastructure configuration.", ) - def check_valid_configuration(self, base_job_template: dict): + def check_valid_configuration(self, base_job_template: dict[str, Any]) -> None: """ Check that the combination of base_job_template defaults and job_variables conforms to the specified schema. @@ -232,11 +238,13 @@ def check_valid_configuration(self, base_job_template: dict): @model_validator(mode="before") @classmethod - def remove_old_fields(cls, values): + def remove_old_fields(cls, values: dict[str, Any]) -> dict[str, Any]: return remove_old_deployment_fields(values) @model_validator(mode="before") - def _validate_parameters_conform_to_schema(cls, values): + def _validate_parameters_conform_to_schema( + cls, values: dict[str, Any] + ) -> dict[str, Any]: values["parameters"] = validate_parameters_conform_to_schema( values.get("parameters", {}), values ) @@ -251,7 +259,7 @@ class DeploymentUpdate(ActionBaseModel): @model_validator(mode="before") @classmethod - def remove_old_fields(cls, values): + def remove_old_fields(cls, values: dict[str, Any]) -> dict[str, Any]: return remove_old_deployment_fields(values) version: Optional[str] = Field(None) @@ -300,7 +308,7 @@ def remove_old_fields(cls, values): ) model_config: ClassVar[ConfigDict] = ConfigDict(populate_by_name=True) - def check_valid_configuration(self, base_job_template: dict): + def check_valid_configuration(self, base_job_template: dict[str, Any]) -> None: """ Check that the combination of base_job_template defaults and job_variables conforms to the schema specified in the base_job_template. @@ -315,7 +323,7 @@ def check_valid_configuration(self, base_job_template: dict): variables_schema = deepcopy(base_job_template.get("variables")) - if variables_schema is not None: + if variables_schema is not None and self.job_variables is not None: errors = validate( self.job_variables, variables_schema, @@ -343,7 +351,7 @@ class FlowRunUpdate(ActionBaseModel): @field_validator("name", mode="before") @classmethod - def set_name(cls, name): + def set_name(cls, name: str) -> str: return get_or_create_run_name(name) @@ -459,12 +467,12 @@ class TaskRunCreate(ActionBaseModel): @field_validator("name", mode="before") @classmethod - def set_name(cls, name): + def set_name(cls, name: str) -> str: return get_or_create_run_name(name) @field_validator("cache_key") @classmethod - def validate_cache_key(cls, cache_key): + def validate_cache_key(cls, cache_key: str | None) -> str | None: return validate_cache_key_length(cache_key) @@ -477,7 +485,7 @@ class TaskRunUpdate(ActionBaseModel): @field_validator("name", mode="before") @classmethod - def set_name(cls, name): + def set_name(cls, name: str) -> str: return get_or_create_run_name(name) @@ -544,7 +552,7 @@ class FlowRunCreate(ActionBaseModel): @field_validator("name", mode="before") @classmethod - def set_name(cls, name): + def set_name(cls, name: str) -> str: return get_or_create_run_name(name) @@ -683,7 +691,7 @@ class BlockTypeUpdate(ActionBaseModel): code_example: Optional[str] = Field(None) @classmethod - def updatable_fields(cls) -> set: + def updatable_fields(cls) -> set[str]: return get_class_fields_only(cls) @@ -780,7 +788,7 @@ class LogCreate(ActionBaseModel): task_run_id: Optional[UUID] = Field(None) -def validate_base_job_template(v): +def validate_base_job_template(v: dict[str, Any]) -> dict[str, Any]: if v == dict(): return v @@ -791,7 +799,7 @@ def validate_base_job_template(v): "The `base_job_template` must contain both a `job_configuration` key" " and a `variables` key." ) - template_variables = set() + template_variables: set[str] = set() for template in job_config.values(): # find any variables inside of double curly braces, minus any whitespace # e.g. "{{ var1 }}.{{var2}}" -> ["var1", "var2"] @@ -927,7 +935,7 @@ class FlowRunNotificationPolicyCreate(ActionBaseModel): @field_validator("message_template") @classmethod - def validate_message_template_variables(cls, v): + def validate_message_template_variables(cls, v: str | None) -> str | None: return validate_message_template_variables(v) @@ -942,7 +950,7 @@ class FlowRunNotificationPolicyUpdate(ActionBaseModel): @field_validator("message_template") @classmethod - def validate_message_template_variables(cls, v): + def validate_message_template_variables(cls, v: str | None) -> str | None: return validate_message_template_variables(v) @@ -984,8 +992,8 @@ class ArtifactCreate(ActionBaseModel): ) @classmethod - def from_result(cls, data: Any): - artifact_info = dict() + def from_result(cls, data: Any | dict[str, Any]) -> "ArtifactCreate": + artifact_info: dict[str, Any] = dict() if isinstance(data, dict): artifact_key = data.pop("artifact_key", None) if artifact_key: diff --git a/src/prefect/server/schemas/core.py b/src/prefect/server/schemas/core.py index 7b245446f335..029455c668d8 100644 --- a/src/prefect/server/schemas/core.py +++ b/src/prefect/server/schemas/core.py @@ -559,7 +559,7 @@ class DeploymentSchedule(ORMBaseModel): @field_validator("max_scheduled_runs") @classmethod - def validate_max_scheduled_runs(cls, v): + def validate_max_scheduled_runs(cls, v: int) -> int: return validate_schedule_max_scheduled_runs( v, PREFECT_DEPLOYMENT_SCHEDULE_MAX_SCHEDULED_RUNS.value() ) @@ -1069,7 +1069,7 @@ class FlowRunNotificationPolicy(ORMBaseModel): @field_validator("message_template") @classmethod - def validate_message_template_variables(cls, v): + def validate_message_template_variables(cls, v: str) -> str: return validate_message_template_variables(v) @@ -1187,7 +1187,7 @@ class Artifact(ORMBaseModel): " the artifact type." ), ) - metadata_: Optional[Dict[str, str]] = Field( + metadata_: Optional[dict[str, str]] = Field( default=None, description=( "User-defined artifact metadata. Content must be string key and value" @@ -1202,8 +1202,8 @@ class Artifact(ORMBaseModel): ) @classmethod - def from_result(cls, data: Any): - artifact_info = dict() + def from_result(cls, data: Any | dict[str, Any]) -> "Artifact": + artifact_info: dict[str, Any] = dict() if isinstance(data, dict): artifact_key = data.pop("artifact_key", None) if artifact_key: @@ -1221,7 +1221,7 @@ def from_result(cls, data: Any): @field_validator("metadata_") @classmethod - def validate_metadata_length(cls, v): + def validate_metadata_length(cls, v: dict[str, str]) -> dict[str, str]: return validate_max_metadata_length(v) @@ -1289,7 +1289,7 @@ class FlowRunInput(ORMBaseModel): @field_validator("key", check_fields=False) @classmethod - def validate_name_characters(cls, v): + def validate_name_characters(cls, v: str) -> str: raise_on_name_alphanumeric_dashes_only(v) return v diff --git a/src/prefect/server/schemas/filters.py b/src/prefect/server/schemas/filters.py index 2b1c63452b08..445dc31d495b 100644 --- a/src/prefect/server/schemas/filters.py +++ b/src/prefect/server/schemas/filters.py @@ -676,8 +676,8 @@ class FlowRunFilter(PrefectOperatorFilterBaseModel): default=None, description="Filter criteria for `FlowRun.idempotency_key`" ) - def only_filters_on_id(self): - return ( + def only_filters_on_id(self) -> bool: + return bool( self.id is not None and (self.id.any_ and not self.id.not_any_) and self.name is None diff --git a/src/prefect/server/schemas/responses.py b/src/prefect/server/schemas/responses.py index 4df1ab929fed..80b74f2e1b76 100644 --- a/src/prefect/server/schemas/responses.py +++ b/src/prefect/server/schemas/responses.py @@ -410,7 +410,7 @@ class DeploymentResponse(ORMBaseModel): " storage or an absolute path." ), ) - pull_steps: Optional[List[dict]] = Field( + pull_steps: Optional[list[dict[str, Any]]] = Field( default=None, description="Pull steps for cloning and running this deployment." ) entrypoint: Optional[str] = Field( diff --git a/src/prefect/server/schemas/schedules.py b/src/prefect/server/schemas/schedules.py index 0466643f784f..438c07d8b6db 100644 --- a/src/prefect/server/schemas/schedules.py +++ b/src/prefect/server/schemas/schedules.py @@ -1,6 +1,7 @@ """ Schedule schemas """ +from __future__ import annotations import datetime from typing import Annotated, Any, ClassVar, Generator, List, Optional, Tuple, Union @@ -225,7 +226,7 @@ def validate_timezone(self): @field_validator("cron") @classmethod - def valid_cron_string(cls, v): + def valid_cron_string(cls, v: str) -> str: return validate_cron_string(v) async def get_dates( @@ -368,11 +369,13 @@ class RRuleSchedule(PrefectBaseModel): @field_validator("rrule") @classmethod - def validate_rrule_str(cls, v): + def validate_rrule_str(cls, v: str) -> str: return validate_rrule_string(v) @classmethod - def from_rrule(cls, rrule: dateutil.rrule.rrule): + def from_rrule( + cls, rrule: dateutil.rrule.rrule | dateutil.rrule.rruleset + ) -> "RRuleSchedule": if isinstance(rrule, dateutil.rrule.rrule): if rrule._dtstart.tzinfo is not None: timezone = rrule._dtstart.tzinfo.name diff --git a/src/prefect/server/schemas/states.py b/src/prefect/server/schemas/states.py index 2ac202c70d4f..d0ac008f04f0 100644 --- a/src/prefect/server/schemas/states.py +++ b/src/prefect/server/schemas/states.py @@ -63,7 +63,7 @@ class CountByState(PrefectBaseModel): @field_validator("*") @classmethod - def check_key(cls, value: Optional[Any], info: ValidationInfo): + def check_key(cls, value: Optional[Any], info: ValidationInfo) -> Optional[Any]: if info.field_name not in StateType.__members__: raise ValueError(f"{info.field_name} is not a valid StateType") return value diff --git a/src/prefect/server/task_queue.py b/src/prefect/server/task_queue.py index f63de8607f80..5854eda8ff8d 100644 --- a/src/prefect/server/task_queue.py +++ b/src/prefect/server/task_queue.py @@ -38,7 +38,7 @@ def configure_task_key( task_key: str, scheduled_size: Optional[int] = None, retry_size: Optional[int] = None, - ): + ) -> None: scheduled_size = scheduled_size or cls.default_scheduled_max_size retry_size = retry_size or cls.default_retry_max_size cls._queue_size_configs[task_key] = (scheduled_size, retry_size) diff --git a/src/prefect/server/utilities/database.py b/src/prefect/server/utilities/database.py index 80ff10db658e..d133037f7bad 100644 --- a/src/prefect/server/utilities/database.py +++ b/src/prefect/server/utilities/database.py @@ -5,6 +5,8 @@ allow the Prefect REST API to seamlessly switch between the two. """ +from __future__ import annotations + import datetime import json import operator @@ -15,6 +17,7 @@ Any, Callable, Optional, + Type, TypeVar, Union, overload, @@ -101,8 +104,8 @@ class Timestamp(TypeDecorator[pendulum.DateTime]): as naive timestamps without timezones) and recovered as UTC. """ - impl = sa.TIMESTAMP(timezone=True) - cache_ok = True + impl: TypeEngine[Any] | type[TypeEngine[Any]] = sa.TIMESTAMP(timezone=True) + cache_ok: bool | None = True def load_dialect_impl(self, dialect: sa.Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": @@ -150,8 +153,8 @@ class UUID(TypeDecorator[uuid.UUID]): hyphens. """ - impl = TypeEngine - cache_ok = True + impl: type[TypeEngine[Any]] | TypeEngine[Any] = TypeEngine + cache_ok: bool | None = True def load_dialect_impl(self, dialect: sa.Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": @@ -191,8 +194,10 @@ class JSON(TypeDecorator[Any]): to SQL compilation """ - impl = postgresql.JSONB - cache_ok = True + impl: type[postgresql.JSONB] | type[TypeEngine[Any]] | TypeEngine[ + Any + ] = postgresql.JSONB + cache_ok: bool | None = True def load_dialect_impl(self, dialect: sa.Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": @@ -230,7 +235,7 @@ class Pydantic(TypeDecorator[T]): """ impl = JSON - cache_ok = True + cache_ok: bool | None = True @overload def __init__( @@ -259,7 +264,9 @@ def __init__( super().__init__() self._pydantic_type = pydantic_type if sa_column_type is not None: - self.impl = sa_column_type + self.impl: type[JSON] | type[TypeEngine[Any]] | TypeEngine[ + Any + ] = sa_column_type def process_bind_param( self, value: Optional[T], dialect: sa.Dialect @@ -308,8 +315,8 @@ def bindparams_from_clause( class date_add(functions.GenericFunction[pendulum.DateTime]): """Platform-independent way to add a timestamp and an interval""" - type = Timestamp() - inherit_cache = True + type: Timestamp = Timestamp() + inherit_cache: bool = True def __init__( self, @@ -327,8 +334,8 @@ def __init__( class interval_add(functions.GenericFunction[datetime.timedelta]): """Platform-independent way to add two intervals.""" - type = sa.Interval() - inherit_cache = True + type: sa.Interval = sa.Interval() + inherit_cache: bool = True def __init__( self, @@ -346,8 +353,8 @@ def __init__( class date_diff(functions.GenericFunction[datetime.timedelta]): """Platform-independent difference of two timestamps. Computes d1 - d2.""" - type = sa.Interval() - inherit_cache = True + type: sa.Interval = sa.Interval() + inherit_cache: bool = True def __init__( self, @@ -363,8 +370,8 @@ def __init__( class date_diff_seconds(functions.GenericFunction[float]): """Platform-independent calculation of the number of seconds between two timestamps or from 'now'""" - type = sa.REAL - inherit_cache = True + type: Type[sa.REAL[float]] = sa.REAL + inherit_cache: bool = True def __init__( self, @@ -664,7 +671,7 @@ def sqlite_json_operators( class greatest(functions.ReturnTypeFromArgs[T]): - inherit_cache = True + inherit_cache: bool = True @compiles(greatest, "sqlite") diff --git a/src/prefect/server/utilities/messaging/__init__.py b/src/prefect/server/utilities/messaging/__init__.py index 115f392dc512..9b2103c05660 100644 --- a/src/prefect/server/utilities/messaging/__init__.py +++ b/src/prefect/server/utilities/messaging/__init__.py @@ -2,13 +2,25 @@ from contextlib import asynccontextmanager, AbstractAsyncContextManager from dataclasses import dataclass import importlib -from typing import Any, Callable, Optional, Protocol, TypeVar, Union, runtime_checkable +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Optional, + Protocol, + TypeVar, + Union, + runtime_checkable, +) from collections.abc import AsyncGenerator, Awaitable, Iterable, Mapping from prefect.settings import PREFECT_MESSAGING_CACHE, PREFECT_MESSAGING_BROKER from prefect.logging import get_logger -logger = get_logger(__name__) +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger(__name__) M = TypeVar("M", bound="Message", covariant=True) @@ -77,13 +89,13 @@ def __init__( deduplicate_by: Optional[str] = None, ) -> None: self.topic = topic - self.cache = cache or create_cache() + self.cache: Cache = cache or create_cache() self.deduplicate_by = deduplicate_by async def __aexit__(self, *args: Any) -> None: pass - async def publish_data(self, data: bytes, attributes: Mapping[str, str]): + async def publish_data(self, data: bytes, attributes: Mapping[str, str]) -> None: to_publish = [CapturedMessage(data, attributes)] if self.deduplicate_by: diff --git a/src/prefect/server/utilities/messaging/memory.py b/src/prefect/server/utilities/messaging/memory.py index 95934375a7d5..7d1cc4a8d6a2 100644 --- a/src/prefect/server/utilities/messaging/memory.py +++ b/src/prefect/server/utilities/messaging/memory.py @@ -5,7 +5,7 @@ from dataclasses import asdict, dataclass from datetime import timedelta from pathlib import Path -from typing import Any, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from uuid import uuid4 import anyio @@ -19,7 +19,10 @@ from prefect.server.utilities.messaging import Publisher as _Publisher from prefect.settings.context import get_current_settings -logger = get_logger(__name__) +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger(__name__) @dataclass @@ -57,7 +60,7 @@ def __init__( ) -> None: self.topic = topic self.max_retries = max_retries - self.dead_letter_queue_path = ( + self.dead_letter_queue_path: Path = ( Path(dead_letter_queue_path) if dead_letter_queue_path else get_current_settings().home / "dlq" @@ -155,7 +158,7 @@ def subscribe(self) -> Subscription: def unsubscribe(self, subscription: Subscription) -> None: self._subscriptions.remove(subscription) - def clear(self): + def clear(self) -> None: for subscription in self._subscriptions: self.unsubscribe(subscription) self._subscriptions = [] @@ -229,7 +232,7 @@ async def forget_duplicates(self, attribute: str, messages: Iterable[M]) -> None class Publisher(_Publisher): def __init__(self, topic: str, cache: Cache, deduplicate_by: Optional[str] = None): - self.topic = Topic.by_name(topic) + self.topic: Topic = Topic.by_name(topic) self.deduplicate_by = deduplicate_by self._cache = cache @@ -254,7 +257,7 @@ async def publish_data(self, data: bytes, attributes: Mapping[str, str]) -> None class Consumer(_Consumer): def __init__(self, topic: str, subscription: Optional[Subscription] = None): - self.topic = Topic.by_name(topic) + self.topic: Topic = Topic.by_name(topic) if not subscription: subscription = self.topic.subscribe() assert subscription.topic is self.topic diff --git a/src/prefect/server/utilities/user_templates.py b/src/prefect/server/utilities/user_templates.py index 8c721789f0cb..320459b03d62 100644 --- a/src/prefect/server/utilities/user_templates.py +++ b/src/prefect/server/utilities/user_templates.py @@ -1,6 +1,6 @@ """Utilities to support safely rendering user-supplied templates""" -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import jinja2.sandbox from jinja2 import ChainableUndefined, nodes @@ -8,7 +8,10 @@ from prefect.logging import get_logger -logger = get_logger(__name__) +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger(__name__) jinja2.sandbox.MAX_RANGE = 100 @@ -46,12 +49,12 @@ def __init__(self, message: Optional[str] = None, line_number: int = 0) -> None: super().__init__(message) -def register_user_template_filters(filters: dict[str, Any]): +def register_user_template_filters(filters: dict[str, Any]) -> None: """Register additional filters that will be available to user templates""" _template_environment.filters.update(filters) -def validate_user_template(template: str): +def validate_user_template(template: str) -> None: root_node = _template_environment.parse(template) _validate_loop_constraints(root_node) diff --git a/src/prefect/states.py b/src/prefect/states.py index d404dfb5e912..aec37d5215f5 100644 --- a/src/prefect/states.py +++ b/src/prefect/states.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import datetime import sys @@ -33,12 +35,14 @@ from prefect.utilities.collections import ensure_iterable if TYPE_CHECKING: + import logging + from prefect.results import ( R, ResultStore, ) -logger = get_logger("states") +logger: "logging.Logger" = get_logger("states") @deprecated.deprecated_parameter( @@ -246,7 +250,7 @@ async def exception_to_failed_state( result_store: Optional["ResultStore"] = None, write_result: bool = False, **kwargs: Any, -) -> State: +) -> State[BaseException]: """ Convenience function for creating `Failed` states from exceptions """ @@ -553,17 +557,17 @@ def is_state_iterable(obj: Any) -> TypeGuard[Iterable[State]]: class StateGroup: - def __init__(self, states: Iterable[State]) -> None: - self.states = states - self.type_counts = self._get_type_counts(states) - self.total_count = len(states) - self.cancelled_count = self.type_counts[StateType.CANCELLED] - self.final_count = sum(state.is_final() for state in states) - self.not_final_count = self.total_count - self.final_count - self.paused_count = self.type_counts[StateType.PAUSED] + def __init__(self, states: list[State]) -> None: + self.states: list[State] = states + self.type_counts: dict[StateType, int] = self._get_type_counts(states) + self.total_count: int = len(states) + self.cancelled_count: int = self.type_counts[StateType.CANCELLED] + self.final_count: int = sum(state.is_final() for state in states) + self.not_final_count: int = self.total_count - self.final_count + self.paused_count: int = self.type_counts[StateType.PAUSED] @property - def fail_count(self): + def fail_count(self) -> int: return self.type_counts[StateType.FAILED] + self.type_counts[StateType.CRASHED] def all_completed(self) -> bool: @@ -741,7 +745,7 @@ def Suspended( pause_expiration_time: Optional[datetime.datetime] = None, pause_key: Optional[str] = None, **kwargs: Any, -): +) -> "State[R]": """Convenience function for creating `Suspended` states. Returns: diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index 5285b9dec1f0..170d37b1987e 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import inspect import logging @@ -28,7 +30,7 @@ import anyio import pendulum from opentelemetry import trace -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, Self from prefect.cache_policies import CachePolicy from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client @@ -123,7 +125,7 @@ class BaseTaskRunEngine(Generic[P, R]): _last_event: Optional[PrefectEvent] = None _telemetry: RunTelemetry = field(default_factory=RunTelemetry) - def __post_init__(self): + def __post_init__(self) -> None: if self.parameters is None: self.parameters = {} @@ -239,7 +241,7 @@ def is_running(self) -> bool: return False return task_run.state.is_running() or task_run.state.is_scheduled() - def log_finished_message(self): + def log_finished_message(self) -> None: if not self.task_run: return @@ -295,6 +297,7 @@ def handle_rollback(self, txn: Transaction) -> None: @dataclass class SyncTaskRunEngine(BaseTaskRunEngine[P, R]): + task_run: Optional[TaskRun] = None _client: Optional[SyncPrefectClient] = None @property @@ -337,7 +340,7 @@ def can_retry(self, exc: Exception) -> bool: ) return False - def call_hooks(self, state: Optional[State] = None): + def call_hooks(self, state: Optional[State] = None) -> None: if state is None: state = self.state task = self.task @@ -372,7 +375,7 @@ def call_hooks(self, state: Optional[State] = None): else: self.logger.info(f"Hook {hook_name!r} finished running successfully") - def begin_run(self): + def begin_run(self) -> None: try: self._resolve_parameters() self._set_custom_task_run_name() @@ -547,7 +550,7 @@ def handle_retry(self, exc: Exception) -> bool: ) self.set_state(new_state, force=True) - self.retries = self.retries + 1 + self.retries: int = self.retries + 1 return True elif self.retries >= self.task.retries: self.logger.error( @@ -641,7 +644,9 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None): stack.enter_context(ConcurrencyContextV1()) stack.enter_context(ConcurrencyContext()) - self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore + self.logger: "logging.Logger" = task_run_logger( + task_run=self.task_run, task=self.task + ) # type: ignore yield @@ -650,7 +655,7 @@ def initialize_run( self, task_run_id: Optional[UUID] = None, dependencies: Optional[dict[str, set[TaskRunInput]]] = None, - ) -> Generator["SyncTaskRunEngine", Any, Any]: + ) -> Generator[Self, Any, Any]: """ Enters a client context and creates a task run if needed. """ @@ -720,7 +725,7 @@ def initialize_run( self._is_started = False self._client = None - async def wait_until_ready(self): + async def wait_until_ready(self) -> None: """Waits until the scheduled time (if its the future), then enters Running.""" if scheduled_time := self.state.state_details.scheduled_time: sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds() @@ -827,6 +832,7 @@ def call_task_fn( @dataclass class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]): + task_run: TaskRun | None = None _client: Optional[PrefectClient] = None @property @@ -868,7 +874,7 @@ async def can_retry(self, exc: Exception) -> bool: ) return False - async def call_hooks(self, state: Optional[State] = None): + async def call_hooks(self, state: Optional[State] = None) -> None: if state is None: state = self.state task = self.task @@ -903,7 +909,7 @@ async def call_hooks(self, state: Optional[State] = None): else: self.logger.info(f"Hook {hook_name!r} finished running successfully") - async def begin_run(self): + async def begin_run(self) -> None: try: self._resolve_parameters() self._set_custom_task_run_name() @@ -1077,7 +1083,7 @@ async def handle_retry(self, exc: Exception) -> bool: ) await self.set_state(new_state, force=True) - self.retries = self.retries + 1 + self.retries: int = self.retries + 1 return True elif self.retries >= self.task.retries: self.logger.error( @@ -1171,7 +1177,9 @@ async def setup_run_context(self, client: Optional[PrefectClient] = None): ) stack.enter_context(ConcurrencyContext()) - self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore + self.logger: "logging.Logger" = task_run_logger( + task_run=self.task_run, task=self.task + ) # type: ignore yield @@ -1180,7 +1188,7 @@ async def initialize_run( self, task_run_id: Optional[UUID] = None, dependencies: Optional[dict[str, set[TaskRunInput]]] = None, - ) -> AsyncGenerator["AsyncTaskRunEngine", Any]: + ) -> AsyncGenerator[Self, Any]: """ Enters a client context and creates a task run if needed. """ @@ -1248,7 +1256,7 @@ async def initialize_run( self._is_started = False self._client = None - async def wait_until_ready(self): + async def wait_until_ready(self) -> None: """Waits until the scheduled time (if its the future), then enters Running.""" if scheduled_time := self.state.state_details.scheduled_time: sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds() diff --git a/src/prefect/task_runners.py b/src/prefect/task_runners.py index 1efd284e34fd..9f42a8724cab 100644 --- a/src/prefect/task_runners.py +++ b/src/prefect/task_runners.py @@ -67,7 +67,7 @@ def __init__(self): self._started = False @property - def name(self): + def name(self) -> str: """The name of this task runner""" return type(self).__name__.lower().replace("taskrunner", "") diff --git a/src/prefect/task_runs.py b/src/prefect/task_runs.py index c76ea3f6418a..bdbce7e6138b 100644 --- a/src/prefect/task_runs.py +++ b/src/prefect/task_runs.py @@ -17,6 +17,9 @@ from prefect.events.filters import EventFilter, EventNameFilter from prefect.logging.loggers import get_logger +if TYPE_CHECKING: + import logging + class TaskRunWaiter: """ @@ -70,8 +73,8 @@ async def main(): _instance_lock = threading.Lock() def __init__(self): - self.logger = get_logger("TaskRunWaiter") - self._consumer_task: asyncio.Task[None] | None = None + self.logger: "logging.Logger" = get_logger("TaskRunWaiter") + self._consumer_task: "asyncio.Task[None] | None" = None self._observed_completed_task_runs: TTLCache[uuid.UUID, bool] = TTLCache( maxsize=10000, ttl=600 ) @@ -82,7 +85,7 @@ def __init__(self): self._completion_events_lock = threading.Lock() self._started = False - def start(self): + def start(self) -> None: """ Start the TaskRunWaiter service. """ @@ -145,7 +148,7 @@ async def _consume_events(self, consumer_started: asyncio.Event): except Exception as exc: self.logger.error(f"Error processing event: {exc}") - def stop(self): + def stop(self) -> None: """ Stop the TaskRunWaiter service. """ @@ -159,7 +162,7 @@ def stop(self): @classmethod async def wait_for_task_run( cls, task_run_id: uuid.UUID, timeout: Optional[float] = None - ): + ) -> None: """ Wait for a task run to finish. @@ -225,7 +228,7 @@ def add_done_callback( instance._completion_callbacks[task_run_id] = callback @classmethod - def instance(cls): + def instance(cls) -> Self: """ Get the singleton instance of TaskRunWaiter. """ diff --git a/src/prefect/task_worker.py b/src/prefect/task_worker.py index f1aa9fc0d721..bcc37c3d73aa 100644 --- a/src/prefect/task_worker.py +++ b/src/prefect/task_worker.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import inspect import os @@ -16,7 +18,7 @@ import uvicorn from exceptiongroup import BaseExceptionGroup # novermin from fastapi import FastAPI -from typing_extensions import ParamSpec, TypeVar +from typing_extensions import ParamSpec, Self, TypeVar from websockets.exceptions import InvalidStatusCode from prefect import Task @@ -42,7 +44,10 @@ from prefect.utilities.services import start_client_metrics_server from prefect.utilities.urls import url_for -logger = get_logger("task_worker") +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger("task_worker") P = ParamSpec("P") R = TypeVar("R", infer_variance=True) @@ -85,7 +90,7 @@ class TaskWorker: def __init__( self, *tasks: Task[P, R], - limit: Optional[int] = 10, + limit: int | None = 10, ): self.tasks: list["Task[..., Any]"] = [] for t in tasks: @@ -100,7 +105,7 @@ def __init__( else: self.tasks.append(t.with_options(persist_result=True)) - self.task_keys = set(t.task_key for t in tasks if isinstance(t, Task)) # pyright: ignore[reportUnnecessaryIsInstance] + self.task_keys: set[str] = set(t.task_key for t in tasks if isinstance(t, Task)) # pyright: ignore[reportUnnecessaryIsInstance] self._started_at: Optional[pendulum.DateTime] = None self.stopping: bool = False @@ -154,7 +159,7 @@ def current_tasks(self) -> Optional[int]: def available_tasks(self) -> Optional[int]: return int(self._limiter.available_tokens) if self._limiter else None - def handle_sigterm(self, signum: int, frame: object): + def handle_sigterm(self, signum: int, frame: object) -> None: """ Shuts down the task worker when a SIGTERM is received. """ @@ -355,14 +360,14 @@ async def _submit_scheduled_task_run(self, task_run: TaskRun): ) await asyncio.wrap_future(future) - async def execute_task_run(self, task_run: TaskRun): + async def execute_task_run(self, task_run: TaskRun) -> None: """Execute a task run in the task worker.""" async with self if not self.started else asyncnullcontext(): token_acquired = await self._acquire_token(task_run.id) if token_acquired: await self._safe_submit_scheduled_task_run(task_run) - async def __aenter__(self): + async def __aenter__(self) -> Self: logger.debug("Starting task worker...") if self._client._closed: # pyright: ignore[reportPrivateUsage] diff --git a/src/prefect/telemetry/bootstrap.py b/src/prefect/telemetry/bootstrap.py index 6646cf1d9d81..bbd3c95a78ad 100644 --- a/src/prefect/telemetry/bootstrap.py +++ b/src/prefect/telemetry/bootstrap.py @@ -4,7 +4,10 @@ from prefect.client.base import ServerType, determine_server_type from prefect.logging.loggers import get_logger -logger = get_logger(__name__) +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger(__name__) if TYPE_CHECKING: from opentelemetry.sdk._logs import LoggerProvider diff --git a/src/prefect/telemetry/run_telemetry.py b/src/prefect/telemetry/run_telemetry.py index cb43719841cf..abad58f781fe 100644 --- a/src/prefect/telemetry/run_telemetry.py +++ b/src/prefect/telemetry/run_telemetry.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import time from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union from opentelemetry import propagate, trace from opentelemetry.context import Context @@ -47,14 +49,14 @@ class RunTelemetry: _tracer: "Tracer" = field( default_factory=lambda: get_tracer("prefect", prefect.__version__) ) - span: Optional[Span] = None + span: Span | None = None async def async_start_span( self, run: FlowOrTaskRun, client: PrefectClient, - parameters: Optional[dict[str, Any]] = None, - ): + parameters: dict[str, Any] | None = None, + ) -> Span: traceparent, span = self._start_span(run, parameters) if self._run_type(run) == "flow" and traceparent: @@ -70,8 +72,8 @@ def start_span( self, run: FlowOrTaskRun, client: SyncPrefectClient, - parameters: Optional[dict[str, Any]] = None, - ): + parameters: dict[str, Any] | None = None, + ) -> Span: traceparent, span = self._start_span(run, parameters) if self._run_type(run) == "flow" and traceparent: @@ -84,8 +86,8 @@ def start_span( def _start_span( self, run: FlowOrTaskRun, - parameters: Optional[dict[str, Any]] = None, - ) -> tuple[Optional[str], Span]: + parameters: dict[str, Any] | None = None, + ) -> tuple[str | None, Span]: """ Start a span for a run. """ @@ -139,8 +141,8 @@ def _run_type(self, run: FlowOrTaskRun) -> str: return "task" if isinstance(run, TaskRun) else "flow" def _trace_context_from_labels( - self, labels: Optional[KeyValueLabels] - ) -> Optional[Context]: + self, labels: KeyValueLabels | None + ) -> Context | None: """Get trace context from run labels if it exists.""" if not labels or LABELS_TRACEPARENT_KEY not in labels: return None @@ -148,7 +150,7 @@ def _trace_context_from_labels( carrier = {TRACEPARENT_KEY: traceparent} return propagate.extract(carrier) - def _traceparent_from_span(self, span: Span) -> Optional[str]: + def _traceparent_from_span(self, span: Span) -> str | None: carrier: dict[str, Any] = {} propagate.inject(carrier, context=trace.set_span_in_context(span)) return carrier.get(TRACEPARENT_KEY) @@ -162,7 +164,7 @@ def end_span_on_success(self) -> None: self.span.end(time.time_ns()) self.span = None - def end_span_on_failure(self, terminal_message: Optional[str] = None) -> None: + def end_span_on_failure(self, terminal_message: str | None = None) -> None: """ End a span for a run on failure. """ @@ -203,7 +205,7 @@ def update_run_name(self, name: str) -> None: self.span.update_name(name=name) self.span.set_attribute("prefect.run.name", name) - def _parent_run(self) -> Union[FlowOrTaskRun, None]: + def _parent_run(self) -> FlowOrTaskRun | None: """ Identify the "parent run" for the current execution context. diff --git a/src/prefect/testing/cli.py b/src/prefect/testing/cli.py index 7a660119257b..03c893527336 100644 --- a/src/prefect/testing/cli.py +++ b/src/prefect/testing/cli.py @@ -13,7 +13,7 @@ from prefect.utilities.asyncutils import in_async_main_thread -def check_contains(cli_result: Result, content: str, should_contain: bool): +def check_contains(cli_result: Result, content: str, should_contain: bool) -> None: """ Utility function to see if content is or is not in a CLI result. diff --git a/src/prefect/testing/docker.py b/src/prefect/testing/docker.py index d70a1a2e5542..1e0de02624ad 100644 --- a/src/prefect/testing/docker.py +++ b/src/prefect/testing/docker.py @@ -1,18 +1,18 @@ from contextlib import contextmanager -from typing import Generator, List +from typing import Any, Generator from unittest import mock from prefect.utilities.dockerutils import ImageBuilder @contextmanager -def capture_builders() -> Generator[List[ImageBuilder], None, None]: +def capture_builders() -> Generator[list[ImageBuilder], None, None]: """Captures any instances of ImageBuilder created while this context is active""" - builders = [] + builders: list[ImageBuilder] = [] original_init = ImageBuilder.__init__ - def capture(self, *args, **kwargs): + def capture(self: ImageBuilder, *args: Any, **kwargs: Any): builders.append(self) original_init(self, *args, **kwargs) diff --git a/src/prefect/testing/fixtures.py b/src/prefect/testing/fixtures.py index 07352f872afc..3b59ab64440e 100644 --- a/src/prefect/testing/fixtures.py +++ b/src/prefect/testing/fixtures.py @@ -4,7 +4,7 @@ import socket import sys from contextlib import contextmanager -from typing import AsyncGenerator, Generator, List, Optional, Union +from typing import Any, AsyncGenerator, Callable, Generator, List, Optional, Union from unittest import mock from uuid import UUID @@ -39,7 +39,9 @@ @pytest.fixture(autouse=True) -def add_prefect_loggers_to_caplog(caplog): +def add_prefect_loggers_to_caplog( + caplog: pytest.LogCaptureFixture, +) -> Generator[None, None, None]: import logging logger = logging.getLogger("prefect") @@ -57,7 +59,9 @@ def is_port_in_use(port: int) -> bool: @pytest.fixture(scope="session") -async def hosted_api_server(unused_tcp_port_factory): +async def hosted_api_server( + unused_tcp_port_factory: Callable[[], int], +) -> AsyncGenerator[str, None]: """ Runs an instance of the Prefect API server in a subprocess instead of the using the ephemeral application. @@ -134,7 +138,7 @@ async def hosted_api_server(unused_tcp_port_factory): @pytest.fixture(autouse=True) -def use_hosted_api_server(hosted_api_server): +def use_hosted_api_server(hosted_api_server: str) -> Generator[str, None, None]: """ Sets `PREFECT_API_URL` to the test session's hosted API endpoint. """ @@ -148,7 +152,7 @@ def use_hosted_api_server(hosted_api_server): @pytest.fixture -def disable_hosted_api_server(): +def disable_hosted_api_server() -> Generator[None, None, None]: """ Disables the hosted API server by setting `PREFECT_API_URL` to `None`. """ @@ -157,11 +161,13 @@ def disable_hosted_api_server(): PREFECT_API_URL: None, } ): - yield hosted_api_server + yield @pytest.fixture -def enable_ephemeral_server(disable_hosted_api_server): +def enable_ephemeral_server( + disable_hosted_api_server: None, +) -> Generator[None, None, None]: """ Enables the ephemeral server by setting `PREFECT_SERVER_ALLOW_EPHEMERAL_MODE` to `True`. """ @@ -170,13 +176,15 @@ def enable_ephemeral_server(disable_hosted_api_server): PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: True, } ): - yield hosted_api_server + yield SubprocessASGIServer().stop() @pytest.fixture -def mock_anyio_sleep(monkeypatch): +def mock_anyio_sleep( + monkeypatch: pytest.MonkeyPatch, +) -> Generator[Callable[[float], None], None, None]: """ Mock sleep used to not actually sleep but to set the current time to now + sleep delay seconds while still yielding to other tasks in the event loop. @@ -188,18 +196,18 @@ def mock_anyio_sleep(monkeypatch): original_sleep = anyio.sleep time_shift = 0.0 - async def callback(delay_in_seconds): + async def callback(delay_in_seconds: float) -> None: nonlocal time_shift time_shift += float(delay_in_seconds) # Preserve yield effects of sleep await original_sleep(0) - def latest_now(*args): + def latest_now(*args: Any) -> pendulum.DateTime: # Fast-forwards the time by the total sleep time return original_now(*args).add( # Ensure we retain float precision seconds=int(time_shift), - microseconds=(time_shift - int(time_shift)) * 1000000, + microseconds=int((time_shift - int(time_shift)) * 1000000), ) monkeypatch.setattr("pendulum.now", latest_now) @@ -368,7 +376,7 @@ def events_cloud_api_url(events_server: WebSocketServer, unused_tcp_port: int) - @pytest.fixture -def mock_should_emit_events(monkeypatch) -> mock.Mock: +def mock_should_emit_events(monkeypatch: pytest.MonkeyPatch) -> mock.Mock: m = mock.Mock() m.return_value = True monkeypatch.setattr("prefect.events.utilities.should_emit_events", m) @@ -376,7 +384,9 @@ def mock_should_emit_events(monkeypatch) -> mock.Mock: @pytest.fixture -def asserting_events_worker(monkeypatch) -> Generator[EventsWorker, None, None]: +def asserting_events_worker( + monkeypatch: pytest.MonkeyPatch, +) -> Generator[EventsWorker, None, None]: worker = EventsWorker.instance(AssertingEventsClient) # Always yield the asserting worker when new instances are retrieved monkeypatch.setattr(EventsWorker, "instance", lambda *_: worker) @@ -388,7 +398,7 @@ def asserting_events_worker(monkeypatch) -> Generator[EventsWorker, None, None]: @pytest.fixture def asserting_and_emitting_events_worker( - monkeypatch, + monkeypatch: pytest.MonkeyPatch, ) -> Generator[EventsWorker, None, None]: worker = EventsWorker.instance(AssertingPassthroughEventsClient) # Always yield the asserting worker when new instances are retrieved @@ -400,7 +410,9 @@ def asserting_and_emitting_events_worker( @pytest.fixture -async def events_pipeline(asserting_events_worker: EventsWorker): +async def events_pipeline( + asserting_events_worker: EventsWorker, +) -> AsyncGenerator[EventsPipeline, None]: class AssertingEventsPipeline(EventsPipeline): @sync_compatible async def process_events( @@ -435,7 +447,9 @@ async def wait_for_min_events(): @pytest.fixture -async def emitting_events_pipeline(asserting_and_emitting_events_worker: EventsWorker): +async def emitting_events_pipeline( + asserting_and_emitting_events_worker: EventsWorker, +) -> AsyncGenerator[EventsPipeline, None]: class AssertingAndEmittingEventsPipeline(EventsPipeline): @sync_compatible async def process_events(self): @@ -449,14 +463,16 @@ async def process_events(self): @pytest.fixture -def reset_worker_events(asserting_events_worker: EventsWorker): +def reset_worker_events( + asserting_events_worker: EventsWorker, +) -> Generator[None, None, None]: yield assert isinstance(asserting_events_worker._client, AssertingEventsClient) asserting_events_worker._client.events = [] @pytest.fixture -def enable_lineage_events(): +def enable_lineage_events() -> Generator[None, None, None]: """A fixture that ensures lineage events are enabled.""" with temporary_settings(updates={PREFECT_EXPERIMENTS_LINEAGE_EVENTS_ENABLED: True}): yield diff --git a/src/prefect/testing/standard_test_suites/blocks.py b/src/prefect/testing/standard_test_suites/blocks.py index 0fa04776784e..f71f80c1244e 100644 --- a/src/prefect/testing/standard_test_suites/blocks.py +++ b/src/prefect/testing/standard_test_suites/blocks.py @@ -14,13 +14,13 @@ class BlockStandardTestSuite(ABC): def block(self) -> type[Block]: pass - def test_has_a_description(self, block: type[Block]): + def test_has_a_description(self, block: type[Block]) -> None: assert block.get_description() - def test_has_a_documentation_url(self, block: type[Block]): + def test_has_a_documentation_url(self, block: type[Block]) -> None: assert block._documentation_url - def test_all_fields_have_a_description(self, block: type[Block]): + def test_all_fields_have_a_description(self, block: type[Block]) -> None: for name, field in block.model_fields.items(): if Block.annotation_refers_to_block_class(field.annotation): # TODO: Block field descriptions aren't currently handled by the UI, so @@ -34,7 +34,7 @@ def test_all_fields_have_a_description(self, block: type[Block]): "." ), f"{name} description on {block.__name__} does not end with a period" - def test_has_a_valid_code_example(self, block: type[Block]): + def test_has_a_valid_code_example(self, block: type[Block]) -> None: code_example = block.get_code_example() assert code_example is not None, f"{block.__name__} is missing a code example" @@ -55,7 +55,7 @@ def test_has_a_valid_code_example(self, block: type[Block]): f" matching the pattern {block_load_pattern}" ) - def test_has_a_valid_image(self, block: type[Block]): + def test_has_a_valid_image(self, block: type[Block]) -> None: logo_url = block._logo_url assert ( logo_url is not None diff --git a/src/prefect/testing/utilities.py b/src/prefect/testing/utilities.py index 2b874240ec39..d22fb2887424 100644 --- a/src/prefect/testing/utilities.py +++ b/src/prefect/testing/utilities.py @@ -2,6 +2,8 @@ Internal utilities for tests. """ +from __future__ import annotations + import atexit import shutil import warnings @@ -9,7 +11,7 @@ from pathlib import Path from pprint import pprint from tempfile import mkdtemp -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Generator import prefect.context import prefect.settings @@ -30,10 +32,11 @@ if TYPE_CHECKING: from prefect.client.orchestration import PrefectClient + from prefect.client.schemas.objects import FlowRun from prefect.filesystems import ReadableFileSystem -def exceptions_equal(a, b): +def exceptions_equal(a: Exception, b: Exception) -> bool: """ Exceptions cannot be compared by `==`. They can be compared using `is` but this will fail if the exception is serialized/deserialized so this utility does its @@ -54,9 +57,9 @@ def exceptions_equal(a, b): def kubernetes_environments_equal( - actual: List[Dict[str, str]], - expected: Union[List[Dict[str, str]], Dict[str, str]], -): + actual: list[dict[str, str]], + expected: list[dict[str, str]] | dict[str, str], +) -> bool: # Convert to a required format and sort by name if isinstance(expected, dict): expected = [{"name": key, "value": value} for key, value in expected.items()] @@ -90,7 +93,9 @@ def kubernetes_environments_equal( @contextmanager -def assert_does_not_warn(ignore_warnings=[]): +def assert_does_not_warn( + ignore_warnings: list[type[Warning]] | None = None, +) -> Generator[None, None, None]: """ Converts warnings to errors within this context to assert warnings are not raised, except for those specified in ignore_warnings. @@ -98,6 +103,7 @@ def assert_does_not_warn(ignore_warnings=[]): Parameters: - ignore_warnings: List of warning types to ignore. Example: [DeprecationWarning, UserWarning] """ + ignore_warnings = ignore_warnings or [] with warnings.catch_warnings(): warnings.simplefilter("error") for warning_type in ignore_warnings: @@ -110,7 +116,7 @@ def assert_does_not_warn(ignore_warnings=[]): @contextmanager -def prefect_test_harness(server_startup_timeout: Optional[int] = 30): +def prefect_test_harness(server_startup_timeout: int | None = 30): """ Temporarily run flows against a local SQLite database for testing. @@ -175,7 +181,7 @@ def cleanup_temp_dir(temp_dir): test_server.stop() -async def get_most_recent_flow_run(client: "PrefectClient" = None): +async def get_most_recent_flow_run(client: "PrefectClient | None" = None) -> "FlowRun": if client is None: client = get_client() @@ -187,8 +193,8 @@ async def get_most_recent_flow_run(client: "PrefectClient" = None): def assert_blocks_equal( - found: Block, expected: Block, exclude_private: bool = True, **kwargs -) -> bool: + found: Block, expected: Block, exclude_private: bool = True, **kwargs: Any +) -> None: assert isinstance( found, type(expected) ), f"Unexpected type {type(found).__name__}, expected {type(expected).__name__}" @@ -204,8 +210,8 @@ def assert_blocks_equal( async def assert_uses_result_serializer( - state: State, serializer: Union[str, Serializer], client: "PrefectClient" -): + state: State, serializer: str | Serializer, client: "PrefectClient" +) -> None: assert isinstance(state.data, (ResultRecord, ResultRecordMetadata)) if isinstance(state.data, ResultRecord): result_serializer = state.data.metadata.serializer @@ -240,8 +246,8 @@ async def assert_uses_result_serializer( @inject_client async def assert_uses_result_storage( - state: State, storage: Union[str, "ReadableFileSystem"], client: "PrefectClient" -): + state: State, storage: "str | ReadableFileSystem", client: "PrefectClient" +) -> None: assert isinstance(state.data, (ResultRecord, ResultRecordMetadata)) if isinstance(state.data, ResultRecord): assert_blocks_equal( @@ -267,11 +273,11 @@ async def assert_uses_result_storage( ) -def a_test_step(**kwargs): +def a_test_step(**kwargs: Any) -> dict[str, Any]: kwargs.update({"output1": 1, "output2": ["b", 2, 3]}) return kwargs -def b_test_step(**kwargs): +def b_test_step(**kwargs: Any) -> dict[str, Any]: kwargs.update({"output1": 1, "output2": ["b", 2, 3]}) return kwargs diff --git a/src/prefect/transactions.py b/src/prefect/transactions.py index 6e882dbf5c6f..e8cdd65034f4 100644 --- a/src/prefect/transactions.py +++ b/src/prefect/transactions.py @@ -173,7 +173,7 @@ def is_pending(self) -> bool: def is_active(self) -> bool: return self.state == TransactionState.ACTIVE - def __enter__(self): + def __enter__(self) -> Self: if self._token is not None: raise RuntimeError( "Context already entered. Context enter calls cannot be nested." @@ -206,7 +206,7 @@ def __enter__(self): self._token = self.__var__.set(self) return self - def __exit__(self, *exc_info: Any): + def __exit__(self, *exc_info: Any) -> None: exc_type, exc_val, _ = exc_info if not self._token: raise RuntimeError( @@ -235,7 +235,7 @@ def __exit__(self, *exc_info: Any): self.reset() - def begin(self): + def begin(self) -> None: if ( self.store and self.key diff --git a/src/prefect/types/__init__.py b/src/prefect/types/__init__.py index f36622f5a3df..c70e5cc3c8b1 100644 --- a/src/prefect/types/__init__.py +++ b/src/prefect/types/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from functools import partial from typing import Annotated, Any, Dict, List, Optional, Set, TypeVar, Union from typing_extensions import Literal, TypeAlias @@ -114,7 +116,7 @@ class SecretDict(pydantic.Secret[Dict[str, Any]]): def validate_set_T_from_delim_string( - value: Union[str, T, Set[T], None], type_, delim=None + value: Union[str, T, Set[T], None], type_: type[T], delim: str | None = None ) -> Set[T]: """ "no-info" before validator useful in scooping env vars diff --git a/src/prefect/utilities/_deprecated.py b/src/prefect/utilities/_deprecated.py new file mode 100644 index 000000000000..2d5fd8a87c7c --- /dev/null +++ b/src/prefect/utilities/_deprecated.py @@ -0,0 +1,38 @@ +from typing import Any + +from jsonpatch import ( # type: ignore # no typing stubs available, see https://github.com/stefankoegl/python-json-patch/issues/158 + JsonPatch as JsonPatchBase, +) +from pydantic import GetJsonSchemaHandler +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import core_schema + + +class JsonPatch(JsonPatchBase): + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetJsonSchemaHandler + ) -> core_schema.CoreSchema: + return core_schema.typed_dict_schema( + {"patch": core_schema.typed_dict_field(core_schema.dict_schema())} + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + json_schema = handler(core_schema) + json_schema = handler.resolve_ref_schema(json_schema) + json_schema.pop("required", None) + json_schema.pop("properties", None) + json_schema.update( + { + "type": "array", + "format": "rfc6902", + "items": { + "type": "object", + "additionalProperties": {"type": "string"}, + }, + } + ) + return json_schema diff --git a/src/prefect/utilities/filesystem.py b/src/prefect/utilities/filesystem.py index 1301a1178c91..92579c444c87 100644 --- a/src/prefect/utilities/filesystem.py +++ b/src/prefect/utilities/filesystem.py @@ -8,7 +8,7 @@ from collections.abc import Iterable from contextlib import contextmanager from pathlib import Path, PureWindowsPath -from typing import AnyStr, Optional, Union, cast +from typing import Any, AnyStr, Optional, Union, cast # fsspec has no stubs, see https://github.com/fsspec/filesystem_spec/issues/625 import fsspec # type: ignore @@ -114,7 +114,7 @@ def filename(path: str) -> str: return path.split(sep)[-1] -def is_local_path(path: Union[str, pathlib.Path, OpenFile]) -> bool: +def is_local_path(path: Union[str, pathlib.Path, Any]) -> bool: """Check if the given path points to a local or remote file system""" if isinstance(path, str): try: diff --git a/src/prefect/utilities/generics.py b/src/prefect/utilities/generics.py index a3ff954c247c..9d5097b4e6aa 100644 --- a/src/prefect/utilities/generics.py +++ b/src/prefect/utilities/generics.py @@ -5,7 +5,7 @@ T = TypeVar("T", bound=BaseModel) -ListValidator = SchemaValidator( +ListValidator: SchemaValidator = SchemaValidator( schema=core_schema.list_schema( items_schema=core_schema.dict_schema( keys_schema=core_schema.str_schema(), values_schema=core_schema.any_schema() diff --git a/src/prefect/utilities/pydantic.py b/src/prefect/utilities/pydantic.py index 6086931fbbbd..381456ca23bc 100644 --- a/src/prefect/utilities/pydantic.py +++ b/src/prefect/utilities/pydantic.py @@ -1,3 +1,4 @@ +import warnings from typing import ( Any, Callable, @@ -10,18 +11,13 @@ overload, ) -from jsonpatch import ( # type: ignore # no typing stubs available, see https://github.com/stefankoegl/python-json-patch/issues/158 - JsonPatch as JsonPatchBase, -) from pydantic import ( BaseModel, - GetJsonSchemaHandler, Secret, TypeAdapter, ValidationError, ) -from pydantic.json_schema import JsonSchemaValue -from pydantic_core import core_schema, to_jsonable_python +from pydantic_core import to_jsonable_python from typing_extensions import Literal from prefect.utilities.dispatch import get_dispatch_key, lookup_type, register_base_type @@ -262,36 +258,6 @@ def __repr__(self) -> str: return f"PartialModel(cls={self.model_cls.__name__}, {dsp_fields})" -class JsonPatch(JsonPatchBase): - @classmethod - def __get_pydantic_core_schema__( - cls, source_type: Any, handler: GetJsonSchemaHandler - ) -> core_schema.CoreSchema: - return core_schema.typed_dict_schema( - {"patch": core_schema.typed_dict_field(core_schema.dict_schema())} - ) - - @classmethod - def __get_pydantic_json_schema__( - cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler - ) -> JsonSchemaValue: - json_schema = handler(core_schema) - json_schema = handler.resolve_ref_schema(json_schema) - json_schema.pop("required", None) - json_schema.pop("properties", None) - json_schema.update( - { - "type": "array", - "format": "rfc6902", - "items": { - "type": "object", - "additionalProperties": {"type": "string"}, - }, - } - ) - return json_schema - - def custom_pydantic_encoder( type_encoders: dict[Any, Callable[[type[Any]], Any]], obj: Any ) -> Any: @@ -382,3 +348,22 @@ def handle_secret_render(value: object, context: dict[str, Any]) -> object: elif isinstance(value, BaseModel): return value.model_dump(context=context) return value + + +def __getattr__(name: str) -> Any: + """ + Handles imports from this module that are deprecated. + """ + + if name == "JsonPatch": + warnings.warn( + "JsonPatch is deprecated and will be removed after March 2025. " + "Please use `JsonPatch` from the `jsonpatch` package instead.", + DeprecationWarning, + stacklevel=2, + ) + from ._deprecated import JsonPatch + + return JsonPatch + else: + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/src/prefect/workers/base.py b/src/prefect/workers/base.py index ebc67c05d794..16b0d4534945 100644 --- a/src/prefect/workers/base.py +++ b/src/prefect/workers/base.py @@ -1,9 +1,22 @@ +from __future__ import annotations + import abc import asyncio import threading from contextlib import AsyncExitStack from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generic, + List, + Optional, + Set, + Type, + Union, +) from uuid import UUID, uuid4 import anyio @@ -13,7 +26,7 @@ from importlib_metadata import distributions from pydantic import BaseModel, Field, PrivateAttr, field_validator from pydantic.json_schema import GenerateJsonSchema -from typing_extensions import Literal +from typing_extensions import Literal, Self, TypeVar import prefect from prefect._internal.schemas.validators import return_v_or_none @@ -104,7 +117,7 @@ class BaseJobConfiguration(BaseModel): _related_objects: Dict[str, Any] = PrivateAttr(default_factory=dict) @property - def is_using_a_runner(self): + def is_using_a_runner(self) -> bool: return self.command is not None and "prefect flow-run execute" in self.command @field_validator("command") @@ -175,7 +188,7 @@ async def from_template_and_values( return cls(**populated_configuration) @classmethod - def json_template(cls) -> dict: + def json_template(cls) -> dict[str, Any]: """Returns a dict with job configuration as keys and the corresponding templates as values Defaults to using the job configuration parameter name as the template variable name. @@ -186,7 +199,7 @@ def json_template(cls) -> dict: key2: '{{ template2 }}', # `template2` specifically provide as template } """ - configuration = {} + configuration: dict[str, Any] = {} properties = cls.model_json_schema()["properties"] for k, v in properties.items(): if v.get("template"): @@ -202,7 +215,7 @@ def prepare_for_flow_run( flow_run: "FlowRun", deployment: Optional["DeploymentResponse"] = None, flow: Optional["Flow"] = None, - ): + ) -> None: """ Prepare the job configuration for a flow run. @@ -368,15 +381,20 @@ class BaseWorkerResult(BaseModel, abc.ABC): identifier: str status_code: int - def __bool__(self): + def __bool__(self) -> bool: return self.status_code == 0 +C = TypeVar("C", bound=BaseJobConfiguration) +V = TypeVar("V", bound=BaseVariables) +R = TypeVar("R", bound=BaseWorkerResult) + + @register_base_type -class BaseWorker(abc.ABC): +class BaseWorker(abc.ABC, Generic[C, V, R]): type: str - job_configuration: Type[BaseJobConfiguration] = BaseJobConfiguration - job_configuration_variables: Optional[Type[BaseVariables]] = None + job_configuration: Type[C] = BaseJobConfiguration # type: ignore + job_configuration_variables: Optional[Type[V]] = None _documentation_url = "" _logo_url = "" @@ -418,7 +436,7 @@ def __init__( """ if name and ("/" in name or "%" in name): raise ValueError("Worker name cannot contain '/' or '%'") - self.name = name or f"{self.__class__.__name__} {uuid4()}" + self.name: str = name or f"{self.__class__.__name__} {uuid4()}" self._started_event: Optional[Event] = None self.backend_id: Optional[UUID] = None self._logger = get_worker_logger(self) @@ -432,7 +450,7 @@ def __init__( self._prefetch_seconds: float = ( prefetch_seconds or PREFECT_WORKER_PREFETCH_SECONDS.value() ) - self.heartbeat_interval_seconds = ( + self.heartbeat_interval_seconds: int = ( heartbeat_interval_seconds or PREFECT_WORKER_HEARTBEAT_SECONDS.value() ) @@ -461,7 +479,7 @@ def get_description(cls) -> str: return cls._description @classmethod - def get_default_base_job_template(cls) -> Dict: + def get_default_base_job_template(cls) -> dict[str, Any]: if cls.job_configuration_variables is None: schema = cls.job_configuration.model_json_schema() # remove "template" key from all dicts in schema['properties'] because it is not a @@ -479,7 +497,9 @@ def get_default_base_job_template(cls) -> Dict: } @staticmethod - def get_worker_class_from_type(type: str) -> Optional[Type["BaseWorker"]]: + def get_worker_class_from_type( + type: str, + ) -> Optional[Type["BaseWorker[Any, Any, Any]"]]: """ Returns the worker class for a given worker type. If the worker type is not recognized, returns None. @@ -500,7 +520,7 @@ def get_all_available_worker_types() -> List[str]: return list(worker_registry.keys()) return [] - def get_name_slug(self): + def get_name_slug(self) -> str: return slugify(self.name) def get_flow_run_logger(self, flow_run: "FlowRun") -> PrefectLogAdapter: @@ -524,7 +544,7 @@ async def start( run_once: bool = False, with_healthcheck: bool = False, printer: Callable[..., None] = print, - ): + ) -> None: """ Starts the worker and runs the main worker loops. @@ -603,9 +623,9 @@ async def start( async def run( self, flow_run: "FlowRun", - configuration: BaseJobConfiguration, - task_status: Optional[anyio.abc.TaskStatus] = None, - ) -> BaseWorkerResult: + configuration: C, + task_status: Optional[anyio.abc.TaskStatus[int]] = None, + ) -> R: """ Runs a given flow run on the current worker. """ @@ -614,12 +634,12 @@ async def run( ) @classmethod - def __dispatch_key__(cls): + def __dispatch_key__(cls) -> str | None: if cls.__name__ == "BaseWorker": return None # The base class is abstract return cls.type - async def setup(self): + async def setup(self) -> None: """Prepares the worker to run.""" self._logger.debug("Setting up worker...") self._runs_task_group = anyio.create_task_group() @@ -637,10 +657,10 @@ async def setup(self): self.is_setup = True - async def teardown(self, *exc_info): + async def teardown(self, *exc_info: Any) -> None: """Cleans up resources after the worker is stopped.""" self._logger.debug("Tearing down worker...") - self.is_setup = False + self.is_setup: bool = False for scope in self._scheduled_task_scopes: scope.cancel() @@ -684,14 +704,16 @@ def is_worker_still_polling(self, query_interval_seconds: float) -> bool: return is_still_polling - async def get_and_submit_flow_runs(self): + async def get_and_submit_flow_runs(self) -> list["FlowRun"]: runs_response = await self._get_scheduled_flow_runs() self._last_polled_time = pendulum.now("utc") return await self._submit_scheduled_flow_runs(flow_run_response=runs_response) - async def _update_local_work_pool_info(self): + async def _update_local_work_pool_info(self) -> None: + if TYPE_CHECKING: + assert self._client is not None try: work_pool = await self._client.read_work_pool( work_pool_name=self._work_pool_name @@ -803,7 +825,7 @@ async def _send_worker_heartbeat(self) -> Optional[UUID]: return worker_id - async def sync_with_backend(self): + async def sync_with_backend(self) -> None: """ Updates the worker's local information about it's current work pool and queues. Sends a worker heartbeat to the API. @@ -1042,7 +1064,7 @@ def _release_limit_slot(self, flow_run_id: str) -> None: self._limiter.release_on_behalf_of(flow_run_id) self._logger.debug("Limit slot released for flow run '%s'", flow_run_id) - def get_status(self): + def get_status(self) -> dict[str, Any]: """ Retrieves the status of the current worker including its name, current worker pool, the work pool queues it is polling, and its local settings. @@ -1234,17 +1256,17 @@ async def _give_worker_labels_to_flow_run(self, flow_run_id: UUID): await self._client.update_flow_run_labels(flow_run_id, labels) - async def __aenter__(self): + async def __aenter__(self) -> Self: self._logger.debug("Entering worker context...") await self.setup() return self - async def __aexit__(self, *exc_info): + async def __aexit__(self, *exc_info: Any) -> None: self._logger.debug("Exiting worker context...") await self.teardown(*exc_info) - def __repr__(self): + def __repr__(self) -> str: return f"Worker(pool={self._work_pool_name!r}, name={self.name!r})" def _event_resource(self): diff --git a/src/prefect/workers/process.py b/src/prefect/workers/process.py index 4a89665f4fca..e0059d185395 100644 --- a/src/prefect/workers/process.py +++ b/src/prefect/workers/process.py @@ -13,6 +13,7 @@ For more information about work pools and workers, checkout out the [Prefect docs](/concepts/work-pools/). """ +from __future__ import annotations import contextlib import os @@ -30,7 +31,7 @@ import anyio.abc from pydantic import Field, field_validator -from prefect._internal.schemas.validators import validate_command +from prefect._internal.schemas.validators import validate_working_dir from prefect.client.schemas import FlowRun from prefect.client.schemas.filters import ( FlowRunFilter, @@ -85,19 +86,21 @@ class ProcessJobConfiguration(BaseJobConfiguration): @field_validator("working_dir") @classmethod - def validate_command(cls, v: str) -> str: - return validate_command(v) + def validate_working_dir(cls, v: Path | str | None) -> Path | None: + if isinstance(v, str): + return validate_working_dir(v) + return v def prepare_for_flow_run( self, flow_run: "FlowRun", deployment: Optional["DeploymentResponse"] = None, flow: Optional["Flow"] = None, - ): + ) -> None: super().prepare_for_flow_run(flow_run, deployment, flow) - self.env = {**os.environ, **self.env} - self.command = ( + self.env: dict[str, str | None] = {**os.environ, **self.env} + self.command: str | None = ( f"{get_sys_executable()} -m prefect.engine" if self.command == self._base_flow_run_command() else self.command @@ -134,10 +137,12 @@ class ProcessWorkerResult(BaseWorkerResult): """Contains information about the final state of a completed process""" -class ProcessWorker(BaseWorker): +class ProcessWorker( + BaseWorker[ProcessJobConfiguration, ProcessVariables, ProcessWorkerResult] +): type = "process" - job_configuration = ProcessJobConfiguration - job_configuration_variables = ProcessVariables + job_configuration: type[ProcessJobConfiguration] = ProcessJobConfiguration + job_configuration_variables: type[ProcessVariables] | None = ProcessVariables _description = ( "Execute flow runs as subprocesses on a worker. Works well for local execution" @@ -152,7 +157,7 @@ async def start( run_once: bool = False, with_healthcheck: bool = False, printer: Callable[..., None] = print, - ): + ) -> None: """ Starts the worker and runs the main worker loops. @@ -241,8 +246,8 @@ async def run( self, flow_run: FlowRun, configuration: ProcessJobConfiguration, - task_status: Optional[anyio.abc.TaskStatus] = None, - ): + task_status: Optional[anyio.abc.TaskStatus[int]] = None, + ) -> ProcessWorkerResult: command = configuration.command if not command: command = f"{get_sys_executable()} -m prefect.engine" @@ -322,7 +327,7 @@ async def kill_process( self, infrastructure_pid: str, grace_seconds: int = 30, - ): + ) -> None: hostname, pid = _parse_infrastructure_pid(infrastructure_pid) if hostname != socket.gethostname(): @@ -372,7 +377,7 @@ async def kill_process( # process ended right after the check above. return - async def check_for_cancelled_flow_runs(self): + async def check_for_cancelled_flow_runs(self) -> list["FlowRun"]: if not self.is_setup: raise RuntimeError( "Worker is not set up. Please make sure you are running this worker " @@ -429,7 +434,7 @@ async def check_for_cancelled_flow_runs(self): return cancelling_flow_runs - async def cancel_run(self, flow_run: "FlowRun"): + async def cancel_run(self, flow_run: "FlowRun") -> None: run_logger = self.get_flow_run_logger(flow_run) try: diff --git a/src/prefect/workers/server.py b/src/prefect/workers/server.py index e513df3c22ef..a073621afc9a 100644 --- a/src/prefect/workers/server.py +++ b/src/prefect/workers/server.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Any import uvicorn import uvicorn.server @@ -10,14 +10,13 @@ PREFECT_WORKER_WEBSERVER_PORT, ) from prefect.workers.base import BaseWorker -from prefect.workers.process import ProcessWorker def build_healthcheck_server( - worker: Union[BaseWorker, ProcessWorker], + worker: BaseWorker[Any, Any, Any], query_interval_seconds: float, log_level: str = "error", -): +) -> uvicorn.Server: """ Build a healthcheck FastAPI server for a worker. @@ -54,7 +53,7 @@ def perform_health_check(): def start_healthcheck_server( - worker: Union[BaseWorker, ProcessWorker], + worker: BaseWorker[Any, Any, Any], query_interval_seconds: float, log_level: str = "error", ) -> None: diff --git a/tests/infrastructure/provisioners/test_coiled.py b/tests/infrastructure/provisioners/test_coiled.py new file mode 100644 index 000000000000..eca5cdceb11f --- /dev/null +++ b/tests/infrastructure/provisioners/test_coiled.py @@ -0,0 +1,186 @@ +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import UUID + +import pytest + +from prefect.blocks.core import Block +from prefect.client.orchestration import PrefectClient +from prefect.infrastructure.provisioners.coiled import CoiledPushProvisioner + + +@pytest.fixture(autouse=True) +async def coiled_credentials_block_cls(): + class MockCoiledCredentials(Block): + _block_type_name = "Coiled Credentials" + api_token: str + + await MockCoiledCredentials.register_type_and_schema() + + return MockCoiledCredentials + + +@pytest.fixture +async def coiled_credentials_block_id(coiled_credentials_block_cls: Block): + block_doc_id = await coiled_credentials_block_cls(api_token="existing_token").save( + "work-pool-name-coiled-credentials", overwrite=True + ) + + return block_doc_id + + +@pytest.fixture +def mock_run_process(): + with patch("prefect.infrastructure.provisioners.coiled.run_process") as mock: + yield mock + + +@pytest.fixture +def mock_coiled(monkeypatch): + mock = MagicMock() + monkeypatch.setattr("prefect.infrastructure.provisioners.coiled.coiled", mock) + yield mock + + +@pytest.fixture +def mock_importlib(): + with patch("prefect.infrastructure.provisioners.coiled.importlib") as mock: + yield mock + + +@pytest.fixture +def mock_confirm(): + with patch("prefect.infrastructure.provisioners.coiled.Confirm") as mock: + yield mock + + +@pytest.fixture +def mock_dask_config(): + with patch( + "prefect.infrastructure.provisioners.coiled.CoiledPushProvisioner._get_coiled_token" + ) as mock: + mock.return_value = "local-api-token-from-dask-config" + yield mock + + +async def test_provision( + prefect_client: PrefectClient, + mock_run_process: AsyncMock, + mock_coiled: MagicMock, + mock_dask_config: MagicMock, + mock_confirm: MagicMock, + mock_importlib: MagicMock, +): + """ + Test provision from a clean slate: + - Coiled is not installed + - Coiled token does not exist + - CoiledCredentials block does not exist + """ + provisioner = CoiledPushProvisioner() + provisioner.console.is_interactive = True + + mock_confirm.ask.side_effect = [ + True, + True, + True, + ] # confirm provision, install coiled, create new token + mock_importlib.import_module.side_effect = [ + ModuleNotFoundError, + mock_coiled, + mock_coiled, + ] + # simulate coiled token creation + mock_coiled.config.Config.return_value.get.side_effect = [ + None, + None, + "mock_token", + ] + + work_pool_name = "work-pool-name" + base_job_template = {"variables": {"properties": {"credentials": {}}}} + + result = await provisioner.provision( + work_pool_name, base_job_template, client=prefect_client + ) + + # Check if the block document exists and has expected values + block_document = await prefect_client.read_block_document_by_name( + "work-pool-name-coiled-credentials", "coiled-credentials" + ) + + assert block_document.data["api_token"], str == "mock_token" + + # Check if the base job template was updated + assert result["variables"]["properties"]["credentials"] == { + "default": {"$ref": {"block_document_id": str(block_document.id)}}, + } + + +async def test_provision_existing_coiled_credentials_block( + prefect_client: PrefectClient, + coiled_credentials_block_id: UUID, + mock_run_process: AsyncMock, +): + """ + Test provision with an existing CoiledCredentials block. + """ + provisioner = CoiledPushProvisioner() + + work_pool_name = "work-pool-name" + base_job_template = {"variables": {"properties": {"credentials": {}}}} + + result = await provisioner.provision( + work_pool_name, base_job_template, client=prefect_client + ) + + # Check if the base job template was updated + assert result["variables"]["properties"]["credentials"] == { + "default": {"$ref": {"block_document_id": str(coiled_credentials_block_id)}}, + } + + mock_run_process.assert_not_called() + + +async def test_provision_existing_coiled_credentials( + prefect_client: PrefectClient, + mock_run_process: AsyncMock, + mock_coiled: MagicMock, + mock_dask_config: MagicMock, + mock_confirm: MagicMock, + mock_importlib: MagicMock, +): + """ + Test provision where the user has coiled installed and an existing Coiled configuration. + """ + provisioner = CoiledPushProvisioner() + mock_confirm.ask.side_effect = [ + True, + ] # confirm provision + mock_importlib.import_module.side_effect = [ + mock_coiled, + mock_coiled, + ] # coiled is already installed + mock_coiled.config.Config.return_value.get.side_effect = [ + "mock_token", + ] # coiled config exists + + work_pool_name = "work-pool-name" + base_job_template = {"variables": {"properties": {"credentials": {}}}} + + result = await provisioner.provision( + work_pool_name, base_job_template, client=prefect_client + ) + + # Check if the block document exists and has expected values + block_document = await prefect_client.read_block_document_by_name( + "work-pool-name-coiled-credentials", "coiled-credentials" + ) + + assert block_document.data["api_token"], str == "mock_token" + + # Check if the base job template was updated + assert result["variables"]["properties"]["credentials"] == { + "default": {"$ref": {"block_document_id": str(block_document.id)}}, + } + + mock_run_process.assert_not_called() diff --git a/tests/utilities/test_pydantic.py b/tests/utilities/test_pydantic.py index d867a7a3ed38..10bed971b300 100644 --- a/tests/utilities/test_pydantic.py +++ b/tests/utilities/test_pydantic.py @@ -1,3 +1,4 @@ +import warnings from pathlib import Path import cloudpickle @@ -6,12 +7,15 @@ from prefect.utilities.dispatch import register_type from prefect.utilities.pydantic import ( - JsonPatch, PartialModel, add_cloudpickle_reduction, get_class_fields_only, ) +with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + from prefect.utilities.pydantic import JsonPatch + class SimplePydantic(BaseModel): x: int diff --git a/ui-v2/src/api/prefect.ts b/ui-v2/src/api/prefect.ts index 738414555336..5bf741b1fea9 100644 --- a/ui-v2/src/api/prefect.ts +++ b/ui-v2/src/api/prefect.ts @@ -8487,7 +8487,9 @@ export interface components { /** @description The current task run state. */ state?: components["schemas"]["State"] | null; }; - TaskRunCount: Record; + TaskRunCount: { + [key: string]: number; + }; /** * TaskRunCreate * @description Data used by the Prefect REST API to create a task run @@ -15409,7 +15411,7 @@ export interface operations { [name: string]: unknown; }; content: { - "application/json": unknown; + "application/json": Record; }; }; /** @description Validation Error */ @@ -15671,7 +15673,7 @@ export interface operations { [name: string]: unknown; }; content: { - "application/json": unknown; + "application/json": string; }; }; /** @description Validation Error */ diff --git a/ui/src/pages/Unauthenticated.vue b/ui/src/pages/Unauthenticated.vue index faa79ee9fa3f..5b3d4bf7e1a7 100644 --- a/ui/src/pages/Unauthenticated.vue +++ b/ui/src/pages/Unauthenticated.vue @@ -8,7 +8,7 @@ => { try { localStorage.setItem('prefect-password', btoa(password.value)) - router.push(props.redirect || '/') + api.admin.authCheck().then(status_code => { + if (status_code == 401) { + localStorage.removeItem('prefect-pasword') + showToast('Authentication failed.', 'error', { timeout: false }) + if (router.currentRoute.value.name !== 'login') { + router.push({ + name: 'login', + query: { redirect: router.currentRoute.value.fullPath } + }) + } + } else { + api.health.isHealthy().then(healthy => { + if (!healthy) { + showToast(`Can't connect to Server API at ${config.baseUrl}. Check that it's accessible from your machine.`, 'error', { timeout: false }) + } + router.push(props.redirect || '/') + }) + } + }) } catch (e) { localStorage.removeItem('prefect-password') error.value = 'Invalid password'