diff --git a/.github/workflows/static-analysis.yaml b/.github/workflows/static-analysis.yaml
index 06bdb8b2cc92..e8fe3c2f3ab9 100644
--- a/.github/workflows/static-analysis.yaml
+++ b/.github/workflows/static-analysis.yaml
@@ -66,63 +66,4 @@ jobs:
- name: Run pre-commit
run: |
- pre-commit run --show-diff-on-failure --color=always --all-files
-
- type-completeness-check:
- name: Type completeness check
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v4
- with:
- persist-credentials: false
- fetch-depth: 0
-
- - name: Set up uv
- uses: astral-sh/setup-uv@v5
- with:
- python-version: "3.12"
-
- - name: Calculate type completeness score
- id: calculate_current_score
- run: |
- # `pyright` will exit with a non-zero status code if it finds any issues,
- # so we need to explicitly ignore the exit code with `|| true`.
- uv tool run --with-editable . pyright --verifytypes prefect --ignoreexternal --outputjson > prefect-analysis.json || true
- SCORE=$(jq -r '.typeCompleteness.completenessScore' prefect-analysis.json)
- echo "current_score=$SCORE" >> $GITHUB_OUTPUT
-
- - name: Checkout base branch
- run: |
- git checkout ${{ github.base_ref }}
-
- - name: Calculate base branch score
- id: calculate_base_score
- run: |
- uv tool run --with-editable . pyright --verifytypes prefect --ignoreexternal --outputjson > prefect-analysis-base.json || true
- BASE_SCORE=$(jq -r '.typeCompleteness.completenessScore' prefect-analysis-base.json)
- echo "base_score=$BASE_SCORE" >> $GITHUB_OUTPUT
-
- - name: Compare scores
- run: |
- CURRENT_SCORE=$(echo ${{ steps.calculate_current_score.outputs.current_score }})
- BASE_SCORE=$(echo ${{ steps.calculate_base_score.outputs.base_score }})
-
- if (( $(echo "$BASE_SCORE > $CURRENT_SCORE" | bc -l) )); then
- echo "::notice title=Type Completeness Check::We noticed a decrease in type coverage with these changes. Check workflow summary for more details."
- echo "### ℹ️ Type Completeness Check" >> $GITHUB_STEP_SUMMARY
- echo "We noticed a decrease in type coverage with these changes. To maintain our codebase quality, we aim to keep or improve type coverage with each change." >> $GITHUB_STEP_SUMMARY
- echo "" >> $GITHUB_STEP_SUMMARY
- echo "Need help? Ping @desertaxle or @zzstoatzz for assistance!" >> $GITHUB_STEP_SUMMARY
- echo "" >> $GITHUB_STEP_SUMMARY
- echo "Here's what changed:" >> $GITHUB_STEP_SUMMARY
- echo "" >> $GITHUB_STEP_SUMMARY
- uv run scripts/pyright_diff.py prefect-analysis-base.json prefect-analysis.json >> $GITHUB_STEP_SUMMARY
- SCORE_DIFF=$(echo "$BASE_SCORE - $CURRENT_SCORE" | bc -l)
- if (( $(echo "$SCORE_DIFF > 0.001" | bc -l) )); then
- exit 1
- fi
- elif (( $(echo "$BASE_SCORE < $CURRENT_SCORE" | bc -l) )); then
- echo "🎉 Great work! The type coverage has improved with these changes" >> $GITHUB_STEP_SUMMARY
- else
- echo "✅ Type coverage maintained" >> $GITHUB_STEP_SUMMARY
- fi
+ pre-commit run --show-diff-on-failure --color=always --all-files
\ No newline at end of file
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 2773300caa3a..eda214b85bc9 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -32,6 +32,11 @@ repos:
)$
- repo: local
hooks:
+ - id: type-completeness-check
+ name: Type Completeness Check
+ language: system
+ entry: uv run --with pyright pyright --ignoreexternal --verifytypes prefect
+ pass_filenames: false
- id: generate-mintlify-openapi-docs
name: Generating OpenAPI docs for Mintlify
language: system
diff --git a/README.md b/README.md
index 11dfb13c0366..3c8e7cbce6ed 100644
--- a/README.md
+++ b/README.md
@@ -16,6 +16,28 @@
+
+
+ Installation
+
+ ·
+
+ Quickstart
+
+ ·
+
+ Build workflows
+
+ ·
+
+ Deploy workflows
+
+ ·
+
+ Prefect Cloud
+
+
+
# Prefect
Prefect is a workflow orchestration framework for building data pipelines in Python.
@@ -26,6 +48,11 @@ With just a few lines of code, data teams can confidently automate any data proc
Workflow activity is tracked and can be monitored with a self-hosted [Prefect server](https://docs.prefect.io/latest/manage/self-host/?utm_source=oss&utm_medium=oss&utm_campaign=oss_gh_repo&utm_term=none&utm_content=none) instance or managed [Prefect Cloud](https://www.prefect.io/cloud-vs-oss?utm_source=oss&utm_medium=oss&utm_campaign=oss_gh_repo&utm_term=none&utm_content=none) dashboard.
+> [!TIP]
+> Prefect flows can handle retries, dependencies, and even complex branching logic
+>
+> [Check our docs](https://docs.prefect.io/v3/get-started/index?utm_source=oss&utm_medium=oss&utm_campaign=oss_gh_repo&utm_term=none&utm_content=none) or see the example below to learn more!
+
## Getting started
Prefect requires Python 3.9 or later. To [install the latest or upgrade to the latest version of Prefect](https://docs.prefect.io/get-started/install), run the following command:
@@ -79,8 +106,12 @@ if __name__ == "__main__":
You now have a process running locally that is looking for scheduled deployments!
Additionally you can run your workflow manually from the UI or CLI. You can even run deployments in response to [events](https://docs.prefect.io/latest/automate/?utm_source=oss&utm_medium=oss&utm_campaign=oss_gh_repo&utm_term=none&utm_content=none).
-> [!NOTE]
-> To explore different infrastructure options for your workflows, check out the [deployment documentation](https://docs.prefect.io/v3/deploy).
+> [!TIP]
+> Where to go next - check out our [documentation](https://docs.prefect.io/v3/get-started/index?utm_source=oss&utm_medium=oss&utm_campaign=oss_gh_repo&utm_term=none&utm_content=none) to learn more about:
+> - [Deploying flows to production environments](https://docs.prefect.io/v3/deploy?utm_source=oss&utm_medium=oss&utm_campaign=oss_gh_repo&utm_term=none&utm_content=none)
+> - [Adding error handling and retries](https://docs.prefect.io/v3/develop/write-tasks#retries?utm_source=oss&utm_medium=oss&utm_campaign=oss_gh_repo&utm_term=none&utm_content=none)
+> - [Integrating with your existing tools](https://docs.prefect.io/integrations/integrations?utm_source=oss&utm_medium=oss&utm_campaign=oss_gh_repo&utm_term=none&utm_content=none)
+> - [Setting up team collaboration features](https://docs.prefect.io/v3/manage/cloud/manage-users/manage-teams#manage-teams?utm_source=oss&utm_medium=oss&utm_campaign=oss_gh_repo&utm_term=none&utm_content=none)
## Prefect Cloud
diff --git a/docs/v3/api-ref/rest-api/server/schema.json b/docs/v3/api-ref/rest-api/server/schema.json
index 036dc192dbff..3e4acc39393f 100644
--- a/docs/v3/api-ref/rest-api/server/schema.json
+++ b/docs/v3/api-ref/rest-api/server/schema.json
@@ -9355,7 +9355,10 @@
"description": "Successful Response",
"content": {
"application/json": {
- "schema": {}
+ "schema": {
+ "type": "object",
+ "title": "Response Validate Obj Ui Schemas Validate Post"
+ }
}
}
},
@@ -9727,7 +9730,10 @@
"description": "Successful Response",
"content": {
"application/json": {
- "schema": {}
+ "schema": {
+ "type": "string",
+ "title": "Response Hello Hello Get"
+ }
}
}
},
@@ -22594,6 +22600,9 @@
"description": "An ORM representation of task run data."
},
"TaskRunCount": {
+ "additionalProperties": {
+ "type": "integer"
+ },
"type": "object"
},
"TaskRunCreate": {
diff --git a/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run.py b/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run.py
index 298386ef168d..0b7ab08d6ae9 100644
--- a/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run.py
+++ b/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run.py
@@ -166,12 +166,12 @@
from google.api_core.client_options import ClientOptions
from googleapiclient import discovery
from googleapiclient.discovery import Resource
+from jsonpatch import JsonPatch
from pydantic import Field, field_validator
from prefect.logging.loggers import PrefectLogAdapter
from prefect.utilities.asyncutils import run_sync_in_worker_thread
from prefect.utilities.dockerutils import get_prefect_image_name
-from prefect.utilities.pydantic import JsonPatch
from prefect.workers.base import (
BaseJobConfiguration,
BaseVariables,
diff --git a/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run_v2.py b/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run_v2.py
index 0d3b6989378a..348ad84fa35a 100644
--- a/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run_v2.py
+++ b/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run_v2.py
@@ -11,12 +11,12 @@
# noinspection PyProtectedMember
from googleapiclient.discovery import Resource
from googleapiclient.errors import HttpError
+from jsonpatch import JsonPatch
from pydantic import Field, PrivateAttr, field_validator
from prefect.logging.loggers import PrefectLogAdapter
from prefect.utilities.asyncutils import run_sync_in_worker_thread
from prefect.utilities.dockerutils import get_prefect_image_name
-from prefect.utilities.pydantic import JsonPatch
from prefect.workers.base import (
BaseJobConfiguration,
BaseVariables,
diff --git a/src/integrations/prefect-gcp/prefect_gcp/workers/vertex.py b/src/integrations/prefect-gcp/prefect_gcp/workers/vertex.py
index d11efff4c3d0..4f92c7a195c1 100644
--- a/src/integrations/prefect-gcp/prefect_gcp/workers/vertex.py
+++ b/src/integrations/prefect-gcp/prefect_gcp/workers/vertex.py
@@ -28,11 +28,11 @@
from uuid import uuid4
import anyio
+from jsonpatch import JsonPatch
from pydantic import Field, field_validator
from slugify import slugify
from prefect.logging.loggers import PrefectLogAdapter
-from prefect.utilities.pydantic import JsonPatch
from prefect.workers.base import (
BaseJobConfiguration,
BaseVariables,
diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py
index 0d3f58637d66..2f83394816cc 100644
--- a/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py
+++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py
@@ -121,6 +121,7 @@
import aiohttp
import anyio.abc
import kubernetes_asyncio
+from jsonpatch import JsonPatch
from kubernetes_asyncio import config
from kubernetes_asyncio.client import (
ApiClient,
@@ -148,7 +149,6 @@
from prefect.server.schemas.core import Flow
from prefect.server.schemas.responses import DeploymentResponse
from prefect.utilities.dockerutils import get_prefect_image_name
-from prefect.utilities.pydantic import JsonPatch
from prefect.utilities.templating import find_placeholders
from prefect.utilities.timeout import timeout_async
from prefect.workers.base import (
diff --git a/src/prefect/_internal/schemas/validators.py b/src/prefect/_internal/schemas/validators.py
index 4bbc9c7de15a..879774fb297e 100644
--- a/src/prefect/_internal/schemas/validators.py
+++ b/src/prefect/_internal/schemas/validators.py
@@ -6,6 +6,8 @@
This will be subject to consolidation and refactoring over the next few months.
"""
+from __future__ import annotations
+
import os
import re
import urllib.parse
@@ -627,18 +629,18 @@ def validate_name_present_on_nonanonymous_blocks(values: M) -> M:
@overload
-def validate_command(v: str) -> Path:
+def validate_working_dir(v: str) -> Path:
...
@overload
-def validate_command(v: None) -> None:
+def validate_working_dir(v: None) -> None:
...
-def validate_command(v: Optional[str]) -> Optional[Path]:
+def validate_working_dir(v: Optional[Path | str]) -> Optional[Path]:
"""Make sure that the working directory is formatted for the current platform."""
- if v is not None:
+ if isinstance(v, str):
return relative_path_to_current_platform(v)
return v
diff --git a/src/prefect/artifacts.py b/src/prefect/artifacts.py
index b8903f960c8c..559093abb862 100644
--- a/src/prefect/artifacts.py
+++ b/src/prefect/artifacts.py
@@ -20,7 +20,10 @@
from prefect.utilities.asyncutils import sync_compatible
from prefect.utilities.context import get_task_and_flow_run_ids
-logger = get_logger("artifacts")
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger("artifacts")
if TYPE_CHECKING:
from prefect.client.orchestration import PrefectClient
diff --git a/src/prefect/automations.py b/src/prefect/automations.py
index 3201508220bb..6ec5192a4354 100644
--- a/src/prefect/automations.py
+++ b/src/prefect/automations.py
@@ -137,7 +137,7 @@ def create(self: Self) -> Self:
self.id = client.create_automation(automation=automation)
return self
- async def aupdate(self: Self):
+ async def aupdate(self: Self) -> None:
"""
Updates an existing automation.
diff --git a/src/prefect/blocks/abstract.py b/src/prefect/blocks/abstract.py
index ca29dd04b03e..495eaa67f446 100644
--- a/src/prefect/blocks/abstract.py
+++ b/src/prefect/blocks/abstract.py
@@ -15,7 +15,7 @@
Union,
)
-from typing_extensions import Self, TypeAlias
+from typing_extensions import TYPE_CHECKING, Self, TypeAlias
from prefect.blocks.core import Block
from prefect.exceptions import MissingContextError
@@ -26,7 +26,10 @@
if sys.version_info >= (3, 12):
LoggingAdapter = logging.LoggerAdapter[logging.Logger]
else:
- LoggingAdapter = logging.LoggerAdapter
+ if TYPE_CHECKING:
+ LoggingAdapter = logging.LoggerAdapter[logging.Logger]
+ else:
+ LoggingAdapter = logging.LoggerAdapter
LoggerOrAdapter: TypeAlias = Union[Logger, LoggingAdapter]
diff --git a/src/prefect/cli/dashboard.py b/src/prefect/cli/dashboard.py
index a7b2caea7145..47154f94cac8 100644
--- a/src/prefect/cli/dashboard.py
+++ b/src/prefect/cli/dashboard.py
@@ -8,7 +8,7 @@
from prefect.settings import PREFECT_UI_URL
from prefect.utilities.asyncutils import run_sync_in_worker_thread
-dashboard_app = PrefectTyper(
+dashboard_app: PrefectTyper = PrefectTyper(
name="dashboard",
help="Commands for interacting with the Prefect UI.",
)
@@ -16,7 +16,7 @@
@dashboard_app.command()
-async def open():
+async def open() -> None:
"""
Open the Prefect UI in the browser.
"""
diff --git a/src/prefect/cli/dev.py b/src/prefect/cli/dev.py
index 4a05e53bc398..c901d8fe717d 100644
--- a/src/prefect/cli/dev.py
+++ b/src/prefect/cli/dev.py
@@ -35,13 +35,13 @@
Note that many of these commands require extra dependencies (such as npm and MkDocs)
to function properly.
"""
-dev_app = PrefectTyper(
+dev_app: PrefectTyper = PrefectTyper(
name="dev", short_help="Internal Prefect development.", help=DEV_HELP
)
app.add_typer(dev_app)
-def exit_with_error_if_not_editable_install():
+def exit_with_error_if_not_editable_install() -> None:
if (
prefect.__module_path__.parent == "site-packages"
or not (prefect.__development_base_path__ / "setup.py").exists()
diff --git a/src/prefect/cli/events.py b/src/prefect/cli/events.py
index 1b84c8923947..e9e7b423c15c 100644
--- a/src/prefect/cli/events.py
+++ b/src/prefect/cli/events.py
@@ -15,7 +15,7 @@
get_events_subscriber,
)
-events_app = PrefectTyper(name="events", help="Stream events.")
+events_app: PrefectTyper = PrefectTyper(name="events", help="Stream events.")
app.add_typer(events_app, aliases=["event"])
@@ -60,7 +60,7 @@ async def stream(
handle_error(exc)
-async def handle_event(event: Event, format: StreamFormat, output_file: str):
+async def handle_event(event: Event, format: StreamFormat, output_file: str) -> None:
if format == StreamFormat.json:
event_data = orjson.dumps(event.model_dump(), default=str).decode()
elif format == StreamFormat.text:
@@ -74,7 +74,7 @@ async def handle_event(event: Event, format: StreamFormat, output_file: str):
print(event_data)
-def handle_error(exc):
+def handle_error(exc: Exception) -> None:
if isinstance(exc, websockets.exceptions.ConnectionClosedError):
exit_with_error(f"Connection closed, retrying... ({exc})")
elif isinstance(exc, (KeyboardInterrupt, asyncio.exceptions.CancelledError)):
diff --git a/src/prefect/cli/flow.py b/src/prefect/cli/flow.py
index acc37b4a24cd..735e1958b374 100644
--- a/src/prefect/cli/flow.py
+++ b/src/prefect/cli/flow.py
@@ -19,7 +19,7 @@
from prefect.runner import Runner
from prefect.utilities import urls
-flow_app = PrefectTyper(name="flow", help="View and serve flows.")
+flow_app: PrefectTyper = PrefectTyper(name="flow", help="View and serve flows.")
app.add_typer(flow_app, aliases=["flows"])
diff --git a/src/prefect/cli/flow_run.py b/src/prefect/cli/flow_run.py
index 41bb4fd4189f..63d862c081da 100644
--- a/src/prefect/cli/flow_run.py
+++ b/src/prefect/cli/flow_run.py
@@ -28,13 +28,15 @@
from prefect.runner import Runner
from prefect.states import State
-flow_run_app = PrefectTyper(name="flow-run", help="Interact with flow runs.")
+flow_run_app: PrefectTyper = PrefectTyper(
+ name="flow-run", help="Interact with flow runs."
+)
app.add_typer(flow_run_app, aliases=["flow-runs"])
LOGS_DEFAULT_PAGE_SIZE = 200
LOGS_WITH_LIMIT_FLAG_DEFAULT_NUM_LOGS = 20
-logger = get_logger(__name__)
+logger: "logging.Logger" = get_logger(__name__)
@flow_run_app.command()
diff --git a/src/prefect/cli/global_concurrency_limit.py b/src/prefect/cli/global_concurrency_limit.py
index a391fe327476..dff54dd4d6c3 100644
--- a/src/prefect/cli/global_concurrency_limit.py
+++ b/src/prefect/cli/global_concurrency_limit.py
@@ -22,7 +22,7 @@
PrefectHTTPStatusError,
)
-global_concurrency_limit_app = PrefectTyper(
+global_concurrency_limit_app: PrefectTyper = PrefectTyper(
name="global-concurrency-limit",
help="Manage global concurrency limits.",
)
diff --git a/src/prefect/cli/profile.py b/src/prefect/cli/profile.py
index ce74709ba5ca..55dbd720f85a 100644
--- a/src/prefect/cli/profile.py
+++ b/src/prefect/cli/profile.py
@@ -26,7 +26,9 @@
from prefect.settings import ProfilesCollection
from prefect.utilities.collections import AutoEnum
-profile_app = PrefectTyper(name="profile", help="Select and manage Prefect profiles.")
+profile_app: PrefectTyper = PrefectTyper(
+ name="profile", help="Select and manage Prefect profiles."
+)
app.add_typer(profile_app, aliases=["profiles"])
_OLD_MINIMAL_DEFAULT_PROFILE_CONTENT: str = """active = "default"
@@ -263,8 +265,8 @@ def inspect(
def show_profile_changes(
user_profiles: ProfilesCollection, default_profiles: ProfilesCollection
-):
- changes = []
+) -> bool:
+ changes: list[tuple[str, str]] = []
for name in default_profiles.names:
if name not in user_profiles:
@@ -343,7 +345,7 @@ class ConnectionStatus(AutoEnum):
INVALID_API = AutoEnum.auto()
-async def check_server_connection():
+async def check_server_connection() -> ConnectionStatus:
httpx_settings = dict(timeout=3)
try:
# attempt to infer Cloud 2.0 API from the connection URL
diff --git a/src/prefect/cli/root.py b/src/prefect/cli/root.py
index a6aa37ea7afd..1a3b9973e225 100644
--- a/src/prefect/cli/root.py
+++ b/src/prefect/cli/root.py
@@ -26,16 +26,16 @@
PREFECT_TEST_MODE,
)
-app = PrefectTyper(add_completion=True, no_args_is_help=True)
+app: PrefectTyper = PrefectTyper(add_completion=True, no_args_is_help=True)
-def version_callback(value: bool):
+def version_callback(value: bool) -> None:
if value:
print(prefect.__version__)
raise typer.Exit()
-def is_interactive():
+def is_interactive() -> bool:
return app.console.is_interactive
@@ -157,7 +157,7 @@ def get_prefect_integrations() -> Dict[str, str]:
return integrations
-def display(object: Dict[str, Any], nesting: int = 0):
+def display(object: Dict[str, Any], nesting: int = 0) -> None:
"""Recursive display of a dictionary with nesting."""
for key, value in object.items():
key += ":"
diff --git a/src/prefect/cli/server.py b/src/prefect/cli/server.py
index acbf56a52871..8b9996931bf1 100644
--- a/src/prefect/cli/server.py
+++ b/src/prefect/cli/server.py
@@ -15,6 +15,7 @@
import textwrap
from pathlib import Path
from types import ModuleType
+from typing import TYPE_CHECKING
import typer
import uvicorn
@@ -48,23 +49,30 @@
from prefect.settings.context import temporary_settings
from prefect.utilities.asyncutils import run_sync_in_worker_thread
-server_app = PrefectTyper(
+if TYPE_CHECKING:
+ import logging
+
+server_app: PrefectTyper = PrefectTyper(
name="server",
help="Start a Prefect server instance and interact with the database",
)
-database_app = PrefectTyper(name="database", help="Interact with the database.")
-services_app = PrefectTyper(name="services", help="Interact with server loop services.")
+database_app: PrefectTyper = PrefectTyper(
+ name="database", help="Interact with the database."
+)
+services_app: PrefectTyper = PrefectTyper(
+ name="services", help="Interact with server loop services."
+)
server_app.add_typer(database_app)
server_app.add_typer(services_app)
app.add_typer(server_app)
-logger = get_logger(__name__)
+logger: "logging.Logger" = get_logger(__name__)
SERVER_PID_FILE_NAME = "server.pid"
SERVICES_PID_FILE = Path(PREFECT_HOME.value()) / "services.pid"
-def generate_welcome_blurb(base_url: str, ui_enabled: bool):
+def generate_welcome_blurb(base_url: str, ui_enabled: bool) -> str:
blurb = textwrap.dedent(
r"""
___ ___ ___ ___ ___ ___ _____
@@ -109,7 +117,7 @@ def generate_welcome_blurb(base_url: str, ui_enabled: bool):
return blurb
-def prestart_check(base_url: str):
+def prestart_check(base_url: str) -> None:
"""
Check if `PREFECT_API_URL` is set in the current profile. If not, prompt the user to set it.
diff --git a/src/prefect/cli/shell.py b/src/prefect/cli/shell.py
index 6837e5d00950..968ddd2887f7 100644
--- a/src/prefect/cli/shell.py
+++ b/src/prefect/cli/shell.py
@@ -8,7 +8,7 @@
import subprocess
import sys
import threading
-from typing import Any, Dict, List, Optional
+from typing import IO, Any, Callable, Dict, List, Optional
import typer
from typing_extensions import Annotated
@@ -25,13 +25,13 @@
from prefect.settings import PREFECT_UI_URL
from prefect.types.entrypoint import EntrypointType
-shell_app = PrefectTyper(
+shell_app: PrefectTyper = PrefectTyper(
name="shell", help="Serve and watch shell commands as Prefect flows."
)
app.add_typer(shell_app)
-def output_stream(pipe, logger_function):
+def output_stream(pipe: IO[str], logger_function: Callable[[str], None]) -> None:
"""
Read from a pipe line by line and log using the provided logging function.
@@ -44,7 +44,7 @@ def output_stream(pipe, logger_function):
logger_function(line.strip())
-def output_collect(pipe, container):
+def output_collect(pipe: IO[str], container: list[str]) -> None:
"""
Collects output from a subprocess pipe and stores it in a container list.
diff --git a/src/prefect/cli/task.py b/src/prefect/cli/task.py
index bedc2b4f418e..21f059190778 100644
--- a/src/prefect/cli/task.py
+++ b/src/prefect/cli/task.py
@@ -8,7 +8,7 @@
from prefect.task_worker import serve as task_serve
from prefect.utilities.importtools import import_object
-task_app = PrefectTyper(name="task", help="Work with task scheduling.")
+task_app: PrefectTyper = PrefectTyper(name="task", help="Work with task scheduling.")
app.add_typer(task_app, aliases=["task"])
diff --git a/src/prefect/cli/task_run.py b/src/prefect/cli/task_run.py
index e0c35177fdec..b425b9a3c307 100644
--- a/src/prefect/cli/task_run.py
+++ b/src/prefect/cli/task_run.py
@@ -22,7 +22,9 @@
from prefect.client.schemas.sorting import LogSort, TaskRunSort
from prefect.exceptions import ObjectNotFound
-task_run_app = PrefectTyper(name="task-run", help="View and inspect task runs.")
+task_run_app: PrefectTyper = PrefectTyper(
+ name="task-run", help="View and inspect task runs."
+)
app.add_typer(task_run_app, aliases=["task-runs"])
LOGS_DEFAULT_PAGE_SIZE = 200
diff --git a/src/prefect/cli/variable.py b/src/prefect/cli/variable.py
index dfb1a7cac7d0..919ba3556eaf 100644
--- a/src/prefect/cli/variable.py
+++ b/src/prefect/cli/variable.py
@@ -14,7 +14,7 @@
from prefect.client.schemas.actions import VariableCreate, VariableUpdate
from prefect.exceptions import ObjectNotFound
-variable_app = PrefectTyper(name="variable", help="Manage variables.")
+variable_app: PrefectTyper = PrefectTyper(name="variable", help="Manage variables.")
app.add_typer(variable_app)
diff --git a/src/prefect/cli/work_pool.py b/src/prefect/cli/work_pool.py
index 2c9af2a40f52..ff528d959458 100644
--- a/src/prefect/cli/work_pool.py
+++ b/src/prefect/cli/work_pool.py
@@ -32,11 +32,11 @@
get_default_base_job_template_for_infrastructure_type,
)
-work_pool_app = PrefectTyper(name="work-pool", help="Manage work pools.")
+work_pool_app: PrefectTyper = PrefectTyper(name="work-pool", help="Manage work pools.")
app.add_typer(work_pool_app, aliases=["work-pool"])
-def set_work_pool_as_default(name: str):
+def set_work_pool_as_default(name: str) -> None:
profile = update_current_profile({"PREFECT_DEFAULT_WORK_POOL_NAME": name})
app.console.print(
f"Set {name!r} as default work pool for profile {profile.name!r}\n",
diff --git a/src/prefect/cli/work_queue.py b/src/prefect/cli/work_queue.py
index 6cabb8700f3a..069d8296b48a 100644
--- a/src/prefect/cli/work_queue.py
+++ b/src/prefect/cli/work_queue.py
@@ -20,7 +20,7 @@
from prefect.client.schemas.objects import DEFAULT_AGENT_WORK_POOL_NAME
from prefect.exceptions import ObjectAlreadyExists, ObjectNotFound
-work_app = PrefectTyper(name="work-queue", help="Manage work queues.")
+work_app: PrefectTyper = PrefectTyper(name="work-queue", help="Manage work queues.")
app.add_typer(work_app, aliases=["work-queues"])
diff --git a/src/prefect/cli/worker.py b/src/prefect/cli/worker.py
index 523787548eed..34718485ceff 100644
--- a/src/prefect/cli/worker.py
+++ b/src/prefect/cli/worker.py
@@ -28,7 +28,9 @@
)
from prefect.workers.base import BaseWorker
-worker_app = PrefectTyper(name="worker", help="Start and interact with workers.")
+worker_app: PrefectTyper = PrefectTyper(
+ name="worker", help="Start and interact with workers."
+)
app.add_typer(worker_app)
diff --git a/src/prefect/client/utilities.py b/src/prefect/client/utilities.py
index 4622a7d6fe32..3aa7043333d3 100644
--- a/src/prefect/client/utilities.py
+++ b/src/prefect/client/utilities.py
@@ -5,7 +5,7 @@
# This module must not import from `prefect.client` when it is imported to avoid
# circular imports for decorators such as `inject_client` which are widely used.
-from collections.abc import Awaitable, Coroutine
+from collections.abc import Coroutine
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
@@ -61,8 +61,8 @@ def get_or_create_client(
def client_injector(
- func: Callable[Concatenate["PrefectClient", P], Awaitable[R]],
-) -> Callable[P, Awaitable[R]]:
+ func: Callable[Concatenate["PrefectClient", P], Coroutine[Any, Any, R]],
+) -> Callable[P, Coroutine[Any, Any, R]]:
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
client, _ = get_or_create_client()
diff --git a/src/prefect/deployments/base.py b/src/prefect/deployments/base.py
index 9ec16a909fff..5700b8c34848 100644
--- a/src/prefect/deployments/base.py
+++ b/src/prefect/deployments/base.py
@@ -19,6 +19,7 @@
from prefect.client.schemas.objects import ConcurrencyLimitStrategy
from prefect.client.schemas.schedules import IntervalSchedule
from prefect.utilities._git import get_git_branch, get_git_remote_origin_url
+from prefect.utilities.annotations import NotSet
from prefect.utilities.filesystem import create_default_ignore_file
from prefect.utilities.templating import apply_values
@@ -113,7 +114,9 @@ def create_default_prefect_yaml(
return True
-def configure_project_by_recipe(recipe: str, **formatting_kwargs) -> dict:
+def configure_project_by_recipe(
+ recipe: str, **formatting_kwargs: Any
+) -> dict[str, Any] | type[NotSet]:
"""
Given a recipe name, returns a dictionary representing base configuration options.
@@ -131,13 +134,13 @@ def configure_project_by_recipe(recipe: str, **formatting_kwargs) -> dict:
raise ValueError(f"Unknown recipe {recipe!r} provided.")
with recipe_path.open(mode="r") as f:
- config = yaml.safe_load(f)
+ config: dict[str, Any] = yaml.safe_load(f)
- config = apply_values(
+ templated_config = apply_values(
template=config, values=formatting_kwargs, remove_notset=False
)
- return config
+ return templated_config
def initialize_project(
diff --git a/src/prefect/deployments/flow_runs.py b/src/prefect/deployments/flow_runs.py
index e54615a96952..169539df574f 100644
--- a/src/prefect/deployments/flow_runs.py
+++ b/src/prefect/deployments/flow_runs.py
@@ -29,7 +29,11 @@
}
)
-logger = get_logger(__name__)
+
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger(__name__)
@sync_compatible
diff --git a/src/prefect/deployments/steps/core.py b/src/prefect/deployments/steps/core.py
index ef6118b297a9..4efc15dbcaf6 100644
--- a/src/prefect/deployments/steps/core.py
+++ b/src/prefect/deployments/steps/core.py
@@ -141,7 +141,7 @@ async def run_steps(
steps: List[Dict[str, Any]],
upstream_outputs: Optional[Dict[str, Any]] = None,
print_function: Any = print,
-):
+) -> dict[str, Any]:
upstream_outputs = deepcopy(upstream_outputs) if upstream_outputs else {}
for step in steps:
if not step:
diff --git a/src/prefect/deployments/steps/pull.py b/src/prefect/deployments/steps/pull.py
index 1bc0ec06d64f..54faa6f337ee 100644
--- a/src/prefect/deployments/steps/pull.py
+++ b/src/prefect/deployments/steps/pull.py
@@ -12,7 +12,10 @@
from prefect.runner.storage import BlockStorageAdapter, GitRepository, RemoteStorage
from prefect.utilities.asyncutils import run_coro_as_sync
-deployment_logger = get_logger("deployment")
+if TYPE_CHECKING:
+ import logging
+
+deployment_logger: "logging.Logger" = get_logger("deployment")
if TYPE_CHECKING:
from prefect.blocks.core import Block
@@ -197,7 +200,7 @@ def git_clone(
return dict(directory=str(storage.destination.relative_to(Path.cwd())))
-async def pull_from_remote_storage(url: str, **settings: Any):
+async def pull_from_remote_storage(url: str, **settings: Any) -> dict[str, Any]:
"""
Pulls code from a remote storage location into the current working directory.
@@ -239,7 +242,9 @@ async def pull_from_remote_storage(url: str, **settings: Any):
return {"directory": directory}
-async def pull_with_block(block_document_name: str, block_type_slug: str):
+async def pull_with_block(
+ block_document_name: str, block_type_slug: str
+) -> dict[str, Any]:
"""
Pulls code using a block.
diff --git a/src/prefect/deployments/steps/utility.py b/src/prefect/deployments/steps/utility.py
index 21e2436858d0..53f8529ff387 100644
--- a/src/prefect/deployments/steps/utility.py
+++ b/src/prefect/deployments/steps/utility.py
@@ -26,7 +26,7 @@
import string
import subprocess
import sys
-from typing import Dict, Optional
+from typing import Any, Dict, Optional
from anyio import create_task_group
from anyio.streams.text import TextReceiveStream
@@ -205,7 +205,7 @@ async def pip_install_requirements(
directory: Optional[str] = None,
requirements_file: str = "requirements.txt",
stream_output: bool = True,
-):
+) -> dict[str, Any]:
"""
Installs dependencies from a requirements.txt file.
diff --git a/src/prefect/docker/docker_image.py b/src/prefect/docker/docker_image.py
index c58442cd0e94..a2de9980d52a 100644
--- a/src/prefect/docker/docker_image.py
+++ b/src/prefect/docker/docker_image.py
@@ -1,5 +1,5 @@
from pathlib import Path
-from typing import Optional
+from typing import Any, Optional
from pendulum import now as pendulum_now
@@ -34,7 +34,11 @@ class DockerImage:
"""
def __init__(
- self, name: str, tag: Optional[str] = None, dockerfile="auto", **build_kwargs
+ self,
+ name: str,
+ tag: Optional[str] = None,
+ dockerfile: str = "auto",
+ **build_kwargs: Any,
):
image_name, image_tag = parse_image_tag(name)
if tag and image_tag:
@@ -49,16 +53,16 @@ def __init__(
namespace = PREFECT_DEFAULT_DOCKER_BUILD_NAMESPACE.value()
# join the namespace and repository to create the full image name
# ignore namespace if it is None
- self.name = "/".join(filter(None, [namespace, repository]))
- self.tag = tag or image_tag or slugify(pendulum_now("utc").isoformat())
- self.dockerfile = dockerfile
- self.build_kwargs = build_kwargs
+ self.name: str = "/".join(filter(None, [namespace, repository]))
+ self.tag: str = tag or image_tag or slugify(pendulum_now("utc").isoformat())
+ self.dockerfile: str = dockerfile
+ self.build_kwargs: dict[str, Any] = build_kwargs
@property
- def reference(self):
+ def reference(self) -> str:
return f"{self.name}:{self.tag}"
- def build(self):
+ def build(self) -> None:
full_image_name = self.reference
build_kwargs = self.build_kwargs.copy()
build_kwargs["context"] = Path.cwd()
@@ -72,7 +76,7 @@ def build(self):
build_kwargs["dockerfile"] = self.dockerfile
build_image(**build_kwargs)
- def push(self):
+ def push(self) -> None:
with docker_client() as client:
events = client.api.push(
repository=self.name, tag=self.tag, stream=True, decode=True
diff --git a/src/prefect/engine.py b/src/prefect/engine.py
index 262fdaa67071..a2829a7a2001 100644
--- a/src/prefect/engine.py
+++ b/src/prefect/engine.py
@@ -1,6 +1,6 @@
import os
import sys
-from typing import Any, Callable
+from typing import TYPE_CHECKING, Any, Callable
from uuid import UUID
from prefect._internal.compatibility.migration import getattr_migration
@@ -15,12 +15,19 @@
run_coro_as_sync,
)
-engine_logger = get_logger("engine")
+if TYPE_CHECKING:
+ import logging
+
+ from prefect.flow_engine import FlowRun
+ from prefect.flows import Flow
+ from prefect.logging.loggers import LoggingAdapter
+
+engine_logger: "logging.Logger" = get_logger("engine")
if __name__ == "__main__":
try:
- flow_run_id = UUID(
+ flow_run_id: UUID = UUID(
sys.argv[1] if len(sys.argv) > 1 else os.environ.get("PREFECT__FLOW_RUN_ID")
)
except Exception:
@@ -37,11 +44,11 @@
run_flow,
)
- flow_run = load_flow_run(flow_run_id=flow_run_id)
- run_logger = flow_run_logger(flow_run=flow_run)
+ flow_run: "FlowRun" = load_flow_run(flow_run_id=flow_run_id)
+ run_logger: "LoggingAdapter" = flow_run_logger(flow_run=flow_run)
try:
- flow = load_flow(flow_run)
+ flow: "Flow[..., Any]" = load_flow(flow_run)
except Exception:
run_logger.error(
"Unexpected exception encountered when trying to load flow",
@@ -55,15 +62,17 @@
else:
run_flow(flow, flow_run=flow_run, error_logger=run_logger)
- except Abort as exc:
+ except Abort as abort_signal:
+ abort_signal: Abort
engine_logger.info(
f"Engine execution of flow run '{flow_run_id}' aborted by orchestrator:"
- f" {exc}"
+ f" {abort_signal}"
)
exit(0)
- except Pause as exc:
+ except Pause as pause_signal:
+ pause_signal: Pause
engine_logger.info(
- f"Engine execution of flow run '{flow_run_id}' is paused: {exc}"
+ f"Engine execution of flow run '{flow_run_id}' is paused: {pause_signal}"
)
exit(0)
except Exception:
diff --git a/src/prefect/events/cli/automations.py b/src/prefect/events/cli/automations.py
index b4ce5dbac804..c70c2ec34c05 100644
--- a/src/prefect/events/cli/automations.py
+++ b/src/prefect/events/cli/automations.py
@@ -3,7 +3,7 @@
"""
import functools
-from typing import Optional, Type
+from typing import Any, Callable, Optional, Type
from uuid import UUID
import orjson
@@ -21,16 +21,16 @@
from prefect.events.schemas.automations import Automation
from prefect.exceptions import PrefectHTTPStatusError
-automations_app = PrefectTyper(
+automations_app: PrefectTyper = PrefectTyper(
name="automation",
help="Manage automations.",
)
app.add_typer(automations_app, aliases=["automations"])
-def requires_automations(func):
+def requires_automations(func: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(func)
- async def wrapper(*args, **kwargs):
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
try:
return await func(*args, **kwargs)
except RuntimeError as exc:
diff --git a/src/prefect/events/clients.py b/src/prefect/events/clients.py
index 41b83ee9f41e..3ce243ee9337 100644
--- a/src/prefect/events/clients.py
+++ b/src/prefect/events/clients.py
@@ -70,18 +70,21 @@
labelnames=["client"],
)
-logger = get_logger(__name__)
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger(__name__)
-def http_to_ws(url: str):
+def http_to_ws(url: str) -> str:
return url.replace("https://", "wss://").replace("http://", "ws://").rstrip("/")
-def events_in_socket_from_api_url(url: str):
+def events_in_socket_from_api_url(url: str) -> str:
return http_to_ws(url) + "/events/in"
-def events_out_socket_from_api_url(url: str):
+def events_out_socket_from_api_url(url: str) -> str:
return http_to_ws(url) + "/events/out"
@@ -250,11 +253,11 @@ class AssertingEventsClient(EventsClient):
last: ClassVar["Optional[AssertingEventsClient]"] = None
all: ClassVar[List["AssertingEventsClient"]] = []
- args: Tuple
- kwargs: Dict[str, Any]
- events: List[Event]
+ args: tuple[Any, ...]
+ kwargs: dict[str, Any]
+ events: list[Event]
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args: Any, **kwargs: Any):
AssertingEventsClient.last = self
AssertingEventsClient.all.append(self)
self.args = args
@@ -431,13 +434,13 @@ class AssertingPassthroughEventsClient(PrefectEventsClient):
during tests AND sends them to a Prefect server."""
last: ClassVar["Optional[AssertingPassthroughEventsClient]"] = None
- all: ClassVar[List["AssertingPassthroughEventsClient"]] = []
+ all: ClassVar[list["AssertingPassthroughEventsClient"]] = []
- args: Tuple
- kwargs: Dict[str, Any]
- events: List[Event]
+ args: tuple[Any, ...]
+ kwargs: dict[str, Any]
+ events: list[Event]
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
AssertingPassthroughEventsClient.last = self
AssertingPassthroughEventsClient.all.append(self)
@@ -449,7 +452,7 @@ def reset(cls) -> None:
cls.last = None
cls.all = []
- def pop_events(self) -> List[Event]:
+ def pop_events(self) -> list[Event]:
events = self.events
self.events = []
return events
diff --git a/src/prefect/events/schemas/automations.py b/src/prefect/events/schemas/automations.py
index 97fe5c7e9a50..2eeea40214c6 100644
--- a/src/prefect/events/schemas/automations.py
+++ b/src/prefect/events/schemas/automations.py
@@ -52,7 +52,7 @@ def describe_for_cli(self, indent: int = 0) -> str:
_deployment_id: Optional[UUID] = PrivateAttr(default=None)
- def set_deployment_id(self, deployment_id: UUID):
+ def set_deployment_id(self, deployment_id: UUID) -> None:
self._deployment_id = deployment_id
def owner_resource(self) -> Optional[str]:
@@ -277,7 +277,7 @@ class MetricTriggerQuery(PrefectBaseModel):
)
@field_validator("range", "firing_for")
- def enforce_minimum_range(cls, value: timedelta):
+ def enforce_minimum_range(cls, value: timedelta) -> timedelta:
if value < timedelta(seconds=300):
raise ValueError("The minimum range is 300 seconds (5 minutes)")
return value
@@ -404,13 +404,17 @@ class AutomationCore(PrefectBaseModel, extra="ignore"): # type: ignore[call-arg
"""Defines an action a user wants to take when a certain number of events
do or don't happen to the matching resources"""
- name: str = Field(..., description="The name of this automation")
- description: str = Field("", description="A longer description of this automation")
+ name: str = Field(default=..., description="The name of this automation")
+ description: str = Field(
+ default="", description="A longer description of this automation"
+ )
- enabled: bool = Field(True, description="Whether this automation will be evaluated")
+ enabled: bool = Field(
+ default=True, description="Whether this automation will be evaluated"
+ )
trigger: TriggerTypes = Field(
- ...,
+ default=...,
description=(
"The criteria for which events this Automation covers and how it will "
"respond to the presence or absence of those events"
@@ -418,7 +422,7 @@ class AutomationCore(PrefectBaseModel, extra="ignore"): # type: ignore[call-arg
)
actions: List[ActionTypes] = Field(
- ...,
+ default=...,
description="The actions to perform when this Automation triggers",
)
@@ -438,4 +442,4 @@ class AutomationCore(PrefectBaseModel, extra="ignore"): # type: ignore[call-arg
class Automation(AutomationCore):
- id: UUID = Field(..., description="The ID of this automation")
+ id: UUID = Field(default=..., description="The ID of this automation")
diff --git a/src/prefect/events/schemas/events.py b/src/prefect/events/schemas/events.py
index eddb69e42742..f7d7e7dba57b 100644
--- a/src/prefect/events/schemas/events.py
+++ b/src/prefect/events/schemas/events.py
@@ -1,6 +1,7 @@
import copy
from collections import defaultdict
from typing import (
+ TYPE_CHECKING,
Any,
ClassVar,
Dict,
@@ -32,7 +33,10 @@
from .labelling import Labelled
-logger = get_logger(__name__)
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger(__name__)
class Resource(Labelled):
diff --git a/src/prefect/filesystems.py b/src/prefect/filesystems.py
index 20c0a45d23dd..aec69a916bd7 100644
--- a/src/prefect/filesystems.py
+++ b/src/prefect/filesystems.py
@@ -281,7 +281,7 @@ class RemoteFileSystem(WritableFileSystem, WritableDeploymentStorage):
_filesystem: fsspec.AbstractFileSystem = None
@field_validator("basepath")
- def check_basepath(cls, value):
+ def check_basepath(cls, value: str) -> str:
return validate_basepath(value)
def _resolve_path(self, path: str) -> str:
diff --git a/src/prefect/flow_engine.py b/src/prefect/flow_engine.py
index cbf50b9d17d3..359e7ff5f995 100644
--- a/src/prefect/flow_engine.py
+++ b/src/prefect/flow_engine.py
@@ -146,7 +146,7 @@ class BaseFlowRunEngine(Generic[P, R]):
_flow_run_name_set: bool = False
_telemetry: RunTelemetry = field(default_factory=RunTelemetry)
- def __post_init__(self):
+ def __post_init__(self) -> None:
if self.flow is None and self.flow_run_id is None:
raise ValueError("Either a flow or a flow_run_id must be provided.")
@@ -167,7 +167,7 @@ def is_pending(self) -> bool:
return False # TODO: handle this differently?
return getattr(self, "flow_run").state.is_pending()
- def cancel_all_tasks(self):
+ def cancel_all_tasks(self) -> None:
if hasattr(self.flow.task_runner, "cancel_all"):
self.flow.task_runner.cancel_all() # type: ignore
@@ -208,6 +208,8 @@ def _update_otel_labels(
@dataclass
class FlowRunEngine(BaseFlowRunEngine[P, R]):
_client: Optional[SyncPrefectClient] = None
+ flow_run: FlowRun | None = None
+ parameters: dict[str, Any] | None = None
@property
def client(self) -> SyncPrefectClient:
@@ -502,7 +504,7 @@ def create_flow_run(self, client: SyncPrefectClient) -> FlowRun:
tags=TagsContext.get().current_tags,
)
- def call_hooks(self, state: Optional[State] = None):
+ def call_hooks(self, state: Optional[State] = None) -> None:
if state is None:
state = self.state
flow = self.flow
@@ -600,7 +602,9 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
# set the logger to the flow run logger
- self.logger = flow_run_logger(flow_run=self.flow_run, flow=self.flow)
+ self.logger: "logging.Logger" = flow_run_logger(
+ flow_run=self.flow_run, flow=self.flow
+ ) # type: ignore
# update the flow run name if necessary
if not self._flow_run_name_set and self.flow.flow_run_name:
@@ -768,6 +772,8 @@ class AsyncFlowRunEngine(BaseFlowRunEngine[P, R]):
"""
_client: Optional[PrefectClient] = None
+ parameters: dict[str, Any] | None = None
+ flow_run: FlowRun | None = None
@property
def client(self) -> PrefectClient:
@@ -1061,7 +1067,7 @@ async def create_flow_run(self, client: PrefectClient) -> FlowRun:
tags=TagsContext.get().current_tags,
)
- async def call_hooks(self, state: Optional[State] = None):
+ async def call_hooks(self, state: Optional[State] = None) -> None:
if state is None:
state = self.state
flow = self.flow
@@ -1158,7 +1164,9 @@ async def setup_run_context(self, client: Optional[PrefectClient] = None):
stack.enter_context(ConcurrencyContext())
# set the logger to the flow run logger
- self.logger = flow_run_logger(flow_run=self.flow_run, flow=self.flow)
+ self.logger: "logging.Logger" = flow_run_logger(
+ flow_run=self.flow_run, flow=self.flow
+ )
# update the flow run name if necessary
@@ -1320,7 +1328,7 @@ def run_flow_sync(
flow: Flow[P, R],
flow_run: Optional[FlowRun] = None,
parameters: Optional[Dict[str, Any]] = None,
- wait_for: Optional[Iterable[PrefectFuture]] = None,
+ wait_for: Optional[Iterable[PrefectFuture[Any]]] = None,
return_type: Literal["state", "result"] = "result",
) -> Union[R, State, None]:
engine = FlowRunEngine[P, R](
@@ -1342,7 +1350,7 @@ async def run_flow_async(
flow: Flow[P, R],
flow_run: Optional[FlowRun] = None,
parameters: Optional[Dict[str, Any]] = None,
- wait_for: Optional[Iterable[PrefectFuture]] = None,
+ wait_for: Optional[Iterable[PrefectFuture[Any]]] = None,
return_type: Literal["state", "result"] = "result",
) -> Union[R, State, None]:
engine = AsyncFlowRunEngine[P, R](
@@ -1361,7 +1369,7 @@ def run_generator_flow_sync(
flow: Flow[P, R],
flow_run: Optional[FlowRun] = None,
parameters: Optional[Dict[str, Any]] = None,
- wait_for: Optional[Iterable[PrefectFuture]] = None,
+ wait_for: Optional[Iterable[PrefectFuture[Any]]] = None,
return_type: Literal["state", "result"] = "result",
) -> Generator[R, None, None]:
if return_type != "result":
diff --git a/src/prefect/flows.py b/src/prefect/flows.py
index af93f1a652d1..9a3af1eff546 100644
--- a/src/prefect/flows.py
+++ b/src/prefect/flows.py
@@ -86,6 +86,7 @@
sync_compatible,
)
from prefect.utilities.callables import (
+ ParameterSchema,
get_call_parameters,
parameter_schema,
parameters_to_args_kwargs,
@@ -272,7 +273,7 @@ def __init__(
if not callable(fn):
raise TypeError("'fn' must be callable")
- self.name = name or fn.__name__.replace("_", "-").replace(
+ self.name: str = name or fn.__name__.replace("_", "-").replace(
"",
"unknown-lambda", # prefect API will not accept "<" or ">" in flow names
)
@@ -287,29 +288,29 @@ def __init__(
self.flow_run_name = flow_run_name
if task_runner is None:
- self.task_runner = cast(
+ self.task_runner: TaskRunner[PrefectFuture[Any]] = cast(
TaskRunner[PrefectFuture[Any]], ThreadPoolTaskRunner()
)
else:
- self.task_runner = (
+ self.task_runner: TaskRunner[PrefectFuture[Any]] = (
task_runner() if isinstance(task_runner, type) else task_runner
)
self.log_prints = log_prints
- self.description = description or inspect.getdoc(fn)
+ self.description: str | None = description or inspect.getdoc(fn)
update_wrapper(self, fn)
self.fn = fn
# the flow is considered async if its function is async or an async
# generator
- self.isasync = asyncio.iscoroutinefunction(
+ self.isasync: bool = asyncio.iscoroutinefunction(
self.fn
) or inspect.isasyncgenfunction(self.fn)
# the flow is considered a generator if its function is a generator or
# an async generator
- self.isgenerator = inspect.isgeneratorfunction(
+ self.isgenerator: bool = inspect.isgeneratorfunction(
self.fn
) or inspect.isasyncgenfunction(self.fn)
@@ -326,22 +327,24 @@ def __init__(
pass # `getsourcefile` can return null values and "" for objects in repls
self.version = version
- self.timeout_seconds = float(timeout_seconds) if timeout_seconds else None
+ self.timeout_seconds: float | None = (
+ float(timeout_seconds) if timeout_seconds else None
+ )
# FlowRunPolicy settings
# TODO: We can instantiate a `FlowRunPolicy` and add Pydantic bound checks to
# validate that the user passes positive numbers here
- self.retries = (
+ self.retries: int = (
retries if retries is not None else PREFECT_FLOW_DEFAULT_RETRIES.value()
)
- self.retry_delay_seconds = (
+ self.retry_delay_seconds: float | int = (
retry_delay_seconds
if retry_delay_seconds is not None
else PREFECT_FLOW_DEFAULT_RETRY_DELAY_SECONDS.value()
)
- self.parameters = parameter_schema(self.fn)
+ self.parameters: ParameterSchema = parameter_schema(self.fn)
self.should_validate_parameters = validate_parameters
if self.should_validate_parameters:
@@ -421,7 +424,7 @@ def with_options(
description: Optional[str] = None,
flow_run_name: Optional[Union[Callable[[], str], str]] = None,
task_runner: Union[
- Type[TaskRunner[PrefectFuture[R]]], TaskRunner[PrefectFuture[R]], None
+ Type[TaskRunner[PrefectFuture[Any]]], TaskRunner[PrefectFuture[Any]], None
] = None,
timeout_seconds: Union[int, float, None] = None,
validate_parameters: Optional[bool] = None,
@@ -1708,7 +1711,7 @@ def from_source(
...
-flow = FlowDecorator()
+flow: FlowDecorator = FlowDecorator()
def _raise_on_name_with_banned_characters(name: Optional[str]) -> Optional[str]:
diff --git a/src/prefect/futures.py b/src/prefect/futures.py
index 97e5bba64cc0..ee31acc58503 100644
--- a/src/prefect/futures.py
+++ b/src/prefect/futures.py
@@ -5,7 +5,7 @@
import uuid
from collections.abc import Generator, Iterator
from functools import partial
-from typing import Any, Callable, Generic, Optional, Union
+from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, Union
from typing_extensions import NamedTuple, Self, TypeVar
@@ -22,7 +22,10 @@
F = TypeVar("F")
R = TypeVar("R")
-logger = get_logger(__name__)
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger(__name__)
class PrefectFuture(abc.ABC, Generic[R]):
diff --git a/src/prefect/infrastructure/provisioners/__init__.py b/src/prefect/infrastructure/provisioners/__init__.py
index 545a8576a102..bb8270360841 100644
--- a/src/prefect/infrastructure/provisioners/__init__.py
+++ b/src/prefect/infrastructure/provisioners/__init__.py
@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol, Type
+from prefect.infrastructure.provisioners.coiled import CoiledPushProvisioner
from prefect.infrastructure.provisioners.modal import ModalPushProvisioner
from .cloud_run import CloudRunPushProvisioner
from .container_instance import ContainerInstancePushProvisioner
@@ -15,6 +16,7 @@
"azure-container-instance:push": ContainerInstancePushProvisioner,
"ecs:push": ElasticContainerServicePushProvisioner,
"modal:push": ModalPushProvisioner,
+ "coiled:push": CoiledPushProvisioner,
}
diff --git a/src/prefect/infrastructure/provisioners/cloud_run.py b/src/prefect/infrastructure/provisioners/cloud_run.py
index 4b6d281c688d..011d5e2f4066 100644
--- a/src/prefect/infrastructure/provisioners/cloud_run.py
+++ b/src/prefect/infrastructure/provisioners/cloud_run.py
@@ -33,7 +33,7 @@
class CloudRunPushProvisioner:
def __init__(self):
- self._console = Console()
+ self._console: Console = Console()
self._project = None
self._region = None
self._service_account_name = "prefect-cloud-run"
@@ -41,14 +41,14 @@ def __init__(self):
self._image_repository_name = "prefect-images"
@property
- def console(self):
+ def console(self) -> Console:
return self._console
@console.setter
- def console(self, value):
+ def console(self, value: Console) -> None:
self._console = value
- async def _run_command(self, command: str, *args, **kwargs):
+ async def _run_command(self, command: str, *args: Any, **kwargs: Any) -> str:
result = await run_process(shlex.split(command), check=False, *args, **kwargs)
if result.returncode != 0:
diff --git a/src/prefect/infrastructure/provisioners/coiled.py b/src/prefect/infrastructure/provisioners/coiled.py
new file mode 100644
index 000000000000..f0adb5d7ce50
--- /dev/null
+++ b/src/prefect/infrastructure/provisioners/coiled.py
@@ -0,0 +1,249 @@
+import importlib
+import shlex
+import sys
+from copy import deepcopy
+from types import ModuleType
+from typing import TYPE_CHECKING, Any, Dict, Optional
+
+from anyio import run_process
+from rich.console import Console
+from rich.progress import Progress, SpinnerColumn, TextColumn
+from rich.prompt import Confirm
+
+from prefect.client.schemas.actions import BlockDocumentCreate
+from prefect.client.schemas.objects import BlockDocument
+from prefect.client.utilities import inject_client
+from prefect.exceptions import ObjectNotFound
+from prefect.utilities.importtools import lazy_import
+
+if TYPE_CHECKING:
+ from prefect.client.orchestration import PrefectClient
+
+
+coiled: ModuleType = lazy_import("coiled")
+
+
+class CoiledPushProvisioner:
+ """
+ A infrastructure provisioner for Coiled push work pools.
+ """
+
+ def __init__(self, client: Optional["PrefectClient"] = None):
+ self._console = Console()
+
+ @property
+ def console(self) -> Console:
+ return self._console
+
+ @console.setter
+ def console(self, value: Console) -> None:
+ self._console = value
+
+ @staticmethod
+ def _is_coiled_installed() -> bool:
+ """
+ Checks if the coiled package is installed.
+
+ Returns:
+ True if the coiled package is installed, False otherwise
+ """
+ try:
+ importlib.import_module("coiled")
+ return True
+ except ModuleNotFoundError:
+ return False
+
+ async def _install_coiled(self):
+ """
+ Installs the coiled package.
+ """
+ with Progress(
+ SpinnerColumn(),
+ TextColumn("[bold blue]Installing coiled..."),
+ transient=True,
+ console=self.console,
+ ) as progress:
+ task = progress.add_task("coiled install")
+ progress.start()
+ global coiled
+ await run_process(
+ [shlex.quote(sys.executable), "-m", "pip", "install", "coiled"]
+ )
+ coiled = importlib.import_module("coiled")
+ progress.advance(task)
+
+ async def _get_coiled_token(self) -> str:
+ """
+ Gets a Coiled API token from the current Coiled configuration.
+ """
+ import dask.config
+
+ return dask.config.get("coiled.token", "")
+
+ async def _create_new_coiled_token(self):
+ """
+ Triggers a Coiled login via the browser if no current token. Will create a new token.
+ """
+ await run_process(["coiled", "login"])
+
+ async def _create_coiled_credentials_block(
+ self,
+ block_document_name: str,
+ coiled_token: str,
+ client: "PrefectClient",
+ ) -> BlockDocument:
+ """
+ Creates a CoiledCredentials block containing the provided token.
+
+ Args:
+ block_document_name: The name of the block document to create
+ coiled_token: The Coiled API token
+
+ Returns:
+ The ID of the created block
+ """
+ assert client is not None, "client injection failed"
+ try:
+ credentials_block_type = await client.read_block_type_by_slug(
+ "coiled-credentials"
+ )
+ except ObjectNotFound:
+ # Shouldn't happen, but just in case
+ raise RuntimeError(
+ "Unable to find CoiledCredentials block type. Please ensure you are"
+ " using Prefect Cloud."
+ )
+ credentials_block_schema = (
+ await client.get_most_recent_block_schema_for_block_type(
+ block_type_id=credentials_block_type.id
+ )
+ )
+ assert (
+ credentials_block_schema is not None
+ ), f"Unable to find schema for block type {credentials_block_type.slug}"
+
+ block_doc = await client.create_block_document(
+ block_document=BlockDocumentCreate(
+ name=block_document_name,
+ data={
+ "api_token": coiled_token,
+ },
+ block_type_id=credentials_block_type.id,
+ block_schema_id=credentials_block_schema.id,
+ )
+ )
+ return block_doc
+
+ @inject_client
+ async def provision(
+ self,
+ work_pool_name: str,
+ base_job_template: Dict[str, Any],
+ client: Optional["PrefectClient"] = None,
+ ) -> Dict[str, Any]:
+ """
+ Provisions resources necessary for a Coiled push work pool.
+
+ Provisioned resources:
+ - A CoiledCredentials block containing a Coiled API token
+
+ Args:
+ work_pool_name: The name of the work pool to provision resources for
+ base_job_template: The base job template to update
+
+ Returns:
+ A copy of the provided base job template with the provisioned resources
+ """
+ credentials_block_name = f"{work_pool_name}-coiled-credentials"
+ base_job_template_copy = deepcopy(base_job_template)
+ assert client is not None, "client injection failed"
+ try:
+ block_doc = await client.read_block_document_by_name(
+ credentials_block_name, "coiled-credentials"
+ )
+ self.console.print(
+ f"Work pool [blue]{work_pool_name!r}[/] will reuse the existing Coiled"
+ f" credentials block [blue]{credentials_block_name!r}[/blue]"
+ )
+ except ObjectNotFound:
+ if self._console.is_interactive and not Confirm.ask(
+ (
+ "\n"
+ "To configure your Coiled push work pool we'll need to store a Coiled"
+ " API token with Prefect Cloud as a block. We'll pull the token from"
+ " your local Coiled configuration or create a new token if we"
+ " can't find one.\n"
+ "\n"
+ "Would you like to continue?"
+ ),
+ console=self.console,
+ default=True,
+ ):
+ self.console.print(
+ "No problem! You can always configure your Coiled push work pool"
+ " later via the Prefect UI."
+ )
+ return base_job_template
+
+ if not self._is_coiled_installed():
+ if self.console.is_interactive and Confirm.ask(
+ (
+ "The [blue]coiled[/] package is required to configure"
+ " authentication for your work pool.\n"
+ "\n"
+ "Would you like to install it now?"
+ ),
+ console=self.console,
+ default=True,
+ ):
+ await self._install_coiled()
+
+ if not self._is_coiled_installed():
+ raise RuntimeError(
+ "The coiled package is not installed.\n\nPlease try installing coiled,"
+ " or you can use the Prefect UI to create your Coiled push work pool."
+ )
+
+ # Get the current Coiled API token
+ coiled_api_token = await self._get_coiled_token()
+ if not coiled_api_token:
+ # Create a new token one wasn't found
+ if self.console.is_interactive and Confirm.ask(
+ "Coiled credentials not found. Would you like to create a new token?",
+ console=self.console,
+ default=True,
+ ):
+ await self._create_new_coiled_token()
+ coiled_api_token = await self._get_coiled_token()
+ else:
+ raise RuntimeError(
+ "Coiled credentials not found. Please create a new token by"
+ " running [blue]coiled login[/] and try again."
+ )
+
+ # Create the credentials block
+ with Progress(
+ SpinnerColumn(),
+ TextColumn("[bold blue]Saving Coiled credentials..."),
+ transient=True,
+ console=self.console,
+ ) as progress:
+ task = progress.add_task("create coiled credentials block")
+ progress.start()
+ block_doc = await self._create_coiled_credentials_block(
+ credentials_block_name,
+ coiled_api_token,
+ client=client,
+ )
+ progress.advance(task)
+
+ base_job_template_copy["variables"]["properties"]["credentials"]["default"] = {
+ "$ref": {"block_document_id": str(block_doc.id)}
+ }
+ if "image" in base_job_template_copy["variables"]["properties"]:
+ base_job_template_copy["variables"]["properties"]["image"]["default"] = ""
+ self.console.print(
+ f"Successfully configured Coiled push work pool {work_pool_name!r}!",
+ style="green",
+ )
+ return base_job_template_copy
diff --git a/src/prefect/infrastructure/provisioners/container_instance.py b/src/prefect/infrastructure/provisioners/container_instance.py
index 17eb85fdb222..4bbace63fa27 100644
--- a/src/prefect/infrastructure/provisioners/container_instance.py
+++ b/src/prefect/infrastructure/provisioners/container_instance.py
@@ -10,6 +10,7 @@
ContainerInstancePushProvisioner: A class for provisioning infrastructure using Azure Container Instances.
"""
+from __future__ import annotations
import json
import random
@@ -64,7 +65,7 @@ async def run_command(
failure_message: Optional[str] = None,
ignore_if_exists: bool = False,
return_json: bool = False,
- ):
+ ) -> str | dict[str, Any] | None:
"""
Runs an Azure CLI command and processes the output.
@@ -156,7 +157,7 @@ def __init__(self):
self._subscription_name = None
self._location = "eastus"
self._identity_name = "prefect-acr-identity"
- self.azure_cli = AzureCLI(self.console)
+ self.azure_cli: AzureCLI = AzureCLI(self.console)
self._credentials_block_name = None
self._resource_group_name = "prefect-aci-push-pool-rg"
self._app_registration_name = "prefect-aci-push-pool-app"
@@ -170,7 +171,7 @@ def console(self) -> Console:
def console(self, value: Console) -> None:
self._console = value
- async def set_location(self):
+ async def set_location(self) -> None:
"""
Set the Azure resource deployment location to the default or 'eastus' on failure.
diff --git a/src/prefect/infrastructure/provisioners/ecs.py b/src/prefect/infrastructure/provisioners/ecs.py
index a851055e11d3..2f6cb1460f87 100644
--- a/src/prefect/infrastructure/provisioners/ecs.py
+++ b/src/prefect/infrastructure/provisioners/ecs.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import base64
import contextlib
import contextvars
@@ -9,9 +11,11 @@
from copy import deepcopy
from functools import partial
from textwrap import dedent
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
+from types import ModuleType
+from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional
import anyio
+import anyio.to_thread
from anyio import run_process
from rich.console import Console
from rich.panel import Panel
@@ -33,13 +37,15 @@
if TYPE_CHECKING:
from prefect.client.orchestration import PrefectClient
-boto3 = lazy_import("boto3")
+boto3: ModuleType = lazy_import("boto3")
-current_console = contextvars.ContextVar("console", default=Console())
+current_console: contextvars.ContextVar[Console] = contextvars.ContextVar(
+ "console", default=Console()
+)
@contextlib.contextmanager
-def console_context(value: Console):
+def console_context(value: Console) -> Generator[None, None, None]:
token = current_console.set(value)
try:
yield
@@ -73,7 +79,7 @@ async def get_task_count(self) -> int:
"""
return 1 if await self.requires_provisioning() else 0
- def _get_policy_by_name(self, name):
+ def _get_policy_by_name(self, name: str) -> dict[str, Any] | None:
paginator = self._iam_client.get_paginator("list_policies")
page_iterator = paginator.paginate(Scope="Local")
@@ -119,9 +125,9 @@ async def get_planned_actions(self) -> List[str]:
async def provision(
self,
- policy_document: Dict[str, Any],
+ policy_document: dict[str, Any],
advance: Callable[[], None],
- ):
+ ) -> str:
"""
Provisions an IAM policy.
@@ -153,7 +159,7 @@ async def provision(
return policy["Arn"]
@property
- def next_steps(self):
+ def next_steps(self) -> list[str]:
return []
@@ -215,7 +221,7 @@ async def get_planned_actions(self) -> List[str]:
async def provision(
self,
advance: Callable[[], None],
- ):
+ ) -> None:
"""
Provisions an IAM user.
@@ -231,7 +237,7 @@ async def provision(
advance()
@property
- def next_steps(self):
+ def next_steps(self) -> list[str]:
return []
@@ -241,7 +247,7 @@ def __init__(self, user_name: str, block_document_name: str):
self._user_name = user_name
self._requires_provisioning = None
- async def get_task_count(self):
+ async def get_task_count(self) -> int:
"""
Returns the number of tasks that will be executed to provision this resource.
@@ -357,7 +363,7 @@ async def provision(
}
@property
- def next_steps(self):
+ def next_steps(self) -> list[str]:
return []
@@ -374,7 +380,7 @@ def __init__(
credentials_block_name or f"{work_pool_name}-aws-credentials"
)
self._policy_name = policy_name
- self._policy_document = {
+ self._policy_document: dict[str, Any] = {
"Version": "2012-10-17",
"Statement": [
{
@@ -417,7 +423,11 @@ def __init__(
self._execution_role_resource = ExecutionRoleResource()
@property
- def resources(self):
+ def resources(
+ self,
+ ) -> list[
+ "ExecutionRoleResource | IamUserResource | IamPolicyResource | CredentialsBlockResource"
+ ]:
return [
self._execution_role_resource,
self._iam_user_resource,
@@ -425,7 +435,7 @@ def resources(self):
self._credentials_block_resource,
]
- async def get_task_count(self):
+ async def get_task_count(self) -> int:
"""
Returns the number of tasks that will be executed to provision this resource.
@@ -461,9 +471,9 @@ async def get_planned_actions(self) -> List[str]:
async def provision(
self,
- base_job_template: Dict[str, Any],
+ base_job_template: dict[str, Any],
advance: Callable[[], None],
- ):
+ ) -> None:
"""
Provisions the authentication resources.
@@ -507,7 +517,7 @@ async def provision(
)
@property
- def next_steps(self):
+ def next_steps(self) -> list[str]:
return [
next_step
for resource in self.resources
@@ -521,7 +531,7 @@ def __init__(self, cluster_name: str = "prefect-ecs-cluster"):
self._cluster_name = cluster_name
self._requires_provisioning = None
- async def get_task_count(self):
+ async def get_task_count(self) -> int:
"""
Returns the number of tasks that will be executed to provision this resource.
@@ -566,9 +576,9 @@ async def get_planned_actions(self) -> List[str]:
async def provision(
self,
- base_job_template: Dict[str, Any],
+ base_job_template: dict[str, Any],
advance: Callable[[], None],
- ):
+ ) -> None:
"""
Provisions an ECS cluster.
@@ -592,7 +602,7 @@ async def provision(
] = self._cluster_name
@property
- def next_steps(self):
+ def next_steps(self) -> list[str]:
return []
@@ -608,7 +618,7 @@ def __init__(
self._requires_provisioning = None
self._ecs_security_group_name = ecs_security_group_name
- async def get_task_count(self):
+ async def get_task_count(self) -> int:
"""
Returns the number of tasks that will be executed to provision this resource.
@@ -642,7 +652,9 @@ async def _get_existing_vpc_cidrs(self):
response = await anyio.to_thread.run_sync(self._ec2_client.describe_vpcs)
return [vpc["CidrBlock"] for vpc in response["Vpcs"]]
- async def _find_non_overlapping_cidr(self, default_cidr="172.31.0.0/16"):
+ async def _find_non_overlapping_cidr(
+ self, default_cidr: str = "172.31.0.0/16"
+ ) -> str:
"""Find a non-overlapping CIDR block"""
response = await anyio.to_thread.run_sync(self._ec2_client.describe_vpcs)
existing_cidrs = [vpc["CidrBlock"] for vpc in response["Vpcs"]]
@@ -708,9 +720,9 @@ async def get_planned_actions(self) -> List[str]:
async def provision(
self,
- base_job_template: Dict[str, Any],
+ base_job_template: dict[str, Any],
advance: Callable[[], None],
- ):
+ ) -> None:
"""
Provisions a VPC.
@@ -768,7 +780,7 @@ async def provision(
)
)["AvailabilityZones"]
zones = [az["ZoneName"] for az in azs]
- subnets = []
+ subnets: list[Any] = []
for i, subnet_cidr in enumerate(subnet_cidrs[0:3]):
subnets.append(
await anyio.to_thread.run_sync(
@@ -828,7 +840,7 @@ async def provision(
)
@property
- def next_steps(self):
+ def next_steps(self) -> list[str]:
return []
@@ -838,9 +850,9 @@ def __init__(self, work_pool_name: str, repository_name: str = "prefect-flows"):
self._repository_name = repository_name
self._requires_provisioning = None
self._work_pool_name = work_pool_name
- self._next_steps = []
+ self._next_steps: list[str | Panel] = []
- async def get_task_count(self):
+ async def get_task_count(self) -> int:
"""
Returns the number of tasks that will be executed to provision this resource.
@@ -895,9 +907,9 @@ async def get_planned_actions(self) -> List[str]:
async def provision(
self,
- base_job_template: Dict[str, Any],
+ base_job_template: dict[str, Any],
advance: Callable[[], None],
- ):
+ ) -> None:
"""
Provisions an ECR repository.
@@ -978,7 +990,7 @@ def my_flow(name: str = "world"):
)
@property
- def next_steps(self):
+ def next_steps(self) -> list[str | Panel]:
return self._next_steps
@@ -1000,7 +1012,7 @@ def __init__(self, execution_role_name: str = "PrefectEcsTaskExecutionRole"):
)
self._requires_provisioning = None
- async def get_task_count(self):
+ async def get_task_count(self) -> int:
"""
Returns the number of tasks that will be executed to provision this resource.
@@ -1046,9 +1058,9 @@ async def get_planned_actions(self) -> List[str]:
async def provision(
self,
- base_job_template: Dict[str, Any],
+ base_job_template: dict[str, Any],
advance: Callable[[], None],
- ):
+ ) -> str:
"""
Provisions an IAM role.
@@ -1087,7 +1099,7 @@ async def provision(
return response["Role"]["Arn"]
@property
- def next_steps(self):
+ def next_steps(self) -> list[str]:
return []
@@ -1100,11 +1112,11 @@ def __init__(self):
self._console = Console()
@property
- def console(self):
+ def console(self) -> Console:
return self._console
@console.setter
- def console(self, value):
+ def console(self, value: Console) -> None:
self._console = value
async def _prompt_boto3_installation(self):
@@ -1115,7 +1127,7 @@ async def _prompt_boto3_installation(self):
boto3 = importlib.import_module("boto3")
@staticmethod
- def is_boto3_installed():
+ def is_boto3_installed() -> bool:
"""
Check if boto3 is installed.
"""
@@ -1157,8 +1169,8 @@ def _generate_resources(
async def provision(
self,
work_pool_name: str,
- base_job_template: dict,
- ) -> Dict[str, Any]:
+ base_job_template: dict[str, Any],
+ ) -> dict[str, Any]:
"""
Provisions the infrastructure for an ECS push work pool.
@@ -1310,7 +1322,7 @@ async def provision(
# provision calls will be no-ops, but update the base job template
base_job_template_copy = deepcopy(base_job_template)
- next_steps = []
+ next_steps: list[str | Panel] = []
with Progress(console=self._console, disable=num_tasks == 0) as progress:
task = progress.add_task(
"Provisioning Infrastructure",
diff --git a/src/prefect/infrastructure/provisioners/modal.py b/src/prefect/infrastructure/provisioners/modal.py
index 274775817f86..04960d9c5d89 100644
--- a/src/prefect/infrastructure/provisioners/modal.py
+++ b/src/prefect/infrastructure/provisioners/modal.py
@@ -2,6 +2,7 @@
import shlex
import sys
from copy import deepcopy
+from types import ModuleType
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from anyio import run_process
@@ -19,7 +20,7 @@
from prefect.client.orchestration import PrefectClient
-modal = lazy_import("modal")
+modal: ModuleType = lazy_import("modal")
class ModalPushProvisioner:
@@ -28,14 +29,14 @@ class ModalPushProvisioner:
"""
def __init__(self, client: Optional["PrefectClient"] = None):
- self._console = Console()
+ self._console: Console = Console()
@property
- def console(self):
+ def console(self) -> Console:
return self._console
@console.setter
- def console(self, value):
+ def console(self, value: Console) -> None:
self._console = value
@staticmethod
diff --git a/src/prefect/input/actions.py b/src/prefect/input/actions.py
index 9e88e5c59b69..e771ca491c39 100644
--- a/src/prefect/input/actions.py
+++ b/src/prefect/input/actions.py
@@ -1,3 +1,4 @@
+import inspect
from typing import TYPE_CHECKING, Any, Optional, Set
from uuid import UUID
@@ -44,9 +45,12 @@ async def create_flow_run_input_from_model(
else:
json_safe = orjson.loads(model_instance.json())
- await create_flow_run_input(
+ coro = create_flow_run_input(
key=key, value=json_safe, flow_run_id=flow_run_id, sender=sender
)
+ if TYPE_CHECKING:
+ assert inspect.iscoroutine(coro)
+ await coro
@sync_compatible
diff --git a/src/prefect/input/run_input.py b/src/prefect/input/run_input.py
index 75434eb4cec4..1567b0d2ff72 100644
--- a/src/prefect/input/run_input.py
+++ b/src/prefect/input/run_input.py
@@ -60,11 +60,15 @@ async def receiver_flow():
```
"""
+from __future__ import annotations
+
+import inspect
from inspect import isclass
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
+ Coroutine,
Dict,
Generic,
Literal,
@@ -81,6 +85,7 @@ async def receiver_flow():
import anyio
import pydantic
from pydantic import ConfigDict
+from typing_extensions import Self
from prefect.input.actions import (
create_flow_run_input,
@@ -144,7 +149,7 @@ class RunInputMetadata(pydantic.BaseModel):
receiver: UUID
-class RunInput(pydantic.BaseModel):
+class BaseRunInput(pydantic.BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
_description: Optional[str] = pydantic.PrivateAttr(default=None)
@@ -172,23 +177,29 @@ async def save(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None):
if is_v2_model(cls):
schema = create_v2_schema(cls.__name__, model_base=cls)
else:
- schema = cls.schema(by_alias=True)
+ schema = cls.model_json_schema(by_alias=True)
- await create_flow_run_input(
+ coro = create_flow_run_input(
key=keyset["schema"], value=schema, flow_run_id=flow_run_id
)
+ if TYPE_CHECKING:
+ assert inspect.iscoroutine(coro)
+ await coro
description = cls._description if isinstance(cls._description, str) else None
if description:
- await create_flow_run_input(
+ coro = create_flow_run_input(
key=keyset["description"],
value=description,
flow_run_id=flow_run_id,
)
+ if TYPE_CHECKING:
+ assert inspect.iscoroutine(coro)
+ await coro
@classmethod
@sync_compatible
- async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None):
+ async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None) -> Self:
"""
Load the run input response from the given key.
@@ -208,7 +219,7 @@ async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None):
return instance
@classmethod
- def load_from_flow_run_input(cls, flow_run_input: "FlowRunInput"):
+ def load_from_flow_run_input(cls, flow_run_input: "FlowRunInput") -> Self:
"""
Load the run input from a FlowRunInput object.
@@ -284,6 +295,8 @@ async def send_to(
key_prefix=key_prefix,
)
+
+class RunInput(BaseRunInput):
@classmethod
def receive(
cls,
@@ -293,7 +306,7 @@ def receive(
exclude_keys: Optional[Set[str]] = None,
key_prefix: Optional[str] = None,
flow_run_id: Optional[UUID] = None,
- ):
+ ) -> GetInputHandler[Self]:
if key_prefix is None:
key_prefix = f"{cls.__name__.lower()}-auto"
@@ -322,12 +335,12 @@ def subclass_from_base_model_type(
return type(f"{model_cls.__name__}RunInput", (RunInput, model_cls), {}) # type: ignore
-class AutomaticRunInput(RunInput, Generic[T]):
+class AutomaticRunInput(BaseRunInput, Generic[T]):
value: T
@classmethod
@sync_compatible
- async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None) -> T:
+ async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None) -> Self:
"""
Load the run input response from the given key.
@@ -335,7 +348,10 @@ async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None) -> T:
- keyset (Keyset): the keyset to load the input for
- flow_run_id (UUID, optional): the flow run ID to load the input for
"""
- instance = await super().load(keyset, flow_run_id=flow_run_id)
+ instance_coro = super().load(keyset, flow_run_id=flow_run_id)
+ if TYPE_CHECKING:
+ assert inspect.iscoroutine(instance_coro)
+ instance = await instance_coro
return instance.value
@classmethod
@@ -370,17 +386,34 @@ def subclass_from_type(cls, _type: Type[T]) -> Type["AutomaticRunInput[T]"]:
# Creating a new Pydantic model class dynamically with the name based
# on the type prefix.
- new_cls: Type["AutomaticRunInput"] = pydantic.create_model(
+ new_cls: Type["AutomaticRunInput[T]"] = pydantic.create_model(
class_name, **fields, __base__=AutomaticRunInput
)
return new_cls
@classmethod
- def receive(cls, *args, **kwargs):
- if kwargs.get("key_prefix") is None:
- kwargs["key_prefix"] = f"{cls.__name__.lower()}-auto"
+ def receive(
+ cls,
+ timeout: Optional[float] = 3600,
+ poll_interval: float = 10,
+ raise_timeout_error: bool = False,
+ exclude_keys: Optional[Set[str]] = None,
+ key_prefix: Optional[str] = None,
+ flow_run_id: Optional[UUID] = None,
+ with_metadata: bool = False,
+ ) -> GetAutomaticInputHandler[T]:
+ key_prefix = key_prefix or f"{cls.__name__.lower()}-auto"
- return GetAutomaticInputHandler(run_input_cls=cls, *args, **kwargs)
+ return GetAutomaticInputHandler(
+ run_input_cls=cls,
+ key_prefix=key_prefix,
+ timeout=timeout,
+ poll_interval=poll_interval,
+ raise_timeout_error=raise_timeout_error,
+ exclude_keys=exclude_keys,
+ flow_run_id=flow_run_id,
+ with_metadata=with_metadata,
+ )
def run_input_subclass_from_type(
@@ -409,24 +442,24 @@ def __init__(
self,
run_input_cls: Type[R],
key_prefix: str,
- timeout: Optional[float] = 3600,
+ timeout: float | None = 3600,
poll_interval: float = 10,
raise_timeout_error: bool = False,
exclude_keys: Optional[Set[str]] = None,
flow_run_id: Optional[UUID] = None,
):
- self.run_input_cls = run_input_cls
- self.key_prefix = key_prefix
- self.timeout = timeout
- self.poll_interval = poll_interval
- self.exclude_keys = set()
- self.raise_timeout_error = raise_timeout_error
- self.flow_run_id = ensure_flow_run_id(flow_run_id)
+ self.run_input_cls: Type[R] = run_input_cls
+ self.key_prefix: str = key_prefix
+ self.timeout: float | None = timeout
+ self.poll_interval: float = poll_interval
+ self.exclude_keys: set[str] = set()
+ self.raise_timeout_error: bool = raise_timeout_error
+ self.flow_run_id: UUID = ensure_flow_run_id(flow_run_id)
if exclude_keys is not None:
self.exclude_keys.update(exclude_keys)
- def __iter__(self):
+ def __iter__(self) -> Self:
return self
def __next__(self) -> R:
@@ -437,24 +470,31 @@ def __next__(self) -> R:
raise
raise StopIteration
- def __aiter__(self):
+ def __aiter__(self) -> Self:
return self
async def __anext__(self) -> R:
try:
- return await self.next()
+ coro = self.next()
+ if TYPE_CHECKING:
+ assert inspect.iscoroutine(coro)
+ return await coro
except TimeoutError:
if self.raise_timeout_error:
raise
raise StopAsyncIteration
- async def filter_for_inputs(self):
- flow_run_inputs = await filter_flow_run_input(
+ async def filter_for_inputs(self) -> list["FlowRunInput"]:
+ flow_run_inputs_coro = filter_flow_run_input(
key_prefix=self.key_prefix,
limit=1,
exclude_keys=self.exclude_keys,
flow_run_id=self.flow_run_id,
)
+ if TYPE_CHECKING:
+ assert inspect.iscoroutine(flow_run_inputs_coro)
+
+ flow_run_inputs = await flow_run_inputs_coro
if flow_run_inputs:
self.exclude_keys.add(*[i.key for i in flow_run_inputs])
@@ -478,22 +518,91 @@ async def next(self) -> R:
return self.to_instance(flow_run_inputs[0])
-class GetAutomaticInputHandler(GetInputHandler, Generic[T]):
- def __init__(self, *args, **kwargs):
- self.with_metadata = kwargs.pop("with_metadata", False)
- super().__init__(*args, **kwargs)
+class GetAutomaticInputHandler(Generic[T]):
+ def __init__(
+ self,
+ run_input_cls: Type[AutomaticRunInput[T]],
+ key_prefix: str,
+ timeout: float | None = 3600,
+ poll_interval: float = 10,
+ raise_timeout_error: bool = False,
+ exclude_keys: Optional[Set[str]] = None,
+ flow_run_id: Optional[UUID] = None,
+ with_metadata: bool = False,
+ ):
+ self.run_input_cls: Type[AutomaticRunInput[T]] = run_input_cls
+ self.key_prefix: str = key_prefix
+ self.timeout: float | None = timeout
+ self.poll_interval: float = poll_interval
+ self.exclude_keys: set[str] = set()
+ self.raise_timeout_error: bool = raise_timeout_error
+ self.flow_run_id: UUID = ensure_flow_run_id(flow_run_id)
+ self.with_metadata = with_metadata
- def __next__(self) -> T:
- return cast(T, super().__next__())
+ if exclude_keys is not None:
+ self.exclude_keys.update(exclude_keys)
- async def __anext__(self) -> T:
- return cast(T, await super().__anext__())
+ def __iter__(self) -> Self:
+ return self
+
+ def __next__(self) -> T | AutomaticRunInput[T]:
+ try:
+ not_coro = self.next()
+ if TYPE_CHECKING:
+ assert not isinstance(not_coro, Coroutine)
+ return not_coro
+ except TimeoutError:
+ if self.raise_timeout_error:
+ raise
+ raise StopIteration
+
+ def __aiter__(self) -> Self:
+ return self
+
+ async def __anext__(self) -> Union[T, AutomaticRunInput[T]]:
+ try:
+ coro = self.next()
+ if TYPE_CHECKING:
+ assert inspect.iscoroutine(coro)
+ return cast(Union[T, AutomaticRunInput[T]], await coro)
+ except TimeoutError:
+ if self.raise_timeout_error:
+ raise
+ raise StopAsyncIteration
+
+ async def filter_for_inputs(self) -> list["FlowRunInput"]:
+ flow_run_inputs_coro = filter_flow_run_input(
+ key_prefix=self.key_prefix,
+ limit=1,
+ exclude_keys=self.exclude_keys,
+ flow_run_id=self.flow_run_id,
+ )
+ if TYPE_CHECKING:
+ assert inspect.iscoroutine(flow_run_inputs_coro)
+
+ flow_run_inputs = await flow_run_inputs_coro
+
+ if flow_run_inputs:
+ self.exclude_keys.add(*[i.key for i in flow_run_inputs])
+
+ return flow_run_inputs
@sync_compatible
- async def next(self) -> T:
- return cast(T, await super().next())
+ async def next(self) -> Union[T, AutomaticRunInput[T]]:
+ flow_run_inputs = await self.filter_for_inputs()
+ if flow_run_inputs:
+ return self.to_instance(flow_run_inputs[0])
- def to_instance(self, flow_run_input: "FlowRunInput") -> T:
+ with anyio.fail_after(self.timeout):
+ while True:
+ await anyio.sleep(self.poll_interval)
+ flow_run_inputs = await self.filter_for_inputs()
+ if flow_run_inputs:
+ return self.to_instance(flow_run_inputs[0])
+
+ def to_instance(
+ self, flow_run_input: "FlowRunInput"
+ ) -> Union[T, AutomaticRunInput[T]]:
run_input = self.run_input_cls.load_from_flow_run_input(flow_run_input)
if self.with_metadata:
@@ -503,14 +612,15 @@ def to_instance(self, flow_run_input: "FlowRunInput") -> T:
async def _send_input(
flow_run_id: UUID,
- run_input: Any,
+ run_input: RunInput | pydantic.BaseModel,
sender: Optional[str] = None,
key_prefix: Optional[str] = None,
):
+ _run_input: Union[RunInput, AutomaticRunInput[Any]]
if isinstance(run_input, RunInput):
- _run_input: RunInput = run_input
+ _run_input = run_input
else:
- input_cls: Type[AutomaticRunInput] = run_input_subclass_from_type(
+ input_cls: Type[AutomaticRunInput[Any]] = run_input_subclass_from_type(
type(run_input)
)
_run_input = input_cls(value=run_input)
@@ -520,9 +630,13 @@ async def _send_input(
key = f"{key_prefix}-{uuid4()}"
- await create_flow_run_input_from_model(
+ coro = create_flow_run_input_from_model(
key=key, flow_run_id=flow_run_id, model_instance=_run_input, sender=sender
)
+ if TYPE_CHECKING:
+ assert inspect.iscoroutine(coro)
+
+ await coro
@sync_compatible
diff --git a/src/prefect/logging/configuration.py b/src/prefect/logging/configuration.py
index 9b666668e33d..73216645cd99 100644
--- a/src/prefect/logging/configuration.py
+++ b/src/prefect/logging/configuration.py
@@ -6,7 +6,7 @@
import warnings
from functools import partial
from pathlib import Path
-from typing import Any, Dict, Optional
+from typing import Any, Callable, Dict, Optional
import yaml
@@ -24,10 +24,10 @@
PROCESS_LOGGING_CONFIG: Optional[Dict[str, Any]] = None
# Regex call to replace non-alphanumeric characters to '_' to create a valid env var
-to_envvar = partial(re.sub, re.compile(r"[^0-9a-zA-Z]+"), "_")
+to_envvar: Callable[[str], str] = partial(re.sub, re.compile(r"[^0-9a-zA-Z]+"), "_")
-def load_logging_config(path: Path) -> dict:
+def load_logging_config(path: Path) -> dict[str, Any]:
"""
Loads logging configuration from a path allowing override from the environment
"""
diff --git a/src/prefect/logging/filters.py b/src/prefect/logging/filters.py
index 43deb4847a6a..013025063d2a 100644
--- a/src/prefect/logging/filters.py
+++ b/src/prefect/logging/filters.py
@@ -5,7 +5,7 @@
from prefect.utilities.names import obfuscate
-def redact_substr(obj: Any, substr: str):
+def redact_substr(obj: Any, substr: str) -> Any:
"""
Redact a string from a potentially nested object.
@@ -17,7 +17,7 @@ def redact_substr(obj: Any, substr: str):
Any: The object with the API key redacted.
"""
- def redact_item(item):
+ def redact_item(item: Any) -> Any:
if isinstance(item, str):
return item.replace(substr, obfuscate(substr))
return item
diff --git a/src/prefect/logging/formatters.py b/src/prefect/logging/formatters.py
index a8eb92c8170d..9fe2c752984b 100644
--- a/src/prefect/logging/formatters.py
+++ b/src/prefect/logging/formatters.py
@@ -1,8 +1,10 @@
+from __future__ import annotations
+
import logging.handlers
import sys
import traceback
from types import TracebackType
-from typing import Optional, Tuple, Type, Union
+from typing import Any, Literal, Optional, Tuple, Type, Union
import orjson
@@ -14,7 +16,7 @@
]
-def format_exception_info(exc_info: ExceptionInfoType) -> dict:
+def format_exception_info(exc_info: ExceptionInfoType) -> dict[str, Any]:
# if sys.exc_info() returned a (None, None, None) tuple,
# then there's nothing to format
if exc_info[0] is None:
@@ -40,13 +42,15 @@ class JsonFormatter(logging.Formatter):
newlines.
"""
- def __init__(self, fmt, dmft, style) -> None: # noqa
+ def __init__(
+ self, fmt: Literal["pretty", "default"], dmft: str, style: str
+ ) -> None: # noqa
super().__init__()
if fmt not in ["pretty", "default"]:
raise ValueError("Format must be either 'pretty' or 'default'.")
- self.serializer = JSONSerializer(
+ self.serializer: JSONSerializer = JSONSerializer(
jsonlib="orjson",
dumps_kwargs={"option": orjson.OPT_INDENT_2} if fmt == "pretty" else {},
)
@@ -72,13 +76,13 @@ def format(self, record: logging.LogRecord) -> str:
class PrefectFormatter(logging.Formatter):
def __init__(
self,
- format=None,
- datefmt=None,
- style="%",
- validate=True,
+ format: str | None = None,
+ datefmt: str | None = None,
+ style: str = "%",
+ validate: bool = True,
*,
- defaults=None,
- task_run_fmt: Optional[str] = None,
+ defaults: dict[str, Any] | None = None,
+ task_run_fmt: str | None = None,
flow_run_fmt: Optional[str] = None,
) -> None:
"""
@@ -118,7 +122,7 @@ def __init__(
self._flow_run_style.validate()
self._task_run_style.validate()
- def formatMessage(self, record: logging.LogRecord):
+ def formatMessage(self, record: logging.LogRecord) -> str:
if record.name == "prefect.flow_runs":
style = self._flow_run_style
elif record.name == "prefect.task_runs":
diff --git a/src/prefect/logging/handlers.py b/src/prefect/logging/handlers.py
index a8323cff6f5b..4b76bef55688 100644
--- a/src/prefect/logging/handlers.py
+++ b/src/prefect/logging/handlers.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
import logging
import sys
@@ -6,7 +8,7 @@
import uuid
import warnings
from contextlib import asynccontextmanager
-from typing import Any, Dict, List, Optional, Type, Union
+from typing import TYPE_CHECKING, Any, Dict, List, TextIO, Type
import pendulum
from rich.console import Console
@@ -34,6 +36,14 @@
PREFECT_LOGGING_TO_API_WHEN_MISSING_FLOW,
)
+if sys.version_info >= (3, 12):
+ StreamHandler = logging.StreamHandler[TextIO]
+else:
+ if TYPE_CHECKING:
+ StreamHandler = logging.StreamHandler[TextIO]
+ else:
+ StreamHandler = logging.StreamHandler
+
class APILogWorker(BatchedQueueService[Dict[str, Any]]):
@property
@@ -90,7 +100,7 @@ class APILogHandler(logging.Handler):
"""
@classmethod
- def flush(cls):
+ def flush(cls) -> None:
"""
Tell the `APILogWorker` to send any currently enqueued logs and block until
completion.
@@ -118,7 +128,7 @@ def flush(cls):
return APILogWorker.drain_all(timeout=5)
@classmethod
- async def aflush(cls):
+ async def aflush(cls) -> bool:
"""
Tell the `APILogWorker` to send any currently enqueued logs and block until
completion.
@@ -126,7 +136,7 @@ async def aflush(cls):
return await APILogWorker.drain_all()
- def emit(self, record: logging.LogRecord):
+ def emit(self, record: logging.LogRecord) -> None:
"""
Send a log to the `APILogWorker`
"""
@@ -239,7 +249,7 @@ def _get_payload_size(self, log: Dict[str, Any]) -> int:
class WorkerAPILogHandler(APILogHandler):
- def emit(self, record: logging.LogRecord):
+ def emit(self, record: logging.LogRecord) -> None:
# Open-source API servers do not currently support worker logs, and
# worker logs only have an associated worker ID when connected to Cloud,
# so we won't send worker logs to the API unless they have a worker ID.
@@ -278,13 +288,13 @@ def prepare(self, record: logging.LogRecord) -> Dict[str, Any]:
return log
-class PrefectConsoleHandler(logging.StreamHandler):
+class PrefectConsoleHandler(StreamHandler):
def __init__(
self,
- stream=None,
- highlighter: Highlighter = PrefectConsoleHighlighter,
- styles: Optional[Dict[str, str]] = None,
- level: Union[int, str] = logging.NOTSET,
+ stream: TextIO | None = None,
+ highlighter: type[Highlighter] = PrefectConsoleHighlighter,
+ styles: dict[str, str] | None = None,
+ level: int | str = logging.NOTSET,
):
"""
The default console handler for Prefect, which highlights log levels,
@@ -307,14 +317,14 @@ def __init__(
theme = Theme(inherit=False)
self.level = level
- self.console = Console(
+ self.console: Console = Console(
highlighter=highlighter,
theme=theme,
file=self.stream,
markup=markup_console,
)
- def emit(self, record: logging.LogRecord):
+ def emit(self, record: logging.LogRecord) -> None:
try:
message = self.format(record)
self.console.print(message, soft_wrap=True)
diff --git a/src/prefect/logging/highlighters.py b/src/prefect/logging/highlighters.py
index b842f7c95240..ac9b84a273d0 100644
--- a/src/prefect/logging/highlighters.py
+++ b/src/prefect/logging/highlighters.py
@@ -7,7 +7,7 @@ class LevelHighlighter(RegexHighlighter):
"""Apply style to log levels."""
base_style = "level."
- highlights = [
+ highlights: list[str] = [
r"(?PDEBUG)",
r"(?PINFO)",
r"(?PWARNING)",
@@ -20,7 +20,7 @@ class UrlHighlighter(RegexHighlighter):
"""Apply style to urls."""
base_style = "url."
- highlights = [
+ highlights: list[str] = [
r"(?P(https|http|ws|wss):\/\/[0-9a-zA-Z\$\-\_\+\!`\(\)\,\.\?\/\;\:\&\=\%\#]*)",
r"(?P(file):\/\/[0-9a-zA-Z\$\-\_\+\!`\(\)\,\.\?\/\;\:\&\=\%\#]*)",
]
@@ -30,7 +30,7 @@ class NameHighlighter(RegexHighlighter):
"""Apply style to names."""
base_style = "name."
- highlights = [
+ highlights: list[str] = [
# ?i means case insensitive
# ?<= means find string right after the words: flow run
r"(?i)(?P(?<=flow run) \'(.*?)\')",
@@ -44,7 +44,7 @@ class StateHighlighter(RegexHighlighter):
"""Apply style to states."""
base_style = "state."
- highlights = [
+ highlights: list[str] = [
rf"(?P<{state.lower()}_state>{state.title()})" for state in StateType
] + [
r"(?PCached)(?=\(type=COMPLETED\))" # Highlight only "Cached"
@@ -55,7 +55,7 @@ class PrefectConsoleHighlighter(RegexHighlighter):
"""Applies style from multiple highlighters."""
base_style = "log."
- highlights = (
+ highlights: list[str] = (
LevelHighlighter.highlights
+ UrlHighlighter.highlights
+ NameHighlighter.highlights
diff --git a/src/prefect/logging/loggers.py b/src/prefect/logging/loggers.py
index 45a2f6195f73..3021ccec4e32 100644
--- a/src/prefect/logging/loggers.py
+++ b/src/prefect/logging/loggers.py
@@ -7,7 +7,7 @@
from contextlib import contextmanager
from functools import lru_cache
from logging import LogRecord
-from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Union
+from typing import TYPE_CHECKING, Any, List, Mapping, MutableMapping, Optional, Union
from typing_extensions import Self
@@ -15,6 +15,14 @@
from prefect.logging.filters import ObfuscateApiKeyFilter
from prefect.telemetry.logging import add_telemetry_log_handler
+if sys.version_info >= (3, 12):
+ LoggingAdapter = logging.LoggerAdapter[logging.Logger]
+else:
+ if TYPE_CHECKING:
+ LoggingAdapter = logging.LoggerAdapter[logging.Logger]
+ else:
+ LoggingAdapter = logging.LoggerAdapter
+
if TYPE_CHECKING:
from prefect.client.schemas import FlowRun as ClientFlowRun
from prefect.client.schemas.objects import FlowRun, TaskRun
@@ -23,11 +31,6 @@
from prefect.tasks import Task
from prefect.workers.base import BaseWorker
-if sys.version_info >= (3, 12):
- LoggingAdapter = logging.LoggerAdapter[logging.Logger]
-else:
- LoggingAdapter = logging.LoggerAdapter
-
class PrefectLogAdapter(LoggingAdapter):
"""
@@ -39,9 +42,9 @@ class PrefectLogAdapter(LoggingAdapter):
not a bug in the LoggingAdapter and subclassing is the intended workaround.
"""
- extra: Mapping[str, object] | None
-
- def process(self, msg: str, kwargs: dict[str, Any]) -> tuple[str, dict[str, Any]]: # type: ignore[incompatibleMethodOverride]
+ def process(
+ self, msg: str, kwargs: MutableMapping[str, Any]
+ ) -> tuple[str, MutableMapping[str, Any]]:
kwargs["extra"] = {**(self.extra or {}), **(kwargs.get("extra") or {})}
return (msg, kwargs)
@@ -164,7 +167,7 @@ def flow_run_logger(
flow_run: Union["FlowRun", "ClientFlowRun"],
flow: Optional["Flow[Any, Any]"] = None,
**kwargs: str,
-) -> LoggingAdapter:
+) -> PrefectLogAdapter:
"""
Create a flow run logger with the run's metadata attached.
@@ -192,7 +195,7 @@ def task_run_logger(
flow_run: Optional["FlowRun"] = None,
flow: Optional["Flow[Any, Any]"] = None,
**kwargs: Any,
-):
+) -> LoggingAdapter:
"""
Create a task run logger with the run's metadata attached.
@@ -228,7 +231,9 @@ def task_run_logger(
)
-def get_worker_logger(worker: "BaseWorker", name: Optional[str] = None):
+def get_worker_logger(
+ worker: "BaseWorker[Any, Any, Any]", name: Optional[str] = None
+) -> logging.Logger | LoggingAdapter:
"""
Create a worker logger with the worker's metadata attached.
@@ -364,7 +369,9 @@ def __init__(self, eavesdrop_on: str, level: int = logging.NOTSET):
# It's important that we use a very minimalistic formatter for use cases where
# we may present these logs back to the user. We shouldn't leak filenames,
# versions, or other environmental information.
- self.formatter = logging.Formatter("[%(levelname)s]: %(message)s")
+ self.formatter: logging.Formatter | None = logging.Formatter(
+ "[%(levelname)s]: %(message)s"
+ )
def __enter__(self) -> Self:
self._target_logger = logging.getLogger(self.eavesdrop_on)
@@ -374,7 +381,7 @@ def __enter__(self) -> Self:
self._lines = []
return self
- def __exit__(self, *_):
+ def __exit__(self, *_: Any) -> None:
if self._target_logger:
self._target_logger.removeHandler(self)
self._target_logger.level = self._original_level
diff --git a/src/prefect/main.py b/src/prefect/main.py
index d61e2c80e4d7..1e9e5421998f 100644
--- a/src/prefect/main.py
+++ b/src/prefect/main.py
@@ -3,7 +3,7 @@
from prefect.deployments import deploy
from prefect.states import State
from prefect.logging import get_run_logger
-from prefect.flows import flow, Flow, serve, aserve
+from prefect.flows import FlowDecorator, flow, Flow, serve, aserve
from prefect.transactions import Transaction
from prefect.tasks import task, Task
from prefect.context import tags
@@ -58,6 +58,8 @@
inject_renamed_module_alias_finder()
+flow: FlowDecorator
+
# Declare API for type-checkers
__all__ = [
diff --git a/src/prefect/runner/runner.py b/src/prefect/runner/runner.py
index 27fbde6e1a36..d7f7d55eda7c 100644
--- a/src/prefect/runner/runner.py
+++ b/src/prefect/runner/runner.py
@@ -32,6 +32,8 @@ def fast_flow():
"""
+from __future__ import annotations
+
import asyncio
import datetime
import logging
@@ -63,13 +65,14 @@ def fast_flow():
import anyio.abc
import pendulum
from cachetools import LRUCache
+from typing_extensions import Self
from prefect._internal.concurrency.api import (
create_call,
from_async,
from_sync,
)
-from prefect.client.orchestration import get_client
+from prefect.client.orchestration import PrefectClient, get_client
from prefect.client.schemas.filters import (
FlowRunFilter,
FlowRunFilterId,
@@ -196,25 +199,25 @@ def goodbye_flow(name):
if name and ("/" in name or "%" in name):
raise ValueError("Runner name cannot contain '/' or '%'")
- self.name = Path(name).stem if name is not None else f"runner-{uuid4()}"
- self._logger = get_logger("runner")
-
- self.started = False
- self.stopping = False
- self.pause_on_shutdown = pause_on_shutdown
- self.limit = limit or settings.runner.process_limit
- self.webserver = webserver
-
- self.query_seconds = query_seconds or settings.runner.poll_frequency
- self._prefetch_seconds = prefetch_seconds
- self.heartbeat_seconds = (
+ self.name: str = Path(name).stem if name is not None else f"runner-{uuid4()}"
+ self._logger: "logging.Logger" = get_logger("runner")
+
+ self.started: bool = False
+ self.stopping: bool = False
+ self.pause_on_shutdown: bool = pause_on_shutdown
+ self.limit: int | None = limit or settings.runner.process_limit
+ self.webserver: bool = webserver
+
+ self.query_seconds: float = query_seconds or settings.runner.poll_frequency
+ self._prefetch_seconds: float = prefetch_seconds
+ self.heartbeat_seconds: float | None = (
heartbeat_seconds or settings.runner.heartbeat_frequency
)
if self.heartbeat_seconds is not None and self.heartbeat_seconds < 30:
raise ValueError("Heartbeat must be 30 seconds or greater.")
- self._limiter: Optional[anyio.CapacityLimiter] = None
- self._client = get_client()
+ self._limiter: anyio.CapacityLimiter | None = None
+ self._client: PrefectClient = get_client()
self._submitting_flow_run_ids: set[UUID] = set()
self._cancelling_flow_run_ids: set[UUID] = set()
self._scheduled_task_scopes: set[UUID] = set()
@@ -224,8 +227,8 @@ def goodbye_flow(name):
self._tmp_dir: Path = (
Path(tempfile.gettempdir()) / "runner_storage" / str(uuid4())
)
- self._storage_objs: List[RunnerStorage] = []
- self._deployment_storage_map: Dict[UUID, RunnerStorage] = {}
+ self._storage_objs: list[RunnerStorage] = []
+ self._deployment_storage_map: dict[UUID, RunnerStorage] = {}
self._loop: Optional[asyncio.AbstractEventLoop] = None
@@ -488,7 +491,7 @@ def execute_in_background(
return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop)
- async def cancel_all(self):
+ async def cancel_all(self) -> None:
runs_to_cancel = []
# done to avoid dictionary size changing during iteration
@@ -525,7 +528,7 @@ async def stop(self):
async def execute_flow_run(
self, flow_run_id: UUID, entrypoint: Optional[str] = None
- ):
+ ) -> None:
"""
Executes a single flow run with the given ID.
@@ -777,7 +780,7 @@ async def _get_and_submit_flow_runs(self):
if self.stopping:
return
runs_response = await self._get_scheduled_flow_runs()
- self.last_polled = pendulum.now("UTC")
+ self.last_polled: pendulum.DateTime = pendulum.now("UTC")
return await self._submit_scheduled_flow_runs(flow_run_response=runs_response)
async def _check_for_cancelled_flow_runs(
@@ -1390,7 +1393,7 @@ async def _run_on_crashed_hooks(
await _run_hooks(hooks, flow_run, flow, state)
- async def __aenter__(self):
+ async def __aenter__(self) -> Self:
self._logger.debug("Starting runner...")
self._client = get_client()
self._tmp_dir.mkdir(parents=True)
@@ -1412,7 +1415,7 @@ async def __aenter__(self):
self.started = True
return self
- async def __aexit__(self, *exc_info: Any):
+ async def __aexit__(self, *exc_info: Any) -> None:
self._logger.debug("Stopping runner...")
if self.pause_on_shutdown:
await self._pause_schedules()
@@ -1430,7 +1433,7 @@ async def __aexit__(self, *exc_info: Any):
shutil.rmtree(str(self._tmp_dir))
del self._runs_task_group, self._loops_task_group
- def __repr__(self):
+ def __repr__(self) -> str:
return f"Runner(name={self.name!r})"
diff --git a/src/prefect/runner/server.py b/src/prefect/runner/server.py
index 9a3688b09b5c..b58cbf02e4d5 100644
--- a/src/prefect/runner/server.py
+++ b/src/prefect/runner/server.py
@@ -26,12 +26,14 @@
from prefect.utilities.importtools import load_script_as_module
if TYPE_CHECKING:
+ import logging
+
from prefect.client.schemas.responses import DeploymentResponse
from prefect.runner import Runner
from pydantic import BaseModel
-logger = get_logger("webserver")
+logger: "logging.Logger" = get_logger("webserver")
RunnableEndpoint = Literal["deployment", "flow", "task"]
diff --git a/src/prefect/runner/storage.py b/src/prefect/runner/storage.py
index 7a374b65eb8d..031ab1a05139 100644
--- a/src/prefect/runner/storage.py
+++ b/src/prefect/runner/storage.py
@@ -33,7 +33,7 @@ class RunnerStorage(Protocol):
remotely stored flow code.
"""
- def set_base_path(self, path: Path):
+ def set_base_path(self, path: Path) -> None:
"""
Sets the base path to use when pulling contents from remote storage to
local storage.
@@ -55,7 +55,7 @@ def destination(self) -> Path:
"""
...
- async def pull_code(self):
+ async def pull_code(self) -> None:
"""
Pulls contents from remote storage to the local filesystem.
"""
@@ -150,7 +150,7 @@ def __init__(
def destination(self) -> Path:
return self._storage_base_path / self._name
- def set_base_path(self, path: Path):
+ def set_base_path(self, path: Path) -> None:
self._storage_base_path = path
@property
@@ -221,7 +221,7 @@ async def is_sparsely_checked_out(self) -> bool:
except Exception:
return False
- async def pull_code(self):
+ async def pull_code(self) -> None:
"""
Pulls the contents of the configured repository to the local filesystem.
"""
@@ -324,7 +324,7 @@ async def _clone_repo(self):
cwd=self.destination,
)
- def __eq__(self, __value) -> bool:
+ def __eq__(self, __value: Any) -> bool:
if isinstance(__value, GitRepository):
return (
self._url == __value._url
@@ -339,7 +339,7 @@ def __repr__(self) -> str:
f" branch={self._branch!r})"
)
- def to_pull_step(self) -> Dict:
+ def to_pull_step(self) -> dict[str, Any]:
pull_step = {
"prefect.deployments.steps.git_clone": {
"repository": self._url,
@@ -466,7 +466,7 @@ def replace_blocks_with_values(obj: Any) -> Any:
return fsspec.filesystem(scheme, **settings_with_block_values)
- def set_base_path(self, path: Path):
+ def set_base_path(self, path: Path) -> None:
self._storage_base_path = path
@property
@@ -492,7 +492,7 @@ def _remote_path(self) -> Path:
_, netloc, urlpath, _, _ = urlsplit(self._url)
return Path(netloc) / Path(urlpath.lstrip("/"))
- async def pull_code(self):
+ async def pull_code(self) -> None:
"""
Pulls contents from remote storage to the local filesystem.
"""
@@ -522,7 +522,7 @@ async def pull_code(self):
f" {self.destination!r}"
) from exc
- def to_pull_step(self) -> dict:
+ def to_pull_step(self) -> dict[str, Any]:
"""
Returns a dictionary representation of the storage object that can be
used as a deployment pull step.
@@ -551,7 +551,7 @@ def replace_block_with_placeholder(obj: Any) -> Any:
] = required_package
return step
- def __eq__(self, __value) -> bool:
+ def __eq__(self, __value: Any) -> bool:
"""
Equality check for runner storage objects.
"""
@@ -590,7 +590,7 @@ def __init__(
else str(uuid4())
)
- def set_base_path(self, path: Path):
+ def set_base_path(self, path: Path) -> None:
self._storage_base_path = path
@property
@@ -601,12 +601,12 @@ def pull_interval(self) -> Optional[int]:
def destination(self) -> Path:
return self._storage_base_path / self._name
- async def pull_code(self):
+ async def pull_code(self) -> None:
if not self.destination.exists():
self.destination.mkdir(parents=True, exist_ok=True)
await self._block.get_directory(local_path=str(self.destination))
- def to_pull_step(self) -> dict:
+ def to_pull_step(self) -> dict[str, Any]:
# Give blocks the change to implement their own pull step
if hasattr(self._block, "get_pull_step"):
return self._block.get_pull_step()
@@ -623,7 +623,7 @@ def to_pull_step(self) -> dict:
}
}
- def __eq__(self, __value) -> bool:
+ def __eq__(self, __value: Any) -> bool:
if isinstance(__value, BlockStorageAdapter):
return self._block == __value._block
return False
@@ -658,19 +658,19 @@ def __init__(
def destination(self) -> Path:
return self._path
- def set_base_path(self, path: Path):
+ def set_base_path(self, path: Path) -> None:
self._storage_base_path = path
@property
def pull_interval(self) -> Optional[int]:
return self._pull_interval
- async def pull_code(self):
+ async def pull_code(self) -> None:
# Local storage assumes the code already exists on the local filesystem
# and does not need to be pulled from a remote location
pass
- def to_pull_step(self) -> dict:
+ def to_pull_step(self) -> dict[str, Any]:
"""
Returns a dictionary representation of the storage object that can be
used as a deployment pull step.
@@ -682,7 +682,7 @@ def to_pull_step(self) -> dict:
}
return step
- def __eq__(self, __value) -> bool:
+ def __eq__(self, __value: Any) -> bool:
if isinstance(__value, LocalStorage):
return self._path == __value._path
return False
diff --git a/src/prefect/runner/submit.py b/src/prefect/runner/submit.py
index ec42a4029a79..56c14b392e98 100644
--- a/src/prefect/runner/submit.py
+++ b/src/prefect/runner/submit.py
@@ -1,11 +1,13 @@
+from __future__ import annotations
+
import asyncio
import inspect
import uuid
-from typing import Any, Dict, List, Optional, Union, overload
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, overload
import anyio
import httpx
-from typing_extensions import Literal
+from typing_extensions import Literal, TypeAlias
from prefect.client.orchestration import get_client
from prefect.client.schemas.filters import FlowRunFilter, TaskRunFilter
@@ -22,12 +24,17 @@
from prefect.tasks import Task
from prefect.utilities.asyncutils import sync_compatible
-logger = get_logger("webserver")
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger("webserver")
+
+FlowOrTask: TypeAlias = Union[Flow[Any, Any], Task[Any, Any]]
async def _submit_flow_to_runner(
- flow: Flow,
- parameters: Dict[str, Any],
+ flow: Flow[Any, Any],
+ parameters: dict[str, Any],
retry_failed_submissions: bool = True,
) -> FlowRun:
"""
@@ -91,7 +98,7 @@ async def _submit_flow_to_runner(
@overload
def submit_to_runner(
- prefect_callable: Union[Flow, Task],
+ prefect_callable: Union[Flow[Any, Any], Task[Any, Any]],
parameters: Dict[str, Any],
retry_failed_submissions: bool = True,
) -> FlowRun:
@@ -100,19 +107,19 @@ def submit_to_runner(
@overload
def submit_to_runner(
- prefect_callable: Union[Flow, Task],
- parameters: List[Dict[str, Any]],
+ prefect_callable: Union[Flow[Any, Any], Task[Any, Any]],
+ parameters: list[dict[str, Any]],
retry_failed_submissions: bool = True,
-) -> List[FlowRun]:
+) -> list[FlowRun]:
...
@sync_compatible
async def submit_to_runner(
- prefect_callable: Union[Flow, Task],
- parameters: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
+ prefect_callable: FlowOrTask,
+ parameters: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = None,
retry_failed_submissions: bool = True,
-) -> Union[FlowRun, List[FlowRun]]:
+) -> Union[FlowRun, list[FlowRun]]:
"""
Submit a callable in the background via the runner webserver one or more times.
diff --git a/src/prefect/runtime/deployment.py b/src/prefect/runtime/deployment.py
index c0acff8c39bb..9dd2c49938d2 100644
--- a/src/prefect/runtime/deployment.py
+++ b/src/prefect/runtime/deployment.py
@@ -25,8 +25,10 @@ def get_task_runner():
object or those directly provided via API for this run
"""
+from __future__ import annotations
+
import os
-from typing import Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Callable, List, Optional
from prefect._internal.concurrency.api import create_call, from_sync
from prefect.client.orchestration import get_client
@@ -34,12 +36,17 @@ def get_task_runner():
from .flow_run import _get_flow_run
+if TYPE_CHECKING:
+ from prefect.client.schemas.responses import DeploymentResponse
+
__all__ = ["id", "flow_run_id", "name", "parameters", "version"]
-CACHED_DEPLOYMENT = {}
+CACHED_DEPLOYMENT: dict[str, "DeploymentResponse"] = {}
-type_cast = {
+type_cast: dict[
+ type[bool] | type[int] | type[float] | type[str] | type[None], Callable[[Any], Any]
+] = {
bool: lambda x: x.lower() == "true",
int: int,
float: float,
@@ -88,7 +95,7 @@ def __dir__() -> List[str]:
return sorted(__all__)
-async def _get_deployment(deployment_id):
+async def _get_deployment(deployment_id: str) -> "DeploymentResponse":
# deployments won't change between calls so let's avoid the lifecycle of a client
if CACHED_DEPLOYMENT.get(deployment_id):
return CACHED_DEPLOYMENT[deployment_id]
@@ -115,7 +122,7 @@ def get_id() -> Optional[str]:
return str(deployment_id)
-def get_parameters() -> Dict:
+def get_parameters() -> dict[str, Any]:
run_id = get_flow_run_id()
if run_id is None:
return {}
@@ -126,7 +133,7 @@ def get_parameters() -> Dict:
return flow_run.parameters or {}
-def get_name() -> Optional[Dict]:
+def get_name() -> Optional[str]:
dep_id = get_id()
if dep_id is None:
@@ -138,7 +145,7 @@ def get_name() -> Optional[Dict]:
return deployment.name
-def get_version() -> Optional[Dict]:
+def get_version() -> Optional[str]:
dep_id = get_id()
if dep_id is None:
@@ -154,7 +161,7 @@ def get_flow_run_id() -> Optional[str]:
return os.getenv("PREFECT__FLOW_RUN_ID")
-FIELDS = {
+FIELDS: dict[str, Callable[[], Any]] = {
"id": get_id,
"flow_run_id": get_flow_run_id,
"parameters": get_parameters,
diff --git a/src/prefect/runtime/flow_run.py b/src/prefect/runtime/flow_run.py
index 3c9be513bba8..0c4709e7aa9c 100644
--- a/src/prefect/runtime/flow_run.py
+++ b/src/prefect/runtime/flow_run.py
@@ -20,8 +20,10 @@
- `run_count`: the number of times this flow run has been run
"""
+from __future__ import annotations
+
import os
-from typing import Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import pendulum
@@ -30,6 +32,9 @@
from prefect.context import FlowRunContext, TaskRunContext
from prefect.settings import PREFECT_API_URL, PREFECT_UI_URL
+if TYPE_CHECKING:
+ from prefect.client.schemas.objects import Flow, FlowRun, TaskRun
+
__all__ = [
"id",
"tags",
@@ -56,7 +61,15 @@ def _pendulum_parse(dt: str) -> pendulum.DateTime:
return pendulum.parse(dt, tz=None, strict=False).set(tz="UTC")
-type_cast = {
+type_cast: dict[
+ type[bool]
+ | type[int]
+ | type[float]
+ | type[str]
+ | type[None]
+ | type[pendulum.DateTime],
+ Callable[[Any], Any],
+] = {
bool: lambda x: x.lower() == "true",
int: int,
float: float,
@@ -106,17 +119,17 @@ def __dir__() -> List[str]:
return sorted(__all__)
-async def _get_flow_run(flow_run_id):
+async def _get_flow_run(flow_run_id: str) -> "FlowRun":
async with get_client() as client:
return await client.read_flow_run(flow_run_id)
-async def _get_task_run(task_run_id):
+async def _get_task_run(task_run_id: str) -> "TaskRun":
async with get_client() as client:
return await client.read_task_run(task_run_id)
-async def _get_flow_from_run(flow_run_id):
+async def _get_flow_from_run(flow_run_id: str) -> "Flow":
async with get_client() as client:
flow_run = await client.read_flow_run(flow_run_id)
return await client.read_flow(flow_run.flow_id)
@@ -323,7 +336,7 @@ def get_job_variables() -> Optional[Dict[str, Any]]:
return flow_run_ctx.flow_run.job_variables if flow_run_ctx else None
-FIELDS = {
+FIELDS: dict[str, Callable[[], Any]] = {
"id": get_id,
"tags": get_tags,
"scheduled_start_time": get_scheduled_start_time,
diff --git a/src/prefect/runtime/task_run.py b/src/prefect/runtime/task_run.py
index da28070fa826..9b3e4cd5b1f9 100644
--- a/src/prefect/runtime/task_run.py
+++ b/src/prefect/runtime/task_run.py
@@ -15,15 +15,19 @@
- `task_name`: the name of the task
"""
+from __future__ import annotations
+
import os
-from typing import Any, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Optional
from prefect.context import TaskRunContext
__all__ = ["id", "tags", "name", "parameters", "run_count", "task_name"]
-type_cast = {
+type_cast: dict[
+ type[bool] | type[int] | type[float] | type[str] | type[None], Callable[[Any], Any]
+] = {
bool: lambda x: x.lower() == "true",
int: int,
float: float,
@@ -118,7 +122,7 @@ def get_parameters() -> Dict[str, Any]:
return {}
-FIELDS = {
+FIELDS: dict[str, Callable[[], Any]] = {
"id": get_id,
"tags": get_tags,
"name": get_name,
diff --git a/src/prefect/server/api/admin.py b/src/prefect/server/api/admin.py
index b55eb0bafd49..51c37460be64 100644
--- a/src/prefect/server/api/admin.py
+++ b/src/prefect/server/api/admin.py
@@ -9,7 +9,7 @@
from prefect.server.database import PrefectDBInterface, provide_database_interface
from prefect.server.utilities.server import PrefectRouter
-router = PrefectRouter(prefix="/admin", tags=["Admin"])
+router: PrefectRouter = PrefectRouter(prefix="/admin", tags=["Admin"])
@router.get("/settings")
@@ -37,7 +37,7 @@ async def clear_database(
description="Pass confirm=True to confirm you want to modify the database.",
),
response: Response = None, # type: ignore
-):
+) -> None:
"""Clear all database tables without dropping them."""
if not confirm:
response.status_code = status.HTTP_400_BAD_REQUEST
@@ -58,7 +58,7 @@ async def drop_database(
description="Pass confirm=True to confirm you want to modify the database.",
),
response: Response = None,
-):
+) -> None:
"""Drop all database objects."""
if not confirm:
response.status_code = status.HTTP_400_BAD_REQUEST
@@ -76,7 +76,7 @@ async def create_database(
description="Pass confirm=True to confirm you want to modify the database.",
),
response: Response = None,
-):
+) -> None:
"""Create all database objects."""
if not confirm:
response.status_code = status.HTTP_400_BAD_REQUEST
diff --git a/src/prefect/server/api/artifacts.py b/src/prefect/server/api/artifacts.py
index d3300b1a348f..9e6c049c4a48 100644
--- a/src/prefect/server/api/artifacts.py
+++ b/src/prefect/server/api/artifacts.py
@@ -14,7 +14,7 @@
from prefect.server.schemas import actions, core, filters, sorting
from prefect.server.utilities.server import PrefectRouter
-router = PrefectRouter(
+router: PrefectRouter = PrefectRouter(
prefix="/artifacts",
tags=["Artifacts"],
)
@@ -191,7 +191,7 @@ async def update_artifact(
..., description="The ID of the artifact to update.", alias="id"
),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Update an artifact in the database.
"""
@@ -211,7 +211,7 @@ async def delete_artifact(
..., description="The ID of the artifact to delete.", alias="id"
),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Delete an artifact from the database.
"""
diff --git a/src/prefect/server/api/automations.py b/src/prefect/server/api/automations.py
index 832e66515f17..847442273177 100644
--- a/src/prefect/server/api/automations.py
+++ b/src/prefect/server/api/automations.py
@@ -27,7 +27,7 @@
ValidationError as JSONSchemaValidationError,
)
-router = PrefectRouter(
+router: PrefectRouter = PrefectRouter(
prefix="/automations",
tags=["Automations"],
dependencies=[],
@@ -91,7 +91,7 @@ async def update_automation(
automation: AutomationUpdate,
automation_id: UUID = Path(..., alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
# reset any client-provided IDs on the provided triggers
automation.trigger.reset_ids()
@@ -134,7 +134,7 @@ async def patch_automation(
automation: AutomationPartialUpdate,
automation_id: UUID = Path(..., alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
try:
async with db.session_context(begin_transaction=True) as session:
updated = await automations_models.update_automation(
@@ -159,7 +159,7 @@ async def patch_automation(
async def delete_automation(
automation_id: UUID = Path(..., alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
async with db.session_context(begin_transaction=True) as session:
deleted = await automations_models.delete_automation(
session=session,
@@ -228,7 +228,7 @@ async def read_automations_related_to_resource(
async def delete_automations_owned_by_resource(
resource_id: str = Path(..., alias="resource_id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
async with db.session_context(begin_transaction=True) as session:
await automations_models.delete_automations_owned_by_resource(
session,
diff --git a/src/prefect/server/api/block_capabilities.py b/src/prefect/server/api/block_capabilities.py
index 3da721adaf30..f386391f12ec 100644
--- a/src/prefect/server/api/block_capabilities.py
+++ b/src/prefect/server/api/block_capabilities.py
@@ -10,7 +10,9 @@
from prefect.server.database import PrefectDBInterface, provide_database_interface
from prefect.server.utilities.server import PrefectRouter
-router = PrefectRouter(prefix="/block_capabilities", tags=["Block capabilities"])
+router: PrefectRouter = PrefectRouter(
+ prefix="/block_capabilities", tags=["Block capabilities"]
+)
@router.get("/")
diff --git a/src/prefect/server/api/block_documents.py b/src/prefect/server/api/block_documents.py
index dd8deb82b6dc..02153f6cf0e5 100644
--- a/src/prefect/server/api/block_documents.py
+++ b/src/prefect/server/api/block_documents.py
@@ -12,7 +12,9 @@
from prefect.server.database import PrefectDBInterface, provide_database_interface
from prefect.server.utilities.server import PrefectRouter
-router = PrefectRouter(prefix="/block_documents", tags=["Block documents"])
+router: PrefectRouter = PrefectRouter(
+ prefix="/block_documents", tags=["Block documents"]
+)
@router.post("/", status_code=status.HTTP_201_CREATED)
@@ -124,7 +126,7 @@ async def delete_block_document(
..., description="The block document id", alias="id"
),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
async with db.session_context(begin_transaction=True) as session:
result = await models.block_documents.delete_block_document(
session=session, block_document_id=block_document_id
@@ -142,7 +144,7 @@ async def update_block_document_data(
..., description="The block document id", alias="id"
),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
try:
async with db.session_context(begin_transaction=True) as session:
result = await models.block_documents.update_block_document(
diff --git a/src/prefect/server/api/block_schemas.py b/src/prefect/server/api/block_schemas.py
index d079f24187ce..36525a0e9531 100644
--- a/src/prefect/server/api/block_schemas.py
+++ b/src/prefect/server/api/block_schemas.py
@@ -21,7 +21,7 @@
from prefect.server.models.block_schemas import MissingBlockTypeException
from prefect.server.utilities.server import PrefectRouter
-router = PrefectRouter(prefix="/block_schemas", tags=["Block schemas"])
+router: PrefectRouter = PrefectRouter(prefix="/block_schemas", tags=["Block schemas"])
@router.post("/", status_code=status.HTTP_201_CREATED)
@@ -68,8 +68,8 @@ async def create_block_schema(
async def delete_block_schema(
block_schema_id: UUID = Path(..., description="The block schema id", alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
- api_version=Depends(dependencies.provide_request_api_version),
-):
+ api_version: str = Depends(dependencies.provide_request_api_version),
+) -> None:
"""
Delete a block schema by id.
"""
diff --git a/src/prefect/server/api/block_types.py b/src/prefect/server/api/block_types.py
index 50351781dc64..b29a8eeb0144 100644
--- a/src/prefect/server/api/block_types.py
+++ b/src/prefect/server/api/block_types.py
@@ -10,7 +10,7 @@
from prefect.server.database import PrefectDBInterface, provide_database_interface
from prefect.server.utilities.server import PrefectRouter
-router = PrefectRouter(prefix="/block_types", tags=["Block types"])
+router: PrefectRouter = PrefectRouter(prefix="/block_types", tags=["Block types"])
@router.post("/", status_code=status.HTTP_201_CREATED)
@@ -101,7 +101,7 @@ async def update_block_type(
block_type: schemas.actions.BlockTypeUpdate,
block_type_id: UUID = Path(..., description="The block type ID", alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Update a block type.
"""
@@ -131,7 +131,7 @@ async def update_block_type(
async def delete_block_type(
block_type_id: UUID = Path(..., description="The block type ID", alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
async with db.session_context(begin_transaction=True) as session:
db_block_type = await models.block_types.read_block_type(
session=session, block_type_id=block_type_id
@@ -204,7 +204,7 @@ async def read_block_document_by_name_for_block_type(
@router.post("/install_system_block_types")
async def install_system_block_types(
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
# Don't begin a transaction. _install_protected_system_blocks will manage
# the transactions.
async with db.session_context(begin_transaction=False) as session:
diff --git a/src/prefect/server/api/clients.py b/src/prefect/server/api/clients.py
index 1c4d5424e5cc..3a2dd6a1d763 100644
--- a/src/prefect/server/api/clients.py
+++ b/src/prefect/server/api/clients.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from typing import Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
from urllib.parse import quote
from uuid import UUID
@@ -19,7 +19,10 @@
from prefect.server.schemas.responses import DeploymentResponse, OrchestrationResult
from prefect.types import StrictVariableValue
-logger = get_logger(__name__)
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger(__name__)
class BaseClient:
diff --git a/src/prefect/server/api/collections.py b/src/prefect/server/api/collections.py
index 66b47d9d2068..a5c01ebc3f82 100644
--- a/src/prefect/server/api/collections.py
+++ b/src/prefect/server/api/collections.py
@@ -8,9 +8,11 @@
from prefect.server.utilities.server import PrefectRouter
-router = PrefectRouter(prefix="/collections", tags=["Collections"])
+router: PrefectRouter = PrefectRouter(prefix="/collections", tags=["Collections"])
-GLOBAL_COLLECTIONS_VIEW_CACHE: TTLCache = TTLCache(maxsize=200, ttl=60 * 10)
+GLOBAL_COLLECTIONS_VIEW_CACHE: TTLCache[str, dict[str, Any]] = TTLCache(
+ maxsize=200, ttl=60 * 10
+)
REGISTRY_VIEWS = (
"https://raw.githubusercontent.com/PrefectHQ/prefect-collection-registry/main/views"
@@ -43,7 +45,7 @@ async def read_view_content(view: str) -> Dict[str, Any]:
raise
-async def get_collection_view(view: str):
+async def get_collection_view(view: str) -> dict[str, Any]:
try:
return GLOBAL_COLLECTIONS_VIEW_CACHE[view]
except KeyError:
diff --git a/src/prefect/server/api/concurrency_limits.py b/src/prefect/server/api/concurrency_limits.py
index 990cac57feeb..231f88b19e6d 100644
--- a/src/prefect/server/api/concurrency_limits.py
+++ b/src/prefect/server/api/concurrency_limits.py
@@ -17,7 +17,9 @@
from prefect.server.utilities.server import PrefectRouter
from prefect.settings import PREFECT_TASK_RUN_TAG_CONCURRENCY_SLOT_WAIT_SECONDS
-router = PrefectRouter(prefix="/concurrency_limits", tags=["Concurrency Limits"])
+router: PrefectRouter = PrefectRouter(
+ prefix="/concurrency_limits", tags=["Concurrency Limits"]
+)
@router.post("/")
@@ -119,7 +121,7 @@ async def reset_concurrency_limit_by_tag(
description="Manual override for active concurrency limit slots.",
),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
async with db.session_context(begin_transaction=True) as session:
model = await models.concurrency_limits.reset_concurrency_limit_by_tag(
session=session, tag=tag, slot_override=slot_override
@@ -136,7 +138,7 @@ async def delete_concurrency_limit(
..., description="The concurrency limit id", alias="id"
),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
async with db.session_context(begin_transaction=True) as session:
result = await models.concurrency_limits.delete_concurrency_limit(
session=session, concurrency_limit_id=concurrency_limit_id
@@ -151,7 +153,7 @@ async def delete_concurrency_limit(
async def delete_concurrency_limit_by_tag(
tag: str = Path(..., description="The tag name"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
async with db.session_context(begin_transaction=True) as session:
result = await models.concurrency_limits.delete_concurrency_limit_by_tag(
session=session, tag=tag
@@ -263,7 +265,7 @@ async def decrement_concurrency_limits_v1(
..., description="The ID of the task run releasing the slot"
),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
async with db.session_context(begin_transaction=True) as session:
filtered_limits = (
await concurrency_limits.filter_concurrency_limits_for_orchestration(
diff --git a/src/prefect/server/api/concurrency_limits_v2.py b/src/prefect/server/api/concurrency_limits_v2.py
index 76d96313c060..75b0f25115c2 100644
--- a/src/prefect/server/api/concurrency_limits_v2.py
+++ b/src/prefect/server/api/concurrency_limits_v2.py
@@ -11,7 +11,9 @@
from prefect.server.utilities.schemas import PrefectBaseModel
from prefect.server.utilities.server import PrefectRouter
-router = PrefectRouter(prefix="/v2/concurrency_limits", tags=["Concurrency Limits V2"])
+router: PrefectRouter = PrefectRouter(
+ prefix="/v2/concurrency_limits", tags=["Concurrency Limits V2"]
+)
@router.post("/", status_code=status.HTTP_201_CREATED)
@@ -85,7 +87,7 @@ async def update_concurrency_limit_v2(
..., description="The ID or name of the concurrency limit", alias="id_or_name"
),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
if isinstance(id_or_name, str): # TODO: this seems like it shouldn't be necessary
try:
id_or_name = UUID(id_or_name)
@@ -115,7 +117,7 @@ async def delete_concurrency_limit_v2(
..., description="The ID or name of the concurrency limit", alias="id_or_name"
),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
if isinstance(id_or_name, str): # TODO: this seems like it shouldn't be necessary
try:
id_or_name = UUID(id_or_name)
diff --git a/src/prefect/server/api/csrf_token.py b/src/prefect/server/api/csrf_token.py
index 94eaa399da7f..b4554876b9ae 100644
--- a/src/prefect/server/api/csrf_token.py
+++ b/src/prefect/server/api/csrf_token.py
@@ -1,3 +1,5 @@
+from typing import TYPE_CHECKING
+
from fastapi import Depends, Query, status
from starlette.exceptions import HTTPException
@@ -7,9 +9,12 @@
from prefect.server.utilities.server import PrefectRouter
from prefect.settings import PREFECT_SERVER_CSRF_PROTECTION_ENABLED
-logger = get_logger("server.api")
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger("server.api")
-router = PrefectRouter(prefix="/csrf-token")
+router: PrefectRouter = PrefectRouter(prefix="/csrf-token")
@router.get("")
diff --git a/src/prefect/server/api/dependencies.py b/src/prefect/server/api/dependencies.py
index 6f9fb797b12e..b20db69cf127 100644
--- a/src/prefect/server/api/dependencies.py
+++ b/src/prefect/server/api/dependencies.py
@@ -1,11 +1,12 @@
"""
Utilities for injecting FastAPI dependencies.
"""
+from __future__ import annotations
import logging
import re
from base64 import b64decode
-from typing import Annotated, Optional
+from typing import Annotated, Any, Optional
from uuid import UUID
from fastapi import Body, Depends, Header, HTTPException, status
@@ -16,13 +17,15 @@
from prefect.settings import PREFECT_API_DEFAULT_LIMIT
-def provide_request_api_version(x_prefect_api_version: str = Header(None)):
+def provide_request_api_version(
+ x_prefect_api_version: str = Header(None),
+) -> Version | None:
if not x_prefect_api_version:
return
# parse version
try:
- major, minor, patch = [int(v) for v in x_prefect_api_version.split(".")]
+ _, _, _ = [int(v) for v in x_prefect_api_version.split(".")]
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -46,15 +49,15 @@ class EnforceMinimumAPIVersion:
def __init__(self, minimum_api_version: str, logger: logging.Logger):
self.minimum_api_version = minimum_api_version
versions = [int(v) for v in minimum_api_version.split(".")]
- self.api_major = versions[0]
- self.api_minor = versions[1]
- self.api_patch = versions[2]
+ self.api_major: int = versions[0]
+ self.api_minor: int = versions[1]
+ self.api_patch: int = versions[2]
self.logger = logger
async def __call__(
self,
x_prefect_api_version: str = Header(None),
- ):
+ ) -> None:
request_version = x_prefect_api_version
# if no version header, assume latest and continue
@@ -96,7 +99,7 @@ async def _notify_of_outdated_version(self, request_version: str):
)
-def LimitBody() -> Depends:
+def LimitBody() -> Any:
"""
A `fastapi.Depends` factory for pulling a `limit: int` parameter from the
request body while determining the default from the current settings.
@@ -163,7 +166,7 @@ def get_updated_by(
return None
-def is_ephemeral_request(request: Request):
+def is_ephemeral_request(request: Request) -> bool:
"""
A dependency that returns whether the request is to an ephemeral server.
"""
diff --git a/src/prefect/server/api/deployments.py b/src/prefect/server/api/deployments.py
index b0b89b721f79..96463c09e79a 100644
--- a/src/prefect/server/api/deployments.py
+++ b/src/prefect/server/api/deployments.py
@@ -37,7 +37,7 @@
validate,
)
-router = PrefectRouter(prefix="/deployments", tags=["Deployments"])
+router: PrefectRouter = PrefectRouter(prefix="/deployments", tags=["Deployments"])
def _multiple_schedules_error(deployment_id) -> HTTPException:
@@ -181,7 +181,7 @@ async def update_deployment(
deployment: schemas.actions.DeploymentUpdate,
deployment_id: UUID = Path(..., description="The deployment id", alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
async with db.session_context(begin_transaction=True) as session:
existing_deployment = await models.deployments.read_deployment(
session=session, deployment_id=deployment_id
@@ -485,7 +485,7 @@ async def count_deployments(
async def delete_deployment(
deployment_id: UUID = Path(..., description="The deployment id", alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Delete a deployment by id.
"""
@@ -811,7 +811,7 @@ async def update_deployment_schedule(
default=..., description="The updated schedule"
),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
async with db.session_context(begin_transaction=True) as session:
deployment = await models.deployments.read_deployment(
session=session, deployment_id=deployment_id
@@ -844,7 +844,7 @@ async def delete_deployment_schedule(
deployment_id: UUID = Path(..., description="The deployment id", alias="id"),
schedule_id: UUID = Path(..., description="The schedule id", alias="schedule_id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
async with db.session_context(begin_transaction=True) as session:
deployment = await models.deployments.read_deployment(
session=session, deployment_id=deployment_id
diff --git a/src/prefect/server/api/events.py b/src/prefect/server/api/events.py
index 87f4f914f484..7e5c87f7e0c3 100644
--- a/src/prefect/server/api/events.py
+++ b/src/prefect/server/api/events.py
@@ -1,5 +1,5 @@
import base64
-from typing import List, Optional
+from typing import TYPE_CHECKING, List, Optional
from fastapi import Response, WebSocket, status
from fastapi.exceptions import HTTPException
@@ -34,17 +34,20 @@
PREFECT_EVENTS_WEBSOCKET_BACKFILL_PAGE_SIZE,
)
-logger = get_logger(__name__)
+if TYPE_CHECKING:
+ import logging
+logger: "logging.Logger" = get_logger(__name__)
-router = PrefectRouter(prefix="/events", tags=["Events"])
+
+router: PrefectRouter = PrefectRouter(prefix="/events", tags=["Events"])
@router.post("", status_code=status.HTTP_204_NO_CONTENT, response_class=Response)
async def create_events(
events: List[Event],
ephemeral_request: bool = Depends(is_ephemeral_request),
-):
+) -> None:
"""Record a batch of Events"""
if ephemeral_request:
await EventsPipeline().process_events(events)
diff --git a/src/prefect/server/api/flow_run_notification_policies.py b/src/prefect/server/api/flow_run_notification_policies.py
index c0a6da90274b..47b4f871fc2a 100644
--- a/src/prefect/server/api/flow_run_notification_policies.py
+++ b/src/prefect/server/api/flow_run_notification_policies.py
@@ -13,7 +13,7 @@
from prefect.server.database import PrefectDBInterface, provide_database_interface
from prefect.server.utilities.server import PrefectRouter
-router = PrefectRouter(
+router: PrefectRouter = PrefectRouter(
prefix="/flow_run_notification_policies", tags=["Flow Run Notification Policies"]
)
@@ -39,7 +39,7 @@ async def update_flow_run_notification_policy(
..., description="The flow run notification policy id", alias="id"
),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Updates an existing flow run notification policy.
"""
@@ -104,7 +104,7 @@ async def delete_flow_run_notification_policy(
..., description="The flow run notification policy id", alias="id"
),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Delete a flow run notification policy by id.
"""
diff --git a/src/prefect/server/api/flow_run_states.py b/src/prefect/server/api/flow_run_states.py
index 72784f0175a1..438e531a9498 100644
--- a/src/prefect/server/api/flow_run_states.py
+++ b/src/prefect/server/api/flow_run_states.py
@@ -12,7 +12,9 @@
from prefect.server.database import PrefectDBInterface, provide_database_interface
from prefect.server.utilities.server import PrefectRouter
-router = PrefectRouter(prefix="/flow_run_states", tags=["Flow Run States"])
+router: PrefectRouter = PrefectRouter(
+ prefix="/flow_run_states", tags=["Flow Run States"]
+)
@router.get("/{id}")
diff --git a/src/prefect/server/api/flow_runs.py b/src/prefect/server/api/flow_runs.py
index b3e36cf2a0b2..8ddfa9d2689f 100644
--- a/src/prefect/server/api/flow_runs.py
+++ b/src/prefect/server/api/flow_runs.py
@@ -5,7 +5,7 @@
import csv
import datetime
import io
-from typing import Any, Dict, List, Optional, Type
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
from uuid import UUID
import orjson
@@ -37,7 +37,10 @@
read_flow_run_graph,
)
from prefect.server.orchestration import dependencies as orchestration_dependencies
-from prefect.server.orchestration.policies import BaseOrchestrationPolicy
+from prefect.server.orchestration.policies import (
+ FlowRunOrchestrationPolicy,
+ TaskRunOrchestrationPolicy,
+)
from prefect.server.schemas.graph import Graph
from prefect.server.schemas.responses import (
FlowRunPaginationResponse,
@@ -47,9 +50,12 @@
from prefect.types import DateTime
from prefect.utilities import schema_tools
-logger = get_logger("server.api")
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger("server.api")
-router = PrefectRouter(prefix="/flow_runs", tags=["Flow Runs"])
+router: PrefectRouter = PrefectRouter(prefix="/flow_runs", tags=["Flow Runs"])
@router.post("/")
@@ -101,7 +107,7 @@ async def update_flow_run(
flow_run: schemas.actions.FlowRunUpdate,
flow_run_id: UUID = Path(..., description="The flow run id", alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Updates a flow run.
"""
@@ -349,18 +355,18 @@ async def read_flow_run_graph_v2(
async def resume_flow_run(
flow_run_id: UUID = Path(..., description="The flow run id", alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
- run_input: Optional[Dict] = Body(default=None, embed=True),
+ run_input: Optional[dict[str, Any]] = Body(default=None, embed=True),
response: Response = None,
- flow_policy: Type[BaseOrchestrationPolicy] = Depends(
+ flow_policy: type[FlowRunOrchestrationPolicy] = Depends(
orchestration_dependencies.provide_flow_policy
),
- task_policy: BaseOrchestrationPolicy = Depends(
+ task_policy: type[TaskRunOrchestrationPolicy] = Depends(
orchestration_dependencies.provide_task_policy
),
orchestration_parameters: Dict[str, Any] = Depends(
orchestration_dependencies.provide_flow_orchestration_parameters
),
- api_version=Depends(dependencies.provide_request_api_version),
+ api_version: str = Depends(dependencies.provide_request_api_version),
) -> OrchestrationResult:
"""
Resume a paused flow run.
@@ -539,7 +545,7 @@ async def read_flow_runs(
async def delete_flow_run(
flow_run_id: UUID = Path(..., description="The flow run id", alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Delete a flow run by id.
"""
@@ -565,14 +571,14 @@ async def set_flow_run_state(
),
),
db: PrefectDBInterface = Depends(provide_database_interface),
- response: Response = None,
- flow_policy: Type[BaseOrchestrationPolicy] = Depends(
+ flow_policy: type[FlowRunOrchestrationPolicy] = Depends(
orchestration_dependencies.provide_flow_policy
),
orchestration_parameters: Dict[str, Any] = Depends(
orchestration_dependencies.provide_flow_orchestration_parameters
),
- api_version=Depends(dependencies.provide_request_api_version),
+ response: Response = None,
+ api_version: str = Depends(dependencies.provide_request_api_version),
) -> OrchestrationResult:
"""Set a flow run state, invoking any orchestration rules."""
@@ -611,7 +617,7 @@ async def create_flow_run_input(
value: bytes = Body(..., description="The value of the input"),
sender: Optional[str] = Body(None, description="The sender of the input"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Create a key/value input for a flow run.
"""
@@ -693,7 +699,7 @@ async def delete_flow_run_input(
flow_run_id: UUID = Path(..., description="The flow run id", alias="id"),
key: str = Path(..., description="The input key", alias="key"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Delete a flow run input
"""
@@ -848,7 +854,7 @@ async def update_flow_run_labels(
flow_run_id: UUID = Path(..., description="The flow run id", alias="id"),
labels: Dict[str, Any] = Body(..., description="The labels to update"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Update the labels of a flow run.
"""
diff --git a/src/prefect/server/api/flows.py b/src/prefect/server/api/flows.py
index 4f7b5696c88d..a62adaacab5e 100644
--- a/src/prefect/server/api/flows.py
+++ b/src/prefect/server/api/flows.py
@@ -16,7 +16,7 @@
from prefect.server.schemas.responses import FlowPaginationResponse
from prefect.server.utilities.server import PrefectRouter
-router = PrefectRouter(prefix="/flows", tags=["Flows"])
+router: PrefectRouter = PrefectRouter(prefix="/flows", tags=["Flows"])
@router.post("/")
@@ -46,7 +46,7 @@ async def update_flow(
flow: schemas.actions.FlowUpdate,
flow_id: UUID = Path(..., description="The flow id", alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Updates a flow.
"""
@@ -150,7 +150,7 @@ async def read_flows(
async def delete_flow(
flow_id: UUID = Path(..., description="The flow id", alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Delete a flow by id.
"""
diff --git a/src/prefect/server/api/logs.py b/src/prefect/server/api/logs.py
index 18b723c1f91f..b473177fc3d5 100644
--- a/src/prefect/server/api/logs.py
+++ b/src/prefect/server/api/logs.py
@@ -12,14 +12,14 @@
from prefect.server.database import PrefectDBInterface, provide_database_interface
from prefect.server.utilities.server import PrefectRouter
-router = PrefectRouter(prefix="/logs", tags=["Logs"])
+router: PrefectRouter = PrefectRouter(prefix="/logs", tags=["Logs"])
@router.post("/", status_code=status.HTTP_201_CREATED)
async def create_logs(
logs: List[schemas.actions.LogCreate],
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""Create new logs from the provided schema."""
for batch in models.logs.split_logs_into_batches(logs):
async with db.session_context(begin_transaction=True) as session:
diff --git a/src/prefect/server/api/root.py b/src/prefect/server/api/root.py
index 8a52153a4b8b..27420a12d563 100644
--- a/src/prefect/server/api/root.py
+++ b/src/prefect/server/api/root.py
@@ -8,11 +8,11 @@
from prefect.server.database import PrefectDBInterface, provide_database_interface
from prefect.server.utilities.server import PrefectRouter
-router = PrefectRouter(prefix="", tags=["Root"])
+router: PrefectRouter = PrefectRouter(prefix="", tags=["Root"])
@router.get("/hello")
-async def hello():
+async def hello() -> str:
"""Say hello!"""
return "👋"
@@ -20,7 +20,7 @@ async def hello():
@router.get("/ready")
async def perform_readiness_check(
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> JSONResponse:
is_db_connectable = await db.is_db_connectable()
if is_db_connectable:
diff --git a/src/prefect/server/api/run_history.py b/src/prefect/server/api/run_history.py
index b707b06ed4fb..85e7c3577861 100644
--- a/src/prefect/server/api/run_history.py
+++ b/src/prefect/server/api/run_history.py
@@ -4,7 +4,7 @@
import datetime
import json
-from typing import List, Optional
+from typing import TYPE_CHECKING, List, Optional
import pydantic
import sqlalchemy as sa
@@ -16,7 +16,10 @@
from prefect.server.database import PrefectDBInterface, db_injector
from prefect.types import DateTime
-logger = get_logger("server.api")
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger("server.api")
@db_injector
diff --git a/src/prefect/server/api/saved_searches.py b/src/prefect/server/api/saved_searches.py
index fce131484590..786c8df03533 100644
--- a/src/prefect/server/api/saved_searches.py
+++ b/src/prefect/server/api/saved_searches.py
@@ -14,7 +14,7 @@
from prefect.server.database import PrefectDBInterface, provide_database_interface
from prefect.server.utilities.server import PrefectRouter
-router = PrefectRouter(prefix="/saved_searches", tags=["SavedSearches"])
+router: PrefectRouter = PrefectRouter(prefix="/saved_searches", tags=["SavedSearches"])
@router.put("/")
@@ -85,7 +85,7 @@ async def read_saved_searches(
async def delete_saved_search(
saved_search_id: UUID = Path(..., description="The saved search id", alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Delete a saved search by id.
"""
diff --git a/src/prefect/server/api/server.py b/src/prefect/server/api/server.py
index 908ae5c3afc8..55f6b26f56e9 100644
--- a/src/prefect/server/api/server.py
+++ b/src/prefect/server/api/server.py
@@ -21,7 +21,7 @@
from contextlib import asynccontextmanager
from functools import partial, wraps
from hashlib import sha256
-from typing import Any, Awaitable, Callable, Optional
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional
import anyio
import asyncpg
@@ -66,17 +66,20 @@
)
from prefect.utilities.hashing import hash_objects
+if TYPE_CHECKING:
+ import logging
+
TITLE = "Prefect Server"
API_TITLE = "Prefect Prefect REST API"
UI_TITLE = "Prefect Prefect REST API UI"
-API_VERSION = prefect.__version__
+API_VERSION: str = prefect.__version__
# migrations should run only once per app start; the ephemeral API can potentially
# create multiple apps in a single process
LIFESPAN_RAN_FOR_APP: set[Any] = set()
-logger = get_logger("server")
+logger: "logging.Logger" = get_logger("server")
-enforce_minimum_version = EnforceMinimumAPIVersion(
+enforce_minimum_version: EnforceMinimumAPIVersion = EnforceMinimumAPIVersion(
# this should be <= SERVER_API_VERSION; clients that send
# a version header under this value will be rejected
minimum_api_version="0.8.0",
@@ -154,7 +157,9 @@ async def __call__(self, scope: Any, receive: Any, send: Any) -> None:
await self.app(scope, receive, send)
-async def validation_exception_handler(request: Request, exc: RequestValidationError):
+async def validation_exception_handler(
+ request: Request, exc: RequestValidationError
+) -> JSONResponse:
"""Provide a detailed message for request validation errors."""
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -168,7 +173,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
)
-async def integrity_exception_handler(request: Request, exc: Exception):
+async def integrity_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""Capture database integrity errors."""
logger.error("Encountered exception in request:", exc_info=True)
return JSONResponse(
@@ -183,7 +188,7 @@ async def integrity_exception_handler(request: Request, exc: Exception):
)
-def is_client_retryable_exception(exc: Exception):
+def is_client_retryable_exception(exc: Exception) -> bool:
if isinstance(exc, sqlalchemy.exc.OperationalError) and isinstance(
exc.orig, sqlite3.OperationalError
):
@@ -217,7 +222,7 @@ def replace_placeholder_string_in_files(
placeholder: str,
replacement: str,
allowed_extensions: list[str] | None = None,
-):
+) -> None:
"""
Recursively loops through all files in the given directory and replaces
a placeholder string.
@@ -253,7 +258,9 @@ def copy_directory(directory: str, path: str) -> None:
shutil.copy2(source, destination)
-async def custom_internal_exception_handler(request: Request, exc: Exception):
+async def custom_internal_exception_handler(
+ request: Request, exc: Exception
+) -> JSONResponse:
"""
Log a detailed exception for internal server errors before returning.
@@ -278,7 +285,7 @@ async def custom_internal_exception_handler(request: Request, exc: Exception):
async def prefect_object_not_found_exception_handler(
request: Request, exc: ObjectNotFoundError
-):
+) -> JSONResponse:
"""Return 404 status code on object not found exceptions."""
return JSONResponse(
content={"exception_message": str(exc)}, status_code=status.HTTP_404_NOT_FOUND
@@ -353,6 +360,9 @@ async def server_version() -> str: # type: ignore[reportUnusedFunction]
async def token_validation(request: Request, call_next: Any): # type: ignore[reportUnusedFunction]
header_token = request.headers.get("Authorization")
+ # used for probes in k8s and such
+ if request.url.path in ["/api/health", "/api/ready"]:
+ return await call_next(request)
try:
if header_token is None:
return JSONResponse(
@@ -787,7 +797,7 @@ def openapi():
return app
-subprocess_server_logger = get_logger()
+subprocess_server_logger: "logging.Logger" = get_logger()
class SubprocessASGIServer:
@@ -808,12 +818,11 @@ def __init__(self, port: Optional[int] = None):
# This ensures initialization happens only once
if not hasattr(self, "_initialized"):
self.port: Optional[int] = port
- self.server_process = None
- self.server = None
- self.running = False
+ self.server_process: subprocess.Popen[Any] | None = None
+ self.running: bool = False
self._initialized = True
- def find_available_port(self):
+ def find_available_port(self) -> int:
max_attempts = 10
for _ in range(max_attempts):
port = random.choice(self._port_range)
@@ -823,7 +832,7 @@ def find_available_port(self):
raise RuntimeError("Unable to find an available port after multiple attempts")
@staticmethod
- def is_port_available(port: int):
+ def is_port_available(port: int) -> bool:
with contextlib.closing(
socket.socket(socket.AF_INET, socket.SOCK_STREAM)
) as sock:
@@ -841,7 +850,7 @@ def address(self) -> str:
def api_url(self) -> str:
return f"{self.address}/api"
- def start(self, timeout: Optional[int] = None):
+ def start(self, timeout: Optional[int] = None) -> None:
"""
Start the server in a separate process. Safe to call multiple times; only starts
the server once.
@@ -935,7 +944,7 @@ def _run_uvicorn_command(self) -> subprocess.Popen[Any]:
},
)
- def stop(self):
+ def stop(self) -> None:
if self.server_process:
subprocess_server_logger.info(
f"Stopping temporary server on {self.address}"
diff --git a/src/prefect/server/api/task_run_states.py b/src/prefect/server/api/task_run_states.py
index ef68e8b4d7a0..9722a1d3d93b 100644
--- a/src/prefect/server/api/task_run_states.py
+++ b/src/prefect/server/api/task_run_states.py
@@ -12,7 +12,9 @@
from prefect.server.database import PrefectDBInterface, provide_database_interface
from prefect.server.utilities.server import PrefectRouter
-router = PrefectRouter(prefix="/task_run_states", tags=["Task Run States"])
+router: PrefectRouter = PrefectRouter(
+ prefix="/task_run_states", tags=["Task Run States"]
+)
@router.get("/{id}")
diff --git a/src/prefect/server/api/task_runs.py b/src/prefect/server/api/task_runs.py
index 69d5e14adbd6..316eb2551e5d 100644
--- a/src/prefect/server/api/task_runs.py
+++ b/src/prefect/server/api/task_runs.py
@@ -4,7 +4,7 @@
import asyncio
import datetime
-from typing import Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
from uuid import UUID
import pendulum
@@ -27,17 +27,19 @@
from prefect.server.database import PrefectDBInterface, provide_database_interface
from prefect.server.orchestration import dependencies as orchestration_dependencies
from prefect.server.orchestration.core_policy import CoreTaskPolicy
-from prefect.server.orchestration.policies import BaseOrchestrationPolicy
+from prefect.server.orchestration.policies import TaskRunOrchestrationPolicy
from prefect.server.schemas.responses import OrchestrationResult
from prefect.server.task_queue import MultiQueue, TaskQueue
from prefect.server.utilities import subscriptions
from prefect.server.utilities.server import PrefectRouter
from prefect.types import DateTime
-logger = get_logger("server.api")
+if TYPE_CHECKING:
+ import logging
+logger: "logging.Logger" = get_logger("server.api")
-router = PrefectRouter(prefix="/task_runs", tags=["Task Runs"])
+router: PrefectRouter = PrefectRouter(prefix="/task_runs", tags=["Task Runs"])
@router.post("/")
@@ -87,7 +89,7 @@ async def update_task_run(
task_run: schemas.actions.TaskRunUpdate,
task_run_id: UUID = Path(..., description="The task run id", alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Updates a task run.
"""
@@ -214,7 +216,7 @@ async def read_task_runs(
async def delete_task_run(
task_run_id: UUID = Path(..., description="The task run id", alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Delete a task run by id.
"""
@@ -239,7 +241,7 @@ async def set_task_run_state(
),
db: PrefectDBInterface = Depends(provide_database_interface),
response: Response = None,
- task_policy: BaseOrchestrationPolicy = Depends(
+ task_policy: TaskRunOrchestrationPolicy = Depends(
orchestration_dependencies.provide_task_policy
),
orchestration_parameters: Dict[str, Any] = Depends(
@@ -275,7 +277,7 @@ async def set_task_run_state(
@router.websocket("/subscriptions/scheduled")
-async def scheduled_task_subscription(websocket: WebSocket):
+async def scheduled_task_subscription(websocket: WebSocket) -> None:
websocket = await subscriptions.accept_prefect_socket(websocket)
if not websocket:
return
diff --git a/src/prefect/server/api/task_workers.py b/src/prefect/server/api/task_workers.py
index b3ebc3edb6ff..cc8d30a8cea9 100644
--- a/src/prefect/server/api/task_workers.py
+++ b/src/prefect/server/api/task_workers.py
@@ -7,7 +7,7 @@
from prefect.server.models.task_workers import TaskWorkerResponse
from prefect.server.utilities.server import PrefectRouter
-router = PrefectRouter(prefix="/task_workers", tags=["Task Workers"])
+router: PrefectRouter = PrefectRouter(prefix="/task_workers", tags=["Task Workers"])
class TaskWorkerFilter(BaseModel):
diff --git a/src/prefect/server/api/templates.py b/src/prefect/server/api/templates.py
index 29b954aac56e..55c927eb962f 100644
--- a/src/prefect/server/api/templates.py
+++ b/src/prefect/server/api/templates.py
@@ -8,7 +8,7 @@
validate_user_template,
)
-router = PrefectRouter(prefix="/templates", tags=["Automations"])
+router: PrefectRouter = PrefectRouter(prefix="/templates", tags=["Automations"])
@router.post(
diff --git a/src/prefect/server/api/ui/flow_runs.py b/src/prefect/server/api/ui/flow_runs.py
index a67999f92919..f443c4dcc8fc 100644
--- a/src/prefect/server/api/ui/flow_runs.py
+++ b/src/prefect/server/api/ui/flow_runs.py
@@ -1,5 +1,7 @@
+from __future__ import annotations
+
import datetime
-from typing import List
+from typing import TYPE_CHECKING, List
from uuid import UUID
import sqlalchemy as sa
@@ -14,9 +16,12 @@
from prefect.server.utilities.server import PrefectRouter
from prefect.types import DateTime
-logger = get_logger("server.api.ui.flow_runs")
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger("server.api.ui.flow_runs")
-router = PrefectRouter(prefix="/ui/flow_runs", tags=["Flow Runs", "UI"])
+router: PrefectRouter = PrefectRouter(prefix="/ui/flow_runs", tags=["Flow Runs", "UI"])
class SimpleFlowRun(PrefectBaseModel):
diff --git a/src/prefect/server/api/ui/flows.py b/src/prefect/server/api/ui/flows.py
index 8c0b3375d480..e78dad89c674 100644
--- a/src/prefect/server/api/ui/flows.py
+++ b/src/prefect/server/api/ui/flows.py
@@ -1,5 +1,7 @@
+from __future__ import annotations
+
from datetime import datetime
-from typing import Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional
from uuid import UUID
import pendulum
@@ -15,9 +17,12 @@
from prefect.server.utilities.server import PrefectRouter
from prefect.types import DateTime
-logger = get_logger()
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger()
-router = PrefectRouter(prefix="/ui/flows", tags=["Flows", "UI"])
+router: PrefectRouter = PrefectRouter(prefix="/ui/flows", tags=["Flows", "UI"])
class SimpleNextFlowRun(PrefectBaseModel):
@@ -32,7 +37,9 @@ class SimpleNextFlowRun(PrefectBaseModel):
@field_validator("next_scheduled_start_time", mode="before")
@classmethod
- def validate_next_scheduled_start_time(cls, v):
+ def validate_next_scheduled_start_time(
+ cls, v: pendulum.DateTime | datetime
+ ) -> pendulum.DateTime:
if isinstance(v, datetime):
return pendulum.instance(v)
return v
diff --git a/src/prefect/server/api/ui/schemas.py b/src/prefect/server/api/ui/schemas.py
index 768b9c7231d2..91a0bbeb472a 100644
--- a/src/prefect/server/api/ui/schemas.py
+++ b/src/prefect/server/api/ui/schemas.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict
+from typing import TYPE_CHECKING, Any, Dict
from fastapi import Body, Depends, HTTPException, status
@@ -14,9 +14,12 @@
validate,
)
-router = APIRouter(prefix="/ui/schemas", tags=["UI", "Schemas"])
+if TYPE_CHECKING:
+ import logging
-logger = get_logger("server.api.ui.schemas")
+router: APIRouter = APIRouter(prefix="/ui/schemas", tags=["UI", "Schemas"])
+
+logger: "logging.Logger" = get_logger("server.api.ui.schemas")
@router.post("/validate")
@@ -24,7 +27,7 @@ async def validate_obj(
json_schema: Dict[str, Any] = Body(..., embed=True, alias="schema"),
values: Dict[str, Any] = Body(..., embed=True),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> dict[str, Any]:
schema = preprocess_schema(json_schema)
try:
diff --git a/src/prefect/server/api/ui/task_runs.py b/src/prefect/server/api/ui/task_runs.py
index 938d2a27d57e..6b2de8c9ba5c 100644
--- a/src/prefect/server/api/ui/task_runs.py
+++ b/src/prefect/server/api/ui/task_runs.py
@@ -1,5 +1,5 @@
from datetime import datetime
-from typing import List, Optional, cast
+from typing import TYPE_CHECKING, List, Optional, cast
import pendulum
import sqlalchemy as sa
@@ -14,9 +14,12 @@
from prefect.server.utilities.server import PrefectRouter
from prefect.types import DateTime
-logger = get_logger("orion.api.ui.task_runs")
+if TYPE_CHECKING:
+ import logging
-router = PrefectRouter(prefix="/ui/task_runs", tags=["Task Runs", "UI"])
+logger: "logging.Logger" = get_logger("server.api.ui.task_runs")
+
+router: PrefectRouter = PrefectRouter(prefix="/ui/task_runs", tags=["Task Runs", "UI"])
FAILED_STATES = [schemas.states.StateType.CRASHED, schemas.states.StateType.FAILED]
@@ -28,7 +31,7 @@ class TaskRunCount(PrefectBaseModel):
failed: int = Field(default=..., description="The number of failed task runs.")
@model_serializer
- def ser_model(self) -> dict:
+ def ser_model(self) -> dict[str, int]:
return {
"completed": int(self.completed),
"failed": int(self.failed),
diff --git a/src/prefect/server/api/validation.py b/src/prefect/server/api/validation.py
index ebc085a32e25..3c1905ca571d 100644
--- a/src/prefect/server/api/validation.py
+++ b/src/prefect/server/api/validation.py
@@ -35,7 +35,7 @@
allow None values, the Pydantic model will fail to validate at runtime.
"""
-from typing import Any, Dict, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from uuid import UUID
import pydantic
@@ -50,7 +50,10 @@
from prefect.server.schemas.core import WorkPool
from prefect.utilities.schema_tools import ValidationError, is_valid_schema, validate
-logger = get_logger("server.api.validation")
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger("server.api.validation")
DeploymentAction = Union[
schemas.actions.DeploymentCreate, schemas.actions.DeploymentUpdate
diff --git a/src/prefect/server/api/variables.py b/src/prefect/server/api/variables.py
index b05cb8c648e2..e7e5785d3615 100644
--- a/src/prefect/server/api/variables.py
+++ b/src/prefect/server/api/variables.py
@@ -11,12 +11,18 @@
from prefect.server import models
from prefect.server.api.dependencies import LimitBody
-from prefect.server.database import PrefectDBInterface, provide_database_interface
+from prefect.server.database import (
+ PrefectDBInterface,
+ orm_models,
+ provide_database_interface,
+)
from prefect.server.schemas import actions, core, filters, sorting
from prefect.server.utilities.server import PrefectRouter
-async def get_variable_or_404(session: AsyncSession, variable_id: UUID):
+async def get_variable_or_404(
+ session: AsyncSession, variable_id: UUID
+) -> orm_models.Variable:
"""Returns a variable or raises 404 HTTPException if it does not exist"""
variable = await models.variables.read_variable(
@@ -28,7 +34,9 @@ async def get_variable_or_404(session: AsyncSession, variable_id: UUID):
return variable
-async def get_variable_by_name_or_404(session: AsyncSession, name: str):
+async def get_variable_by_name_or_404(
+ session: AsyncSession, name: str
+) -> orm_models.Variable:
"""Returns a variable or raises 404 HTTPException if it does not exist"""
variable = await models.variables.read_variable_by_name(session=session, name=name)
@@ -38,7 +46,7 @@ async def get_variable_by_name_or_404(session: AsyncSession, name: str):
return variable
-router = PrefectRouter(
+router: PrefectRouter = PrefectRouter(
prefix="/variables",
tags=["Variables"],
)
@@ -120,7 +128,7 @@ async def update_variable(
variable: actions.VariableUpdate,
variable_id: UUID = Path(..., alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
async with db.session_context(begin_transaction=True) as session:
updated = await models.variables.update_variable(
session=session,
@@ -136,7 +144,7 @@ async def update_variable_by_name(
variable: actions.VariableUpdate,
name: str = Path(..., alias="name"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
async with db.session_context(begin_transaction=True) as session:
updated = await models.variables.update_variable_by_name(
session=session,
@@ -151,7 +159,7 @@ async def update_variable_by_name(
async def delete_variable(
variable_id: UUID = Path(..., alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
async with db.session_context(begin_transaction=True) as session:
deleted = await models.variables.delete_variable(
session=session, variable_id=variable_id
@@ -164,7 +172,7 @@ async def delete_variable(
async def delete_variable_by_name(
name: str = Path(...),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
async with db.session_context(begin_transaction=True) as session:
deleted = await models.variables.delete_variable_by_name(
session=session, name=name
diff --git a/src/prefect/server/api/work_queues.py b/src/prefect/server/api/work_queues.py
index ad9e17df1645..1267fe7ae87a 100644
--- a/src/prefect/server/api/work_queues.py
+++ b/src/prefect/server/api/work_queues.py
@@ -33,7 +33,7 @@
from prefect.server.utilities.server import PrefectRouter
from prefect.types import DateTime
-router = PrefectRouter(prefix="/work_queues", tags=["Work Queues"])
+router: PrefectRouter = PrefectRouter(prefix="/work_queues", tags=["Work Queues"])
@router.post("/", status_code=status.HTTP_201_CREATED)
@@ -69,7 +69,7 @@ async def update_work_queue(
work_queue: schemas.actions.WorkQueueUpdate,
work_queue_id: UUID = Path(..., description="The work queue id", alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Updates an existing work queue.
"""
@@ -229,7 +229,7 @@ async def read_work_queues(
async def delete_work_queue(
work_queue_id: UUID = Path(..., description="The work queue id", alias="id"),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Delete a work queue by id.
"""
diff --git a/src/prefect/server/api/workers.py b/src/prefect/server/api/workers.py
index 0f0d8546dd19..7d46d16bec42 100644
--- a/src/prefect/server/api/workers.py
+++ b/src/prefect/server/api/workers.py
@@ -35,7 +35,7 @@
if TYPE_CHECKING:
from prefect.server.database.orm_models import ORMWorkQueue
-router = PrefectRouter(
+router: PrefectRouter = PrefectRouter(
prefix="/work_pools",
tags=["Work Pools"],
)
@@ -257,7 +257,7 @@ async def update_work_pool(
work_pool_name: str = Path(..., description="The work pool name", alias="name"),
worker_lookups: WorkerLookups = Depends(WorkerLookups),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Update a work pool
"""
@@ -292,7 +292,7 @@ async def delete_work_pool(
work_pool_name: str = Path(..., description="The work pool name", alias="name"),
worker_lookups: WorkerLookups = Depends(WorkerLookups),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Delete a work pool
"""
@@ -505,7 +505,7 @@ async def update_work_queue(
),
worker_lookups: WorkerLookups = Depends(WorkerLookups),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Update a work pool queue
"""
@@ -535,7 +535,7 @@ async def delete_work_queue(
),
worker_lookups: WorkerLookups = Depends(WorkerLookups),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Delete a work pool queue
"""
@@ -573,7 +573,7 @@ async def worker_heartbeat(
),
worker_lookups: WorkerLookups = Depends(WorkerLookups),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
async with db.session_context(begin_transaction=True) as session:
work_pool = await models.workers.read_work_pool_by_name(
session=session,
@@ -638,7 +638,7 @@ async def delete_worker(
),
worker_lookups: WorkerLookups = Depends(WorkerLookups),
db: PrefectDBInterface = Depends(provide_database_interface),
-):
+) -> None:
"""
Delete a work pool's worker
"""
diff --git a/src/prefect/server/models/block_registration.py b/src/prefect/server/models/block_registration.py
index f537f75966b7..76cdb66b897d 100644
--- a/src/prefect/server/models/block_registration.py
+++ b/src/prefect/server/models/block_registration.py
@@ -13,10 +13,12 @@
from prefect.server import models, schemas
if TYPE_CHECKING:
+ import logging
+
from prefect.client.schemas import BlockSchema as ClientBlockSchema
from prefect.client.schemas import BlockType as ClientBlockType
-logger = get_logger("server")
+logger: "logging.Logger" = get_logger("server")
COLLECTIONS_BLOCKS_DATA_PATH = (
Path(__file__).parent.parent / "collection_blocks_data.json"
diff --git a/src/prefect/server/models/block_schemas.py b/src/prefect/server/models/block_schemas.py
index 264dc4913fc8..dee1cb34f4c9 100644
--- a/src/prefect/server/models/block_schemas.py
+++ b/src/prefect/server/models/block_schemas.py
@@ -40,7 +40,7 @@ async def create_block_schema(
"ClientBlockSchema",
],
override: bool = False,
- definitions: Optional[Dict] = None,
+ definitions: Optional[dict[str, Any]] = None,
) -> Union[BlockSchema, orm_models.BlockSchema]:
"""
Create a new block schema.
diff --git a/src/prefect/server/models/flow_runs.py b/src/prefect/server/models/flow_runs.py
index fca96360dbfa..d5543d8d98d1 100644
--- a/src/prefect/server/models/flow_runs.py
+++ b/src/prefect/server/models/flow_runs.py
@@ -7,6 +7,7 @@
import datetime
from itertools import chain
from typing import (
+ TYPE_CHECKING,
Any,
Dict,
List,
@@ -49,13 +50,13 @@
)
from prefect.types import KeyValueLabels
-logger = get_logger("flow_runs")
+if TYPE_CHECKING:
+ import logging
+logger: "logging.Logger" = get_logger("flow_runs")
-logger = get_logger("flow_runs")
-
-T = TypeVar("T", bound=tuple)
+T = TypeVar("T", bound=tuple[Any, ...])
@db_injector
@@ -281,7 +282,7 @@ async def _apply_flow_run_filters(
async def read_flow_runs(
db: PrefectDBInterface,
session: AsyncSession,
- columns: Optional[List] = None,
+ columns: Optional[list[str]] = None,
flow_filter: Optional[schemas.filters.FlowFilter] = None,
flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None,
task_run_filter: Optional[schemas.filters.TaskRunFilter] = None,
@@ -344,7 +345,7 @@ async def read_flow_runs(
async def cleanup_flow_run_concurrency_slots(
session: AsyncSession,
flow_run: orm_models.FlowRun,
-):
+) -> None:
"""
Cleanup flow run related resources, such as releasing concurrency slots.
All operations should be idempotent and safe to call multiple times.
diff --git a/src/prefect/server/models/logs.py b/src/prefect/server/models/logs.py
index ce5eac18eace..5dda835c9de9 100644
--- a/src/prefect/server/models/logs.py
+++ b/src/prefect/server/models/logs.py
@@ -3,7 +3,7 @@
Intended for internal use by the Prefect REST API.
"""
-from typing import Generator, List, Optional, Sequence, Tuple
+from typing import TYPE_CHECKING, Generator, List, Optional, Sequence, Tuple
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -23,7 +23,10 @@
# ...so we can only INSERT batches of a certain size at a time
LOG_BATCH_SIZE = MAXIMUM_QUERY_PARAMETERS // NUMBER_OF_LOG_FIELDS
-logger = get_logger(__name__)
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger(__name__)
def split_logs_into_batches(
diff --git a/src/prefect/server/models/task_runs.py b/src/prefect/server/models/task_runs.py
index c821b6c4b43b..ae7f54b7889f 100644
--- a/src/prefect/server/models/task_runs.py
+++ b/src/prefect/server/models/task_runs.py
@@ -4,7 +4,17 @@
"""
import contextlib
-from typing import Any, Dict, Optional, Sequence, Type, TypeVar, Union, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Optional,
+ Sequence,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
from uuid import UUID
import pendulum
@@ -29,9 +39,12 @@
from prefect.server.orchestration.rules import TaskOrchestrationContext
from prefect.server.schemas.responses import OrchestrationResult
-T = TypeVar("T", bound=tuple)
+if TYPE_CHECKING:
+ import logging
+
+T = TypeVar("T", bound=tuple[Any, ...])
-logger = get_logger("server")
+logger: "logging.Logger" = get_logger("server")
@db_injector
diff --git a/src/prefect/server/models/task_workers.py b/src/prefect/server/models/task_workers.py
index 899eba98b4d9..4a285feee54f 100644
--- a/src/prefect/server/models/task_workers.py
+++ b/src/prefect/server/models/task_workers.py
@@ -77,7 +77,7 @@ def reset(self) -> None:
# Global instance of the task worker tracker
-task_worker_tracker = InMemoryTaskWorkerTracker()
+task_worker_tracker: InMemoryTaskWorkerTracker = InMemoryTaskWorkerTracker()
# Main utilities to be used in the API layer
diff --git a/src/prefect/server/schemas/actions.py b/src/prefect/server/schemas/actions.py
index 30746fc529f9..1fd1c20b5ec7 100644
--- a/src/prefect/server/schemas/actions.py
+++ b/src/prefect/server/schemas/actions.py
@@ -2,6 +2,8 @@
Reduced schemas for accepting API actions.
"""
+from __future__ import annotations
+
import json
from copy import deepcopy
from typing import Any, ClassVar, Dict, List, Optional, Union
@@ -44,24 +46,24 @@
from prefect.utilities.templating import find_placeholders
-def validate_block_type_slug(value):
+def validate_block_type_slug(value: str) -> str:
raise_on_name_alphanumeric_dashes_only(value, field_name="Block type slug")
return value
-def validate_block_document_name(value):
+def validate_block_document_name(value: str | None) -> str | None:
if value is not None:
raise_on_name_alphanumeric_dashes_only(value, field_name="Block document name")
return value
-def validate_artifact_key(value):
+def validate_artifact_key(value: str | None) -> str | None:
if value is not None:
raise_on_name_alphanumeric_dashes_only(value, field_name="Artifact key")
return value
-def validate_variable_name(value):
+def validate_variable_name(value: str) -> str:
raise_on_name_alphanumeric_underscores_only(value, field_name="Variable name")
return value
@@ -112,7 +114,9 @@ class DeploymentScheduleCreate(ActionBaseModel):
@field_validator("max_scheduled_runs")
@classmethod
- def validate_max_scheduled_runs(cls, v):
+ def validate_max_scheduled_runs(
+ cls, v: PositiveInteger | None
+ ) -> PositiveInteger | None:
return validate_schedule_max_scheduled_runs(
v, PREFECT_DEPLOYMENT_SCHEDULE_MAX_SCHEDULED_RUNS.value()
)
@@ -133,7 +137,9 @@ class DeploymentScheduleUpdate(ActionBaseModel):
@field_validator("max_scheduled_runs")
@classmethod
- def validate_max_scheduled_runs(cls, v):
+ def validate_max_scheduled_runs(
+ cls, v: PositiveInteger | None
+ ) -> PositiveInteger | None:
return validate_schedule_max_scheduled_runs(
v, PREFECT_DEPLOYMENT_SCHEDULE_MAX_SCHEDULED_RUNS.value()
)
@@ -187,7 +193,7 @@ class DeploymentCreate(ActionBaseModel):
description="A dictionary of key-value labels. Values can be strings, numbers, or booleans.",
examples=[{"key": "value1", "key2": 42}],
)
- pull_steps: Optional[List[dict]] = Field(None)
+ pull_steps: Optional[List[dict[str, Any]]] = Field(None)
work_queue_name: Optional[str] = Field(None)
work_pool_name: Optional[str] = Field(
@@ -206,7 +212,7 @@ class DeploymentCreate(ActionBaseModel):
description="Overrides for the flow's infrastructure configuration.",
)
- def check_valid_configuration(self, base_job_template: dict):
+ def check_valid_configuration(self, base_job_template: dict[str, Any]) -> None:
"""
Check that the combination of base_job_template defaults and job_variables
conforms to the specified schema.
@@ -232,11 +238,13 @@ def check_valid_configuration(self, base_job_template: dict):
@model_validator(mode="before")
@classmethod
- def remove_old_fields(cls, values):
+ def remove_old_fields(cls, values: dict[str, Any]) -> dict[str, Any]:
return remove_old_deployment_fields(values)
@model_validator(mode="before")
- def _validate_parameters_conform_to_schema(cls, values):
+ def _validate_parameters_conform_to_schema(
+ cls, values: dict[str, Any]
+ ) -> dict[str, Any]:
values["parameters"] = validate_parameters_conform_to_schema(
values.get("parameters", {}), values
)
@@ -251,7 +259,7 @@ class DeploymentUpdate(ActionBaseModel):
@model_validator(mode="before")
@classmethod
- def remove_old_fields(cls, values):
+ def remove_old_fields(cls, values: dict[str, Any]) -> dict[str, Any]:
return remove_old_deployment_fields(values)
version: Optional[str] = Field(None)
@@ -300,7 +308,7 @@ def remove_old_fields(cls, values):
)
model_config: ClassVar[ConfigDict] = ConfigDict(populate_by_name=True)
- def check_valid_configuration(self, base_job_template: dict):
+ def check_valid_configuration(self, base_job_template: dict[str, Any]) -> None:
"""
Check that the combination of base_job_template defaults and job_variables
conforms to the schema specified in the base_job_template.
@@ -315,7 +323,7 @@ def check_valid_configuration(self, base_job_template: dict):
variables_schema = deepcopy(base_job_template.get("variables"))
- if variables_schema is not None:
+ if variables_schema is not None and self.job_variables is not None:
errors = validate(
self.job_variables,
variables_schema,
@@ -343,7 +351,7 @@ class FlowRunUpdate(ActionBaseModel):
@field_validator("name", mode="before")
@classmethod
- def set_name(cls, name):
+ def set_name(cls, name: str) -> str:
return get_or_create_run_name(name)
@@ -459,12 +467,12 @@ class TaskRunCreate(ActionBaseModel):
@field_validator("name", mode="before")
@classmethod
- def set_name(cls, name):
+ def set_name(cls, name: str) -> str:
return get_or_create_run_name(name)
@field_validator("cache_key")
@classmethod
- def validate_cache_key(cls, cache_key):
+ def validate_cache_key(cls, cache_key: str | None) -> str | None:
return validate_cache_key_length(cache_key)
@@ -477,7 +485,7 @@ class TaskRunUpdate(ActionBaseModel):
@field_validator("name", mode="before")
@classmethod
- def set_name(cls, name):
+ def set_name(cls, name: str) -> str:
return get_or_create_run_name(name)
@@ -544,7 +552,7 @@ class FlowRunCreate(ActionBaseModel):
@field_validator("name", mode="before")
@classmethod
- def set_name(cls, name):
+ def set_name(cls, name: str) -> str:
return get_or_create_run_name(name)
@@ -683,7 +691,7 @@ class BlockTypeUpdate(ActionBaseModel):
code_example: Optional[str] = Field(None)
@classmethod
- def updatable_fields(cls) -> set:
+ def updatable_fields(cls) -> set[str]:
return get_class_fields_only(cls)
@@ -780,7 +788,7 @@ class LogCreate(ActionBaseModel):
task_run_id: Optional[UUID] = Field(None)
-def validate_base_job_template(v):
+def validate_base_job_template(v: dict[str, Any]) -> dict[str, Any]:
if v == dict():
return v
@@ -791,7 +799,7 @@ def validate_base_job_template(v):
"The `base_job_template` must contain both a `job_configuration` key"
" and a `variables` key."
)
- template_variables = set()
+ template_variables: set[str] = set()
for template in job_config.values():
# find any variables inside of double curly braces, minus any whitespace
# e.g. "{{ var1 }}.{{var2}}" -> ["var1", "var2"]
@@ -927,7 +935,7 @@ class FlowRunNotificationPolicyCreate(ActionBaseModel):
@field_validator("message_template")
@classmethod
- def validate_message_template_variables(cls, v):
+ def validate_message_template_variables(cls, v: str | None) -> str | None:
return validate_message_template_variables(v)
@@ -942,7 +950,7 @@ class FlowRunNotificationPolicyUpdate(ActionBaseModel):
@field_validator("message_template")
@classmethod
- def validate_message_template_variables(cls, v):
+ def validate_message_template_variables(cls, v: str | None) -> str | None:
return validate_message_template_variables(v)
@@ -984,8 +992,8 @@ class ArtifactCreate(ActionBaseModel):
)
@classmethod
- def from_result(cls, data: Any):
- artifact_info = dict()
+ def from_result(cls, data: Any | dict[str, Any]) -> "ArtifactCreate":
+ artifact_info: dict[str, Any] = dict()
if isinstance(data, dict):
artifact_key = data.pop("artifact_key", None)
if artifact_key:
diff --git a/src/prefect/server/schemas/core.py b/src/prefect/server/schemas/core.py
index 7b245446f335..029455c668d8 100644
--- a/src/prefect/server/schemas/core.py
+++ b/src/prefect/server/schemas/core.py
@@ -559,7 +559,7 @@ class DeploymentSchedule(ORMBaseModel):
@field_validator("max_scheduled_runs")
@classmethod
- def validate_max_scheduled_runs(cls, v):
+ def validate_max_scheduled_runs(cls, v: int) -> int:
return validate_schedule_max_scheduled_runs(
v, PREFECT_DEPLOYMENT_SCHEDULE_MAX_SCHEDULED_RUNS.value()
)
@@ -1069,7 +1069,7 @@ class FlowRunNotificationPolicy(ORMBaseModel):
@field_validator("message_template")
@classmethod
- def validate_message_template_variables(cls, v):
+ def validate_message_template_variables(cls, v: str) -> str:
return validate_message_template_variables(v)
@@ -1187,7 +1187,7 @@ class Artifact(ORMBaseModel):
" the artifact type."
),
)
- metadata_: Optional[Dict[str, str]] = Field(
+ metadata_: Optional[dict[str, str]] = Field(
default=None,
description=(
"User-defined artifact metadata. Content must be string key and value"
@@ -1202,8 +1202,8 @@ class Artifact(ORMBaseModel):
)
@classmethod
- def from_result(cls, data: Any):
- artifact_info = dict()
+ def from_result(cls, data: Any | dict[str, Any]) -> "Artifact":
+ artifact_info: dict[str, Any] = dict()
if isinstance(data, dict):
artifact_key = data.pop("artifact_key", None)
if artifact_key:
@@ -1221,7 +1221,7 @@ def from_result(cls, data: Any):
@field_validator("metadata_")
@classmethod
- def validate_metadata_length(cls, v):
+ def validate_metadata_length(cls, v: dict[str, str]) -> dict[str, str]:
return validate_max_metadata_length(v)
@@ -1289,7 +1289,7 @@ class FlowRunInput(ORMBaseModel):
@field_validator("key", check_fields=False)
@classmethod
- def validate_name_characters(cls, v):
+ def validate_name_characters(cls, v: str) -> str:
raise_on_name_alphanumeric_dashes_only(v)
return v
diff --git a/src/prefect/server/schemas/filters.py b/src/prefect/server/schemas/filters.py
index 2b1c63452b08..445dc31d495b 100644
--- a/src/prefect/server/schemas/filters.py
+++ b/src/prefect/server/schemas/filters.py
@@ -676,8 +676,8 @@ class FlowRunFilter(PrefectOperatorFilterBaseModel):
default=None, description="Filter criteria for `FlowRun.idempotency_key`"
)
- def only_filters_on_id(self):
- return (
+ def only_filters_on_id(self) -> bool:
+ return bool(
self.id is not None
and (self.id.any_ and not self.id.not_any_)
and self.name is None
diff --git a/src/prefect/server/schemas/responses.py b/src/prefect/server/schemas/responses.py
index 4df1ab929fed..80b74f2e1b76 100644
--- a/src/prefect/server/schemas/responses.py
+++ b/src/prefect/server/schemas/responses.py
@@ -410,7 +410,7 @@ class DeploymentResponse(ORMBaseModel):
" storage or an absolute path."
),
)
- pull_steps: Optional[List[dict]] = Field(
+ pull_steps: Optional[list[dict[str, Any]]] = Field(
default=None, description="Pull steps for cloning and running this deployment."
)
entrypoint: Optional[str] = Field(
diff --git a/src/prefect/server/schemas/schedules.py b/src/prefect/server/schemas/schedules.py
index 0466643f784f..438c07d8b6db 100644
--- a/src/prefect/server/schemas/schedules.py
+++ b/src/prefect/server/schemas/schedules.py
@@ -1,6 +1,7 @@
"""
Schedule schemas
"""
+from __future__ import annotations
import datetime
from typing import Annotated, Any, ClassVar, Generator, List, Optional, Tuple, Union
@@ -225,7 +226,7 @@ def validate_timezone(self):
@field_validator("cron")
@classmethod
- def valid_cron_string(cls, v):
+ def valid_cron_string(cls, v: str) -> str:
return validate_cron_string(v)
async def get_dates(
@@ -368,11 +369,13 @@ class RRuleSchedule(PrefectBaseModel):
@field_validator("rrule")
@classmethod
- def validate_rrule_str(cls, v):
+ def validate_rrule_str(cls, v: str) -> str:
return validate_rrule_string(v)
@classmethod
- def from_rrule(cls, rrule: dateutil.rrule.rrule):
+ def from_rrule(
+ cls, rrule: dateutil.rrule.rrule | dateutil.rrule.rruleset
+ ) -> "RRuleSchedule":
if isinstance(rrule, dateutil.rrule.rrule):
if rrule._dtstart.tzinfo is not None:
timezone = rrule._dtstart.tzinfo.name
diff --git a/src/prefect/server/schemas/states.py b/src/prefect/server/schemas/states.py
index 2ac202c70d4f..d0ac008f04f0 100644
--- a/src/prefect/server/schemas/states.py
+++ b/src/prefect/server/schemas/states.py
@@ -63,7 +63,7 @@ class CountByState(PrefectBaseModel):
@field_validator("*")
@classmethod
- def check_key(cls, value: Optional[Any], info: ValidationInfo):
+ def check_key(cls, value: Optional[Any], info: ValidationInfo) -> Optional[Any]:
if info.field_name not in StateType.__members__:
raise ValueError(f"{info.field_name} is not a valid StateType")
return value
diff --git a/src/prefect/server/task_queue.py b/src/prefect/server/task_queue.py
index f63de8607f80..5854eda8ff8d 100644
--- a/src/prefect/server/task_queue.py
+++ b/src/prefect/server/task_queue.py
@@ -38,7 +38,7 @@ def configure_task_key(
task_key: str,
scheduled_size: Optional[int] = None,
retry_size: Optional[int] = None,
- ):
+ ) -> None:
scheduled_size = scheduled_size or cls.default_scheduled_max_size
retry_size = retry_size or cls.default_retry_max_size
cls._queue_size_configs[task_key] = (scheduled_size, retry_size)
diff --git a/src/prefect/server/utilities/database.py b/src/prefect/server/utilities/database.py
index 80ff10db658e..d133037f7bad 100644
--- a/src/prefect/server/utilities/database.py
+++ b/src/prefect/server/utilities/database.py
@@ -5,6 +5,8 @@
allow the Prefect REST API to seamlessly switch between the two.
"""
+from __future__ import annotations
+
import datetime
import json
import operator
@@ -15,6 +17,7 @@
Any,
Callable,
Optional,
+ Type,
TypeVar,
Union,
overload,
@@ -101,8 +104,8 @@ class Timestamp(TypeDecorator[pendulum.DateTime]):
as naive timestamps without timezones) and recovered as UTC.
"""
- impl = sa.TIMESTAMP(timezone=True)
- cache_ok = True
+ impl: TypeEngine[Any] | type[TypeEngine[Any]] = sa.TIMESTAMP(timezone=True)
+ cache_ok: bool | None = True
def load_dialect_impl(self, dialect: sa.Dialect) -> TypeEngine[Any]:
if dialect.name == "postgresql":
@@ -150,8 +153,8 @@ class UUID(TypeDecorator[uuid.UUID]):
hyphens.
"""
- impl = TypeEngine
- cache_ok = True
+ impl: type[TypeEngine[Any]] | TypeEngine[Any] = TypeEngine
+ cache_ok: bool | None = True
def load_dialect_impl(self, dialect: sa.Dialect) -> TypeEngine[Any]:
if dialect.name == "postgresql":
@@ -191,8 +194,10 @@ class JSON(TypeDecorator[Any]):
to SQL compilation
"""
- impl = postgresql.JSONB
- cache_ok = True
+ impl: type[postgresql.JSONB] | type[TypeEngine[Any]] | TypeEngine[
+ Any
+ ] = postgresql.JSONB
+ cache_ok: bool | None = True
def load_dialect_impl(self, dialect: sa.Dialect) -> TypeEngine[Any]:
if dialect.name == "postgresql":
@@ -230,7 +235,7 @@ class Pydantic(TypeDecorator[T]):
"""
impl = JSON
- cache_ok = True
+ cache_ok: bool | None = True
@overload
def __init__(
@@ -259,7 +264,9 @@ def __init__(
super().__init__()
self._pydantic_type = pydantic_type
if sa_column_type is not None:
- self.impl = sa_column_type
+ self.impl: type[JSON] | type[TypeEngine[Any]] | TypeEngine[
+ Any
+ ] = sa_column_type
def process_bind_param(
self, value: Optional[T], dialect: sa.Dialect
@@ -308,8 +315,8 @@ def bindparams_from_clause(
class date_add(functions.GenericFunction[pendulum.DateTime]):
"""Platform-independent way to add a timestamp and an interval"""
- type = Timestamp()
- inherit_cache = True
+ type: Timestamp = Timestamp()
+ inherit_cache: bool = True
def __init__(
self,
@@ -327,8 +334,8 @@ def __init__(
class interval_add(functions.GenericFunction[datetime.timedelta]):
"""Platform-independent way to add two intervals."""
- type = sa.Interval()
- inherit_cache = True
+ type: sa.Interval = sa.Interval()
+ inherit_cache: bool = True
def __init__(
self,
@@ -346,8 +353,8 @@ def __init__(
class date_diff(functions.GenericFunction[datetime.timedelta]):
"""Platform-independent difference of two timestamps. Computes d1 - d2."""
- type = sa.Interval()
- inherit_cache = True
+ type: sa.Interval = sa.Interval()
+ inherit_cache: bool = True
def __init__(
self,
@@ -363,8 +370,8 @@ def __init__(
class date_diff_seconds(functions.GenericFunction[float]):
"""Platform-independent calculation of the number of seconds between two timestamps or from 'now'"""
- type = sa.REAL
- inherit_cache = True
+ type: Type[sa.REAL[float]] = sa.REAL
+ inherit_cache: bool = True
def __init__(
self,
@@ -664,7 +671,7 @@ def sqlite_json_operators(
class greatest(functions.ReturnTypeFromArgs[T]):
- inherit_cache = True
+ inherit_cache: bool = True
@compiles(greatest, "sqlite")
diff --git a/src/prefect/server/utilities/messaging/__init__.py b/src/prefect/server/utilities/messaging/__init__.py
index 115f392dc512..9b2103c05660 100644
--- a/src/prefect/server/utilities/messaging/__init__.py
+++ b/src/prefect/server/utilities/messaging/__init__.py
@@ -2,13 +2,25 @@
from contextlib import asynccontextmanager, AbstractAsyncContextManager
from dataclasses import dataclass
import importlib
-from typing import Any, Callable, Optional, Protocol, TypeVar, Union, runtime_checkable
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Optional,
+ Protocol,
+ TypeVar,
+ Union,
+ runtime_checkable,
+)
from collections.abc import AsyncGenerator, Awaitable, Iterable, Mapping
from prefect.settings import PREFECT_MESSAGING_CACHE, PREFECT_MESSAGING_BROKER
from prefect.logging import get_logger
-logger = get_logger(__name__)
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger(__name__)
M = TypeVar("M", bound="Message", covariant=True)
@@ -77,13 +89,13 @@ def __init__(
deduplicate_by: Optional[str] = None,
) -> None:
self.topic = topic
- self.cache = cache or create_cache()
+ self.cache: Cache = cache or create_cache()
self.deduplicate_by = deduplicate_by
async def __aexit__(self, *args: Any) -> None:
pass
- async def publish_data(self, data: bytes, attributes: Mapping[str, str]):
+ async def publish_data(self, data: bytes, attributes: Mapping[str, str]) -> None:
to_publish = [CapturedMessage(data, attributes)]
if self.deduplicate_by:
diff --git a/src/prefect/server/utilities/messaging/memory.py b/src/prefect/server/utilities/messaging/memory.py
index 95934375a7d5..7d1cc4a8d6a2 100644
--- a/src/prefect/server/utilities/messaging/memory.py
+++ b/src/prefect/server/utilities/messaging/memory.py
@@ -5,7 +5,7 @@
from dataclasses import asdict, dataclass
from datetime import timedelta
from pathlib import Path
-from typing import Any, Optional, TypeVar, Union
+from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
from uuid import uuid4
import anyio
@@ -19,7 +19,10 @@
from prefect.server.utilities.messaging import Publisher as _Publisher
from prefect.settings.context import get_current_settings
-logger = get_logger(__name__)
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger(__name__)
@dataclass
@@ -57,7 +60,7 @@ def __init__(
) -> None:
self.topic = topic
self.max_retries = max_retries
- self.dead_letter_queue_path = (
+ self.dead_letter_queue_path: Path = (
Path(dead_letter_queue_path)
if dead_letter_queue_path
else get_current_settings().home / "dlq"
@@ -155,7 +158,7 @@ def subscribe(self) -> Subscription:
def unsubscribe(self, subscription: Subscription) -> None:
self._subscriptions.remove(subscription)
- def clear(self):
+ def clear(self) -> None:
for subscription in self._subscriptions:
self.unsubscribe(subscription)
self._subscriptions = []
@@ -229,7 +232,7 @@ async def forget_duplicates(self, attribute: str, messages: Iterable[M]) -> None
class Publisher(_Publisher):
def __init__(self, topic: str, cache: Cache, deduplicate_by: Optional[str] = None):
- self.topic = Topic.by_name(topic)
+ self.topic: Topic = Topic.by_name(topic)
self.deduplicate_by = deduplicate_by
self._cache = cache
@@ -254,7 +257,7 @@ async def publish_data(self, data: bytes, attributes: Mapping[str, str]) -> None
class Consumer(_Consumer):
def __init__(self, topic: str, subscription: Optional[Subscription] = None):
- self.topic = Topic.by_name(topic)
+ self.topic: Topic = Topic.by_name(topic)
if not subscription:
subscription = self.topic.subscribe()
assert subscription.topic is self.topic
diff --git a/src/prefect/server/utilities/user_templates.py b/src/prefect/server/utilities/user_templates.py
index 8c721789f0cb..320459b03d62 100644
--- a/src/prefect/server/utilities/user_templates.py
+++ b/src/prefect/server/utilities/user_templates.py
@@ -1,6 +1,6 @@
"""Utilities to support safely rendering user-supplied templates"""
-from typing import Any, Optional
+from typing import TYPE_CHECKING, Any, Optional
import jinja2.sandbox
from jinja2 import ChainableUndefined, nodes
@@ -8,7 +8,10 @@
from prefect.logging import get_logger
-logger = get_logger(__name__)
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger(__name__)
jinja2.sandbox.MAX_RANGE = 100
@@ -46,12 +49,12 @@ def __init__(self, message: Optional[str] = None, line_number: int = 0) -> None:
super().__init__(message)
-def register_user_template_filters(filters: dict[str, Any]):
+def register_user_template_filters(filters: dict[str, Any]) -> None:
"""Register additional filters that will be available to user templates"""
_template_environment.filters.update(filters)
-def validate_user_template(template: str):
+def validate_user_template(template: str) -> None:
root_node = _template_environment.parse(template)
_validate_loop_constraints(root_node)
diff --git a/src/prefect/states.py b/src/prefect/states.py
index d404dfb5e912..aec37d5215f5 100644
--- a/src/prefect/states.py
+++ b/src/prefect/states.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import asyncio
import datetime
import sys
@@ -33,12 +35,14 @@
from prefect.utilities.collections import ensure_iterable
if TYPE_CHECKING:
+ import logging
+
from prefect.results import (
R,
ResultStore,
)
-logger = get_logger("states")
+logger: "logging.Logger" = get_logger("states")
@deprecated.deprecated_parameter(
@@ -246,7 +250,7 @@ async def exception_to_failed_state(
result_store: Optional["ResultStore"] = None,
write_result: bool = False,
**kwargs: Any,
-) -> State:
+) -> State[BaseException]:
"""
Convenience function for creating `Failed` states from exceptions
"""
@@ -553,17 +557,17 @@ def is_state_iterable(obj: Any) -> TypeGuard[Iterable[State]]:
class StateGroup:
- def __init__(self, states: Iterable[State]) -> None:
- self.states = states
- self.type_counts = self._get_type_counts(states)
- self.total_count = len(states)
- self.cancelled_count = self.type_counts[StateType.CANCELLED]
- self.final_count = sum(state.is_final() for state in states)
- self.not_final_count = self.total_count - self.final_count
- self.paused_count = self.type_counts[StateType.PAUSED]
+ def __init__(self, states: list[State]) -> None:
+ self.states: list[State] = states
+ self.type_counts: dict[StateType, int] = self._get_type_counts(states)
+ self.total_count: int = len(states)
+ self.cancelled_count: int = self.type_counts[StateType.CANCELLED]
+ self.final_count: int = sum(state.is_final() for state in states)
+ self.not_final_count: int = self.total_count - self.final_count
+ self.paused_count: int = self.type_counts[StateType.PAUSED]
@property
- def fail_count(self):
+ def fail_count(self) -> int:
return self.type_counts[StateType.FAILED] + self.type_counts[StateType.CRASHED]
def all_completed(self) -> bool:
@@ -741,7 +745,7 @@ def Suspended(
pause_expiration_time: Optional[datetime.datetime] = None,
pause_key: Optional[str] = None,
**kwargs: Any,
-):
+) -> "State[R]":
"""Convenience function for creating `Suspended` states.
Returns:
diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py
index 5285b9dec1f0..170d37b1987e 100644
--- a/src/prefect/task_engine.py
+++ b/src/prefect/task_engine.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import asyncio
import inspect
import logging
@@ -28,7 +30,7 @@
import anyio
import pendulum
from opentelemetry import trace
-from typing_extensions import ParamSpec
+from typing_extensions import ParamSpec, Self
from prefect.cache_policies import CachePolicy
from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client
@@ -123,7 +125,7 @@ class BaseTaskRunEngine(Generic[P, R]):
_last_event: Optional[PrefectEvent] = None
_telemetry: RunTelemetry = field(default_factory=RunTelemetry)
- def __post_init__(self):
+ def __post_init__(self) -> None:
if self.parameters is None:
self.parameters = {}
@@ -239,7 +241,7 @@ def is_running(self) -> bool:
return False
return task_run.state.is_running() or task_run.state.is_scheduled()
- def log_finished_message(self):
+ def log_finished_message(self) -> None:
if not self.task_run:
return
@@ -295,6 +297,7 @@ def handle_rollback(self, txn: Transaction) -> None:
@dataclass
class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
+ task_run: Optional[TaskRun] = None
_client: Optional[SyncPrefectClient] = None
@property
@@ -337,7 +340,7 @@ def can_retry(self, exc: Exception) -> bool:
)
return False
- def call_hooks(self, state: Optional[State] = None):
+ def call_hooks(self, state: Optional[State] = None) -> None:
if state is None:
state = self.state
task = self.task
@@ -372,7 +375,7 @@ def call_hooks(self, state: Optional[State] = None):
else:
self.logger.info(f"Hook {hook_name!r} finished running successfully")
- def begin_run(self):
+ def begin_run(self) -> None:
try:
self._resolve_parameters()
self._set_custom_task_run_name()
@@ -547,7 +550,7 @@ def handle_retry(self, exc: Exception) -> bool:
)
self.set_state(new_state, force=True)
- self.retries = self.retries + 1
+ self.retries: int = self.retries + 1
return True
elif self.retries >= self.task.retries:
self.logger.error(
@@ -641,7 +644,9 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
stack.enter_context(ConcurrencyContextV1())
stack.enter_context(ConcurrencyContext())
- self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore
+ self.logger: "logging.Logger" = task_run_logger(
+ task_run=self.task_run, task=self.task
+ ) # type: ignore
yield
@@ -650,7 +655,7 @@ def initialize_run(
self,
task_run_id: Optional[UUID] = None,
dependencies: Optional[dict[str, set[TaskRunInput]]] = None,
- ) -> Generator["SyncTaskRunEngine", Any, Any]:
+ ) -> Generator[Self, Any, Any]:
"""
Enters a client context and creates a task run if needed.
"""
@@ -720,7 +725,7 @@ def initialize_run(
self._is_started = False
self._client = None
- async def wait_until_ready(self):
+ async def wait_until_ready(self) -> None:
"""Waits until the scheduled time (if its the future), then enters Running."""
if scheduled_time := self.state.state_details.scheduled_time:
sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds()
@@ -827,6 +832,7 @@ def call_task_fn(
@dataclass
class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
+ task_run: TaskRun | None = None
_client: Optional[PrefectClient] = None
@property
@@ -868,7 +874,7 @@ async def can_retry(self, exc: Exception) -> bool:
)
return False
- async def call_hooks(self, state: Optional[State] = None):
+ async def call_hooks(self, state: Optional[State] = None) -> None:
if state is None:
state = self.state
task = self.task
@@ -903,7 +909,7 @@ async def call_hooks(self, state: Optional[State] = None):
else:
self.logger.info(f"Hook {hook_name!r} finished running successfully")
- async def begin_run(self):
+ async def begin_run(self) -> None:
try:
self._resolve_parameters()
self._set_custom_task_run_name()
@@ -1077,7 +1083,7 @@ async def handle_retry(self, exc: Exception) -> bool:
)
await self.set_state(new_state, force=True)
- self.retries = self.retries + 1
+ self.retries: int = self.retries + 1
return True
elif self.retries >= self.task.retries:
self.logger.error(
@@ -1171,7 +1177,9 @@ async def setup_run_context(self, client: Optional[PrefectClient] = None):
)
stack.enter_context(ConcurrencyContext())
- self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore
+ self.logger: "logging.Logger" = task_run_logger(
+ task_run=self.task_run, task=self.task
+ ) # type: ignore
yield
@@ -1180,7 +1188,7 @@ async def initialize_run(
self,
task_run_id: Optional[UUID] = None,
dependencies: Optional[dict[str, set[TaskRunInput]]] = None,
- ) -> AsyncGenerator["AsyncTaskRunEngine", Any]:
+ ) -> AsyncGenerator[Self, Any]:
"""
Enters a client context and creates a task run if needed.
"""
@@ -1248,7 +1256,7 @@ async def initialize_run(
self._is_started = False
self._client = None
- async def wait_until_ready(self):
+ async def wait_until_ready(self) -> None:
"""Waits until the scheduled time (if its the future), then enters Running."""
if scheduled_time := self.state.state_details.scheduled_time:
sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds()
diff --git a/src/prefect/task_runners.py b/src/prefect/task_runners.py
index 1efd284e34fd..9f42a8724cab 100644
--- a/src/prefect/task_runners.py
+++ b/src/prefect/task_runners.py
@@ -67,7 +67,7 @@ def __init__(self):
self._started = False
@property
- def name(self):
+ def name(self) -> str:
"""The name of this task runner"""
return type(self).__name__.lower().replace("taskrunner", "")
diff --git a/src/prefect/task_runs.py b/src/prefect/task_runs.py
index c76ea3f6418a..bdbce7e6138b 100644
--- a/src/prefect/task_runs.py
+++ b/src/prefect/task_runs.py
@@ -17,6 +17,9 @@
from prefect.events.filters import EventFilter, EventNameFilter
from prefect.logging.loggers import get_logger
+if TYPE_CHECKING:
+ import logging
+
class TaskRunWaiter:
"""
@@ -70,8 +73,8 @@ async def main():
_instance_lock = threading.Lock()
def __init__(self):
- self.logger = get_logger("TaskRunWaiter")
- self._consumer_task: asyncio.Task[None] | None = None
+ self.logger: "logging.Logger" = get_logger("TaskRunWaiter")
+ self._consumer_task: "asyncio.Task[None] | None" = None
self._observed_completed_task_runs: TTLCache[uuid.UUID, bool] = TTLCache(
maxsize=10000, ttl=600
)
@@ -82,7 +85,7 @@ def __init__(self):
self._completion_events_lock = threading.Lock()
self._started = False
- def start(self):
+ def start(self) -> None:
"""
Start the TaskRunWaiter service.
"""
@@ -145,7 +148,7 @@ async def _consume_events(self, consumer_started: asyncio.Event):
except Exception as exc:
self.logger.error(f"Error processing event: {exc}")
- def stop(self):
+ def stop(self) -> None:
"""
Stop the TaskRunWaiter service.
"""
@@ -159,7 +162,7 @@ def stop(self):
@classmethod
async def wait_for_task_run(
cls, task_run_id: uuid.UUID, timeout: Optional[float] = None
- ):
+ ) -> None:
"""
Wait for a task run to finish.
@@ -225,7 +228,7 @@ def add_done_callback(
instance._completion_callbacks[task_run_id] = callback
@classmethod
- def instance(cls):
+ def instance(cls) -> Self:
"""
Get the singleton instance of TaskRunWaiter.
"""
diff --git a/src/prefect/task_worker.py b/src/prefect/task_worker.py
index f1aa9fc0d721..bcc37c3d73aa 100644
--- a/src/prefect/task_worker.py
+++ b/src/prefect/task_worker.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import asyncio
import inspect
import os
@@ -16,7 +18,7 @@
import uvicorn
from exceptiongroup import BaseExceptionGroup # novermin
from fastapi import FastAPI
-from typing_extensions import ParamSpec, TypeVar
+from typing_extensions import ParamSpec, Self, TypeVar
from websockets.exceptions import InvalidStatusCode
from prefect import Task
@@ -42,7 +44,10 @@
from prefect.utilities.services import start_client_metrics_server
from prefect.utilities.urls import url_for
-logger = get_logger("task_worker")
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger("task_worker")
P = ParamSpec("P")
R = TypeVar("R", infer_variance=True)
@@ -85,7 +90,7 @@ class TaskWorker:
def __init__(
self,
*tasks: Task[P, R],
- limit: Optional[int] = 10,
+ limit: int | None = 10,
):
self.tasks: list["Task[..., Any]"] = []
for t in tasks:
@@ -100,7 +105,7 @@ def __init__(
else:
self.tasks.append(t.with_options(persist_result=True))
- self.task_keys = set(t.task_key for t in tasks if isinstance(t, Task)) # pyright: ignore[reportUnnecessaryIsInstance]
+ self.task_keys: set[str] = set(t.task_key for t in tasks if isinstance(t, Task)) # pyright: ignore[reportUnnecessaryIsInstance]
self._started_at: Optional[pendulum.DateTime] = None
self.stopping: bool = False
@@ -154,7 +159,7 @@ def current_tasks(self) -> Optional[int]:
def available_tasks(self) -> Optional[int]:
return int(self._limiter.available_tokens) if self._limiter else None
- def handle_sigterm(self, signum: int, frame: object):
+ def handle_sigterm(self, signum: int, frame: object) -> None:
"""
Shuts down the task worker when a SIGTERM is received.
"""
@@ -355,14 +360,14 @@ async def _submit_scheduled_task_run(self, task_run: TaskRun):
)
await asyncio.wrap_future(future)
- async def execute_task_run(self, task_run: TaskRun):
+ async def execute_task_run(self, task_run: TaskRun) -> None:
"""Execute a task run in the task worker."""
async with self if not self.started else asyncnullcontext():
token_acquired = await self._acquire_token(task_run.id)
if token_acquired:
await self._safe_submit_scheduled_task_run(task_run)
- async def __aenter__(self):
+ async def __aenter__(self) -> Self:
logger.debug("Starting task worker...")
if self._client._closed: # pyright: ignore[reportPrivateUsage]
diff --git a/src/prefect/telemetry/bootstrap.py b/src/prefect/telemetry/bootstrap.py
index 6646cf1d9d81..bbd3c95a78ad 100644
--- a/src/prefect/telemetry/bootstrap.py
+++ b/src/prefect/telemetry/bootstrap.py
@@ -4,7 +4,10 @@
from prefect.client.base import ServerType, determine_server_type
from prefect.logging.loggers import get_logger
-logger = get_logger(__name__)
+if TYPE_CHECKING:
+ import logging
+
+logger: "logging.Logger" = get_logger(__name__)
if TYPE_CHECKING:
from opentelemetry.sdk._logs import LoggerProvider
diff --git a/src/prefect/telemetry/run_telemetry.py b/src/prefect/telemetry/run_telemetry.py
index cb43719841cf..abad58f781fe 100644
--- a/src/prefect/telemetry/run_telemetry.py
+++ b/src/prefect/telemetry/run_telemetry.py
@@ -1,6 +1,8 @@
+from __future__ import annotations
+
import time
from dataclasses import dataclass, field
-from typing import TYPE_CHECKING, Any, Optional, Union
+from typing import TYPE_CHECKING, Any, Union
from opentelemetry import propagate, trace
from opentelemetry.context import Context
@@ -47,14 +49,14 @@ class RunTelemetry:
_tracer: "Tracer" = field(
default_factory=lambda: get_tracer("prefect", prefect.__version__)
)
- span: Optional[Span] = None
+ span: Span | None = None
async def async_start_span(
self,
run: FlowOrTaskRun,
client: PrefectClient,
- parameters: Optional[dict[str, Any]] = None,
- ):
+ parameters: dict[str, Any] | None = None,
+ ) -> Span:
traceparent, span = self._start_span(run, parameters)
if self._run_type(run) == "flow" and traceparent:
@@ -70,8 +72,8 @@ def start_span(
self,
run: FlowOrTaskRun,
client: SyncPrefectClient,
- parameters: Optional[dict[str, Any]] = None,
- ):
+ parameters: dict[str, Any] | None = None,
+ ) -> Span:
traceparent, span = self._start_span(run, parameters)
if self._run_type(run) == "flow" and traceparent:
@@ -84,8 +86,8 @@ def start_span(
def _start_span(
self,
run: FlowOrTaskRun,
- parameters: Optional[dict[str, Any]] = None,
- ) -> tuple[Optional[str], Span]:
+ parameters: dict[str, Any] | None = None,
+ ) -> tuple[str | None, Span]:
"""
Start a span for a run.
"""
@@ -139,8 +141,8 @@ def _run_type(self, run: FlowOrTaskRun) -> str:
return "task" if isinstance(run, TaskRun) else "flow"
def _trace_context_from_labels(
- self, labels: Optional[KeyValueLabels]
- ) -> Optional[Context]:
+ self, labels: KeyValueLabels | None
+ ) -> Context | None:
"""Get trace context from run labels if it exists."""
if not labels or LABELS_TRACEPARENT_KEY not in labels:
return None
@@ -148,7 +150,7 @@ def _trace_context_from_labels(
carrier = {TRACEPARENT_KEY: traceparent}
return propagate.extract(carrier)
- def _traceparent_from_span(self, span: Span) -> Optional[str]:
+ def _traceparent_from_span(self, span: Span) -> str | None:
carrier: dict[str, Any] = {}
propagate.inject(carrier, context=trace.set_span_in_context(span))
return carrier.get(TRACEPARENT_KEY)
@@ -162,7 +164,7 @@ def end_span_on_success(self) -> None:
self.span.end(time.time_ns())
self.span = None
- def end_span_on_failure(self, terminal_message: Optional[str] = None) -> None:
+ def end_span_on_failure(self, terminal_message: str | None = None) -> None:
"""
End a span for a run on failure.
"""
@@ -203,7 +205,7 @@ def update_run_name(self, name: str) -> None:
self.span.update_name(name=name)
self.span.set_attribute("prefect.run.name", name)
- def _parent_run(self) -> Union[FlowOrTaskRun, None]:
+ def _parent_run(self) -> FlowOrTaskRun | None:
"""
Identify the "parent run" for the current execution context.
diff --git a/src/prefect/testing/cli.py b/src/prefect/testing/cli.py
index 7a660119257b..03c893527336 100644
--- a/src/prefect/testing/cli.py
+++ b/src/prefect/testing/cli.py
@@ -13,7 +13,7 @@
from prefect.utilities.asyncutils import in_async_main_thread
-def check_contains(cli_result: Result, content: str, should_contain: bool):
+def check_contains(cli_result: Result, content: str, should_contain: bool) -> None:
"""
Utility function to see if content is or is not in a CLI result.
diff --git a/src/prefect/testing/docker.py b/src/prefect/testing/docker.py
index d70a1a2e5542..1e0de02624ad 100644
--- a/src/prefect/testing/docker.py
+++ b/src/prefect/testing/docker.py
@@ -1,18 +1,18 @@
from contextlib import contextmanager
-from typing import Generator, List
+from typing import Any, Generator
from unittest import mock
from prefect.utilities.dockerutils import ImageBuilder
@contextmanager
-def capture_builders() -> Generator[List[ImageBuilder], None, None]:
+def capture_builders() -> Generator[list[ImageBuilder], None, None]:
"""Captures any instances of ImageBuilder created while this context is active"""
- builders = []
+ builders: list[ImageBuilder] = []
original_init = ImageBuilder.__init__
- def capture(self, *args, **kwargs):
+ def capture(self: ImageBuilder, *args: Any, **kwargs: Any):
builders.append(self)
original_init(self, *args, **kwargs)
diff --git a/src/prefect/testing/fixtures.py b/src/prefect/testing/fixtures.py
index 07352f872afc..3b59ab64440e 100644
--- a/src/prefect/testing/fixtures.py
+++ b/src/prefect/testing/fixtures.py
@@ -4,7 +4,7 @@
import socket
import sys
from contextlib import contextmanager
-from typing import AsyncGenerator, Generator, List, Optional, Union
+from typing import Any, AsyncGenerator, Callable, Generator, List, Optional, Union
from unittest import mock
from uuid import UUID
@@ -39,7 +39,9 @@
@pytest.fixture(autouse=True)
-def add_prefect_loggers_to_caplog(caplog):
+def add_prefect_loggers_to_caplog(
+ caplog: pytest.LogCaptureFixture,
+) -> Generator[None, None, None]:
import logging
logger = logging.getLogger("prefect")
@@ -57,7 +59,9 @@ def is_port_in_use(port: int) -> bool:
@pytest.fixture(scope="session")
-async def hosted_api_server(unused_tcp_port_factory):
+async def hosted_api_server(
+ unused_tcp_port_factory: Callable[[], int],
+) -> AsyncGenerator[str, None]:
"""
Runs an instance of the Prefect API server in a subprocess instead of the using the
ephemeral application.
@@ -134,7 +138,7 @@ async def hosted_api_server(unused_tcp_port_factory):
@pytest.fixture(autouse=True)
-def use_hosted_api_server(hosted_api_server):
+def use_hosted_api_server(hosted_api_server: str) -> Generator[str, None, None]:
"""
Sets `PREFECT_API_URL` to the test session's hosted API endpoint.
"""
@@ -148,7 +152,7 @@ def use_hosted_api_server(hosted_api_server):
@pytest.fixture
-def disable_hosted_api_server():
+def disable_hosted_api_server() -> Generator[None, None, None]:
"""
Disables the hosted API server by setting `PREFECT_API_URL` to `None`.
"""
@@ -157,11 +161,13 @@ def disable_hosted_api_server():
PREFECT_API_URL: None,
}
):
- yield hosted_api_server
+ yield
@pytest.fixture
-def enable_ephemeral_server(disable_hosted_api_server):
+def enable_ephemeral_server(
+ disable_hosted_api_server: None,
+) -> Generator[None, None, None]:
"""
Enables the ephemeral server by setting `PREFECT_SERVER_ALLOW_EPHEMERAL_MODE` to `True`.
"""
@@ -170,13 +176,15 @@ def enable_ephemeral_server(disable_hosted_api_server):
PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: True,
}
):
- yield hosted_api_server
+ yield
SubprocessASGIServer().stop()
@pytest.fixture
-def mock_anyio_sleep(monkeypatch):
+def mock_anyio_sleep(
+ monkeypatch: pytest.MonkeyPatch,
+) -> Generator[Callable[[float], None], None, None]:
"""
Mock sleep used to not actually sleep but to set the current time to now + sleep
delay seconds while still yielding to other tasks in the event loop.
@@ -188,18 +196,18 @@ def mock_anyio_sleep(monkeypatch):
original_sleep = anyio.sleep
time_shift = 0.0
- async def callback(delay_in_seconds):
+ async def callback(delay_in_seconds: float) -> None:
nonlocal time_shift
time_shift += float(delay_in_seconds)
# Preserve yield effects of sleep
await original_sleep(0)
- def latest_now(*args):
+ def latest_now(*args: Any) -> pendulum.DateTime:
# Fast-forwards the time by the total sleep time
return original_now(*args).add(
# Ensure we retain float precision
seconds=int(time_shift),
- microseconds=(time_shift - int(time_shift)) * 1000000,
+ microseconds=int((time_shift - int(time_shift)) * 1000000),
)
monkeypatch.setattr("pendulum.now", latest_now)
@@ -368,7 +376,7 @@ def events_cloud_api_url(events_server: WebSocketServer, unused_tcp_port: int) -
@pytest.fixture
-def mock_should_emit_events(monkeypatch) -> mock.Mock:
+def mock_should_emit_events(monkeypatch: pytest.MonkeyPatch) -> mock.Mock:
m = mock.Mock()
m.return_value = True
monkeypatch.setattr("prefect.events.utilities.should_emit_events", m)
@@ -376,7 +384,9 @@ def mock_should_emit_events(monkeypatch) -> mock.Mock:
@pytest.fixture
-def asserting_events_worker(monkeypatch) -> Generator[EventsWorker, None, None]:
+def asserting_events_worker(
+ monkeypatch: pytest.MonkeyPatch,
+) -> Generator[EventsWorker, None, None]:
worker = EventsWorker.instance(AssertingEventsClient)
# Always yield the asserting worker when new instances are retrieved
monkeypatch.setattr(EventsWorker, "instance", lambda *_: worker)
@@ -388,7 +398,7 @@ def asserting_events_worker(monkeypatch) -> Generator[EventsWorker, None, None]:
@pytest.fixture
def asserting_and_emitting_events_worker(
- monkeypatch,
+ monkeypatch: pytest.MonkeyPatch,
) -> Generator[EventsWorker, None, None]:
worker = EventsWorker.instance(AssertingPassthroughEventsClient)
# Always yield the asserting worker when new instances are retrieved
@@ -400,7 +410,9 @@ def asserting_and_emitting_events_worker(
@pytest.fixture
-async def events_pipeline(asserting_events_worker: EventsWorker):
+async def events_pipeline(
+ asserting_events_worker: EventsWorker,
+) -> AsyncGenerator[EventsPipeline, None]:
class AssertingEventsPipeline(EventsPipeline):
@sync_compatible
async def process_events(
@@ -435,7 +447,9 @@ async def wait_for_min_events():
@pytest.fixture
-async def emitting_events_pipeline(asserting_and_emitting_events_worker: EventsWorker):
+async def emitting_events_pipeline(
+ asserting_and_emitting_events_worker: EventsWorker,
+) -> AsyncGenerator[EventsPipeline, None]:
class AssertingAndEmittingEventsPipeline(EventsPipeline):
@sync_compatible
async def process_events(self):
@@ -449,14 +463,16 @@ async def process_events(self):
@pytest.fixture
-def reset_worker_events(asserting_events_worker: EventsWorker):
+def reset_worker_events(
+ asserting_events_worker: EventsWorker,
+) -> Generator[None, None, None]:
yield
assert isinstance(asserting_events_worker._client, AssertingEventsClient)
asserting_events_worker._client.events = []
@pytest.fixture
-def enable_lineage_events():
+def enable_lineage_events() -> Generator[None, None, None]:
"""A fixture that ensures lineage events are enabled."""
with temporary_settings(updates={PREFECT_EXPERIMENTS_LINEAGE_EVENTS_ENABLED: True}):
yield
diff --git a/src/prefect/testing/standard_test_suites/blocks.py b/src/prefect/testing/standard_test_suites/blocks.py
index 0fa04776784e..f71f80c1244e 100644
--- a/src/prefect/testing/standard_test_suites/blocks.py
+++ b/src/prefect/testing/standard_test_suites/blocks.py
@@ -14,13 +14,13 @@ class BlockStandardTestSuite(ABC):
def block(self) -> type[Block]:
pass
- def test_has_a_description(self, block: type[Block]):
+ def test_has_a_description(self, block: type[Block]) -> None:
assert block.get_description()
- def test_has_a_documentation_url(self, block: type[Block]):
+ def test_has_a_documentation_url(self, block: type[Block]) -> None:
assert block._documentation_url
- def test_all_fields_have_a_description(self, block: type[Block]):
+ def test_all_fields_have_a_description(self, block: type[Block]) -> None:
for name, field in block.model_fields.items():
if Block.annotation_refers_to_block_class(field.annotation):
# TODO: Block field descriptions aren't currently handled by the UI, so
@@ -34,7 +34,7 @@ def test_all_fields_have_a_description(self, block: type[Block]):
"."
), f"{name} description on {block.__name__} does not end with a period"
- def test_has_a_valid_code_example(self, block: type[Block]):
+ def test_has_a_valid_code_example(self, block: type[Block]) -> None:
code_example = block.get_code_example()
assert code_example is not None, f"{block.__name__} is missing a code example"
@@ -55,7 +55,7 @@ def test_has_a_valid_code_example(self, block: type[Block]):
f" matching the pattern {block_load_pattern}"
)
- def test_has_a_valid_image(self, block: type[Block]):
+ def test_has_a_valid_image(self, block: type[Block]) -> None:
logo_url = block._logo_url
assert (
logo_url is not None
diff --git a/src/prefect/testing/utilities.py b/src/prefect/testing/utilities.py
index 2b874240ec39..d22fb2887424 100644
--- a/src/prefect/testing/utilities.py
+++ b/src/prefect/testing/utilities.py
@@ -2,6 +2,8 @@
Internal utilities for tests.
"""
+from __future__ import annotations
+
import atexit
import shutil
import warnings
@@ -9,7 +11,7 @@
from pathlib import Path
from pprint import pprint
from tempfile import mkdtemp
-from typing import TYPE_CHECKING, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Generator
import prefect.context
import prefect.settings
@@ -30,10 +32,11 @@
if TYPE_CHECKING:
from prefect.client.orchestration import PrefectClient
+ from prefect.client.schemas.objects import FlowRun
from prefect.filesystems import ReadableFileSystem
-def exceptions_equal(a, b):
+def exceptions_equal(a: Exception, b: Exception) -> bool:
"""
Exceptions cannot be compared by `==`. They can be compared using `is` but this
will fail if the exception is serialized/deserialized so this utility does its
@@ -54,9 +57,9 @@ def exceptions_equal(a, b):
def kubernetes_environments_equal(
- actual: List[Dict[str, str]],
- expected: Union[List[Dict[str, str]], Dict[str, str]],
-):
+ actual: list[dict[str, str]],
+ expected: list[dict[str, str]] | dict[str, str],
+) -> bool:
# Convert to a required format and sort by name
if isinstance(expected, dict):
expected = [{"name": key, "value": value} for key, value in expected.items()]
@@ -90,7 +93,9 @@ def kubernetes_environments_equal(
@contextmanager
-def assert_does_not_warn(ignore_warnings=[]):
+def assert_does_not_warn(
+ ignore_warnings: list[type[Warning]] | None = None,
+) -> Generator[None, None, None]:
"""
Converts warnings to errors within this context to assert warnings are not raised,
except for those specified in ignore_warnings.
@@ -98,6 +103,7 @@ def assert_does_not_warn(ignore_warnings=[]):
Parameters:
- ignore_warnings: List of warning types to ignore. Example: [DeprecationWarning, UserWarning]
"""
+ ignore_warnings = ignore_warnings or []
with warnings.catch_warnings():
warnings.simplefilter("error")
for warning_type in ignore_warnings:
@@ -110,7 +116,7 @@ def assert_does_not_warn(ignore_warnings=[]):
@contextmanager
-def prefect_test_harness(server_startup_timeout: Optional[int] = 30):
+def prefect_test_harness(server_startup_timeout: int | None = 30):
"""
Temporarily run flows against a local SQLite database for testing.
@@ -175,7 +181,7 @@ def cleanup_temp_dir(temp_dir):
test_server.stop()
-async def get_most_recent_flow_run(client: "PrefectClient" = None):
+async def get_most_recent_flow_run(client: "PrefectClient | None" = None) -> "FlowRun":
if client is None:
client = get_client()
@@ -187,8 +193,8 @@ async def get_most_recent_flow_run(client: "PrefectClient" = None):
def assert_blocks_equal(
- found: Block, expected: Block, exclude_private: bool = True, **kwargs
-) -> bool:
+ found: Block, expected: Block, exclude_private: bool = True, **kwargs: Any
+) -> None:
assert isinstance(
found, type(expected)
), f"Unexpected type {type(found).__name__}, expected {type(expected).__name__}"
@@ -204,8 +210,8 @@ def assert_blocks_equal(
async def assert_uses_result_serializer(
- state: State, serializer: Union[str, Serializer], client: "PrefectClient"
-):
+ state: State, serializer: str | Serializer, client: "PrefectClient"
+) -> None:
assert isinstance(state.data, (ResultRecord, ResultRecordMetadata))
if isinstance(state.data, ResultRecord):
result_serializer = state.data.metadata.serializer
@@ -240,8 +246,8 @@ async def assert_uses_result_serializer(
@inject_client
async def assert_uses_result_storage(
- state: State, storage: Union[str, "ReadableFileSystem"], client: "PrefectClient"
-):
+ state: State, storage: "str | ReadableFileSystem", client: "PrefectClient"
+) -> None:
assert isinstance(state.data, (ResultRecord, ResultRecordMetadata))
if isinstance(state.data, ResultRecord):
assert_blocks_equal(
@@ -267,11 +273,11 @@ async def assert_uses_result_storage(
)
-def a_test_step(**kwargs):
+def a_test_step(**kwargs: Any) -> dict[str, Any]:
kwargs.update({"output1": 1, "output2": ["b", 2, 3]})
return kwargs
-def b_test_step(**kwargs):
+def b_test_step(**kwargs: Any) -> dict[str, Any]:
kwargs.update({"output1": 1, "output2": ["b", 2, 3]})
return kwargs
diff --git a/src/prefect/transactions.py b/src/prefect/transactions.py
index 6e882dbf5c6f..e8cdd65034f4 100644
--- a/src/prefect/transactions.py
+++ b/src/prefect/transactions.py
@@ -173,7 +173,7 @@ def is_pending(self) -> bool:
def is_active(self) -> bool:
return self.state == TransactionState.ACTIVE
- def __enter__(self):
+ def __enter__(self) -> Self:
if self._token is not None:
raise RuntimeError(
"Context already entered. Context enter calls cannot be nested."
@@ -206,7 +206,7 @@ def __enter__(self):
self._token = self.__var__.set(self)
return self
- def __exit__(self, *exc_info: Any):
+ def __exit__(self, *exc_info: Any) -> None:
exc_type, exc_val, _ = exc_info
if not self._token:
raise RuntimeError(
@@ -235,7 +235,7 @@ def __exit__(self, *exc_info: Any):
self.reset()
- def begin(self):
+ def begin(self) -> None:
if (
self.store
and self.key
diff --git a/src/prefect/types/__init__.py b/src/prefect/types/__init__.py
index f36622f5a3df..c70e5cc3c8b1 100644
--- a/src/prefect/types/__init__.py
+++ b/src/prefect/types/__init__.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from functools import partial
from typing import Annotated, Any, Dict, List, Optional, Set, TypeVar, Union
from typing_extensions import Literal, TypeAlias
@@ -114,7 +116,7 @@ class SecretDict(pydantic.Secret[Dict[str, Any]]):
def validate_set_T_from_delim_string(
- value: Union[str, T, Set[T], None], type_, delim=None
+ value: Union[str, T, Set[T], None], type_: type[T], delim: str | None = None
) -> Set[T]:
"""
"no-info" before validator useful in scooping env vars
diff --git a/src/prefect/utilities/_deprecated.py b/src/prefect/utilities/_deprecated.py
new file mode 100644
index 000000000000..2d5fd8a87c7c
--- /dev/null
+++ b/src/prefect/utilities/_deprecated.py
@@ -0,0 +1,38 @@
+from typing import Any
+
+from jsonpatch import ( # type: ignore # no typing stubs available, see https://github.com/stefankoegl/python-json-patch/issues/158
+ JsonPatch as JsonPatchBase,
+)
+from pydantic import GetJsonSchemaHandler
+from pydantic.json_schema import JsonSchemaValue
+from pydantic_core import core_schema
+
+
+class JsonPatch(JsonPatchBase):
+ @classmethod
+ def __get_pydantic_core_schema__(
+ cls, source_type: Any, handler: GetJsonSchemaHandler
+ ) -> core_schema.CoreSchema:
+ return core_schema.typed_dict_schema(
+ {"patch": core_schema.typed_dict_field(core_schema.dict_schema())}
+ )
+
+ @classmethod
+ def __get_pydantic_json_schema__(
+ cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
+ ) -> JsonSchemaValue:
+ json_schema = handler(core_schema)
+ json_schema = handler.resolve_ref_schema(json_schema)
+ json_schema.pop("required", None)
+ json_schema.pop("properties", None)
+ json_schema.update(
+ {
+ "type": "array",
+ "format": "rfc6902",
+ "items": {
+ "type": "object",
+ "additionalProperties": {"type": "string"},
+ },
+ }
+ )
+ return json_schema
diff --git a/src/prefect/utilities/filesystem.py b/src/prefect/utilities/filesystem.py
index 1301a1178c91..92579c444c87 100644
--- a/src/prefect/utilities/filesystem.py
+++ b/src/prefect/utilities/filesystem.py
@@ -8,7 +8,7 @@
from collections.abc import Iterable
from contextlib import contextmanager
from pathlib import Path, PureWindowsPath
-from typing import AnyStr, Optional, Union, cast
+from typing import Any, AnyStr, Optional, Union, cast
# fsspec has no stubs, see https://github.com/fsspec/filesystem_spec/issues/625
import fsspec # type: ignore
@@ -114,7 +114,7 @@ def filename(path: str) -> str:
return path.split(sep)[-1]
-def is_local_path(path: Union[str, pathlib.Path, OpenFile]) -> bool:
+def is_local_path(path: Union[str, pathlib.Path, Any]) -> bool:
"""Check if the given path points to a local or remote file system"""
if isinstance(path, str):
try:
diff --git a/src/prefect/utilities/generics.py b/src/prefect/utilities/generics.py
index a3ff954c247c..9d5097b4e6aa 100644
--- a/src/prefect/utilities/generics.py
+++ b/src/prefect/utilities/generics.py
@@ -5,7 +5,7 @@
T = TypeVar("T", bound=BaseModel)
-ListValidator = SchemaValidator(
+ListValidator: SchemaValidator = SchemaValidator(
schema=core_schema.list_schema(
items_schema=core_schema.dict_schema(
keys_schema=core_schema.str_schema(), values_schema=core_schema.any_schema()
diff --git a/src/prefect/utilities/pydantic.py b/src/prefect/utilities/pydantic.py
index 6086931fbbbd..381456ca23bc 100644
--- a/src/prefect/utilities/pydantic.py
+++ b/src/prefect/utilities/pydantic.py
@@ -1,3 +1,4 @@
+import warnings
from typing import (
Any,
Callable,
@@ -10,18 +11,13 @@
overload,
)
-from jsonpatch import ( # type: ignore # no typing stubs available, see https://github.com/stefankoegl/python-json-patch/issues/158
- JsonPatch as JsonPatchBase,
-)
from pydantic import (
BaseModel,
- GetJsonSchemaHandler,
Secret,
TypeAdapter,
ValidationError,
)
-from pydantic.json_schema import JsonSchemaValue
-from pydantic_core import core_schema, to_jsonable_python
+from pydantic_core import to_jsonable_python
from typing_extensions import Literal
from prefect.utilities.dispatch import get_dispatch_key, lookup_type, register_base_type
@@ -262,36 +258,6 @@ def __repr__(self) -> str:
return f"PartialModel(cls={self.model_cls.__name__}, {dsp_fields})"
-class JsonPatch(JsonPatchBase):
- @classmethod
- def __get_pydantic_core_schema__(
- cls, source_type: Any, handler: GetJsonSchemaHandler
- ) -> core_schema.CoreSchema:
- return core_schema.typed_dict_schema(
- {"patch": core_schema.typed_dict_field(core_schema.dict_schema())}
- )
-
- @classmethod
- def __get_pydantic_json_schema__(
- cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
- ) -> JsonSchemaValue:
- json_schema = handler(core_schema)
- json_schema = handler.resolve_ref_schema(json_schema)
- json_schema.pop("required", None)
- json_schema.pop("properties", None)
- json_schema.update(
- {
- "type": "array",
- "format": "rfc6902",
- "items": {
- "type": "object",
- "additionalProperties": {"type": "string"},
- },
- }
- )
- return json_schema
-
-
def custom_pydantic_encoder(
type_encoders: dict[Any, Callable[[type[Any]], Any]], obj: Any
) -> Any:
@@ -382,3 +348,22 @@ def handle_secret_render(value: object, context: dict[str, Any]) -> object:
elif isinstance(value, BaseModel):
return value.model_dump(context=context)
return value
+
+
+def __getattr__(name: str) -> Any:
+ """
+ Handles imports from this module that are deprecated.
+ """
+
+ if name == "JsonPatch":
+ warnings.warn(
+ "JsonPatch is deprecated and will be removed after March 2025. "
+ "Please use `JsonPatch` from the `jsonpatch` package instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ from ._deprecated import JsonPatch
+
+ return JsonPatch
+ else:
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
diff --git a/src/prefect/workers/base.py b/src/prefect/workers/base.py
index ebc67c05d794..16b0d4534945 100644
--- a/src/prefect/workers/base.py
+++ b/src/prefect/workers/base.py
@@ -1,9 +1,22 @@
+from __future__ import annotations
+
import abc
import asyncio
import threading
from contextlib import AsyncExitStack
from functools import partial
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Type, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Generic,
+ List,
+ Optional,
+ Set,
+ Type,
+ Union,
+)
from uuid import UUID, uuid4
import anyio
@@ -13,7 +26,7 @@
from importlib_metadata import distributions
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.json_schema import GenerateJsonSchema
-from typing_extensions import Literal
+from typing_extensions import Literal, Self, TypeVar
import prefect
from prefect._internal.schemas.validators import return_v_or_none
@@ -104,7 +117,7 @@ class BaseJobConfiguration(BaseModel):
_related_objects: Dict[str, Any] = PrivateAttr(default_factory=dict)
@property
- def is_using_a_runner(self):
+ def is_using_a_runner(self) -> bool:
return self.command is not None and "prefect flow-run execute" in self.command
@field_validator("command")
@@ -175,7 +188,7 @@ async def from_template_and_values(
return cls(**populated_configuration)
@classmethod
- def json_template(cls) -> dict:
+ def json_template(cls) -> dict[str, Any]:
"""Returns a dict with job configuration as keys and the corresponding templates as values
Defaults to using the job configuration parameter name as the template variable name.
@@ -186,7 +199,7 @@ def json_template(cls) -> dict:
key2: '{{ template2 }}', # `template2` specifically provide as template
}
"""
- configuration = {}
+ configuration: dict[str, Any] = {}
properties = cls.model_json_schema()["properties"]
for k, v in properties.items():
if v.get("template"):
@@ -202,7 +215,7 @@ def prepare_for_flow_run(
flow_run: "FlowRun",
deployment: Optional["DeploymentResponse"] = None,
flow: Optional["Flow"] = None,
- ):
+ ) -> None:
"""
Prepare the job configuration for a flow run.
@@ -368,15 +381,20 @@ class BaseWorkerResult(BaseModel, abc.ABC):
identifier: str
status_code: int
- def __bool__(self):
+ def __bool__(self) -> bool:
return self.status_code == 0
+C = TypeVar("C", bound=BaseJobConfiguration)
+V = TypeVar("V", bound=BaseVariables)
+R = TypeVar("R", bound=BaseWorkerResult)
+
+
@register_base_type
-class BaseWorker(abc.ABC):
+class BaseWorker(abc.ABC, Generic[C, V, R]):
type: str
- job_configuration: Type[BaseJobConfiguration] = BaseJobConfiguration
- job_configuration_variables: Optional[Type[BaseVariables]] = None
+ job_configuration: Type[C] = BaseJobConfiguration # type: ignore
+ job_configuration_variables: Optional[Type[V]] = None
_documentation_url = ""
_logo_url = ""
@@ -418,7 +436,7 @@ def __init__(
"""
if name and ("/" in name or "%" in name):
raise ValueError("Worker name cannot contain '/' or '%'")
- self.name = name or f"{self.__class__.__name__} {uuid4()}"
+ self.name: str = name or f"{self.__class__.__name__} {uuid4()}"
self._started_event: Optional[Event] = None
self.backend_id: Optional[UUID] = None
self._logger = get_worker_logger(self)
@@ -432,7 +450,7 @@ def __init__(
self._prefetch_seconds: float = (
prefetch_seconds or PREFECT_WORKER_PREFETCH_SECONDS.value()
)
- self.heartbeat_interval_seconds = (
+ self.heartbeat_interval_seconds: int = (
heartbeat_interval_seconds or PREFECT_WORKER_HEARTBEAT_SECONDS.value()
)
@@ -461,7 +479,7 @@ def get_description(cls) -> str:
return cls._description
@classmethod
- def get_default_base_job_template(cls) -> Dict:
+ def get_default_base_job_template(cls) -> dict[str, Any]:
if cls.job_configuration_variables is None:
schema = cls.job_configuration.model_json_schema()
# remove "template" key from all dicts in schema['properties'] because it is not a
@@ -479,7 +497,9 @@ def get_default_base_job_template(cls) -> Dict:
}
@staticmethod
- def get_worker_class_from_type(type: str) -> Optional[Type["BaseWorker"]]:
+ def get_worker_class_from_type(
+ type: str,
+ ) -> Optional[Type["BaseWorker[Any, Any, Any]"]]:
"""
Returns the worker class for a given worker type. If the worker type
is not recognized, returns None.
@@ -500,7 +520,7 @@ def get_all_available_worker_types() -> List[str]:
return list(worker_registry.keys())
return []
- def get_name_slug(self):
+ def get_name_slug(self) -> str:
return slugify(self.name)
def get_flow_run_logger(self, flow_run: "FlowRun") -> PrefectLogAdapter:
@@ -524,7 +544,7 @@ async def start(
run_once: bool = False,
with_healthcheck: bool = False,
printer: Callable[..., None] = print,
- ):
+ ) -> None:
"""
Starts the worker and runs the main worker loops.
@@ -603,9 +623,9 @@ async def start(
async def run(
self,
flow_run: "FlowRun",
- configuration: BaseJobConfiguration,
- task_status: Optional[anyio.abc.TaskStatus] = None,
- ) -> BaseWorkerResult:
+ configuration: C,
+ task_status: Optional[anyio.abc.TaskStatus[int]] = None,
+ ) -> R:
"""
Runs a given flow run on the current worker.
"""
@@ -614,12 +634,12 @@ async def run(
)
@classmethod
- def __dispatch_key__(cls):
+ def __dispatch_key__(cls) -> str | None:
if cls.__name__ == "BaseWorker":
return None # The base class is abstract
return cls.type
- async def setup(self):
+ async def setup(self) -> None:
"""Prepares the worker to run."""
self._logger.debug("Setting up worker...")
self._runs_task_group = anyio.create_task_group()
@@ -637,10 +657,10 @@ async def setup(self):
self.is_setup = True
- async def teardown(self, *exc_info):
+ async def teardown(self, *exc_info: Any) -> None:
"""Cleans up resources after the worker is stopped."""
self._logger.debug("Tearing down worker...")
- self.is_setup = False
+ self.is_setup: bool = False
for scope in self._scheduled_task_scopes:
scope.cancel()
@@ -684,14 +704,16 @@ def is_worker_still_polling(self, query_interval_seconds: float) -> bool:
return is_still_polling
- async def get_and_submit_flow_runs(self):
+ async def get_and_submit_flow_runs(self) -> list["FlowRun"]:
runs_response = await self._get_scheduled_flow_runs()
self._last_polled_time = pendulum.now("utc")
return await self._submit_scheduled_flow_runs(flow_run_response=runs_response)
- async def _update_local_work_pool_info(self):
+ async def _update_local_work_pool_info(self) -> None:
+ if TYPE_CHECKING:
+ assert self._client is not None
try:
work_pool = await self._client.read_work_pool(
work_pool_name=self._work_pool_name
@@ -803,7 +825,7 @@ async def _send_worker_heartbeat(self) -> Optional[UUID]:
return worker_id
- async def sync_with_backend(self):
+ async def sync_with_backend(self) -> None:
"""
Updates the worker's local information about it's current work pool and
queues. Sends a worker heartbeat to the API.
@@ -1042,7 +1064,7 @@ def _release_limit_slot(self, flow_run_id: str) -> None:
self._limiter.release_on_behalf_of(flow_run_id)
self._logger.debug("Limit slot released for flow run '%s'", flow_run_id)
- def get_status(self):
+ def get_status(self) -> dict[str, Any]:
"""
Retrieves the status of the current worker including its name, current worker
pool, the work pool queues it is polling, and its local settings.
@@ -1234,17 +1256,17 @@ async def _give_worker_labels_to_flow_run(self, flow_run_id: UUID):
await self._client.update_flow_run_labels(flow_run_id, labels)
- async def __aenter__(self):
+ async def __aenter__(self) -> Self:
self._logger.debug("Entering worker context...")
await self.setup()
return self
- async def __aexit__(self, *exc_info):
+ async def __aexit__(self, *exc_info: Any) -> None:
self._logger.debug("Exiting worker context...")
await self.teardown(*exc_info)
- def __repr__(self):
+ def __repr__(self) -> str:
return f"Worker(pool={self._work_pool_name!r}, name={self.name!r})"
def _event_resource(self):
diff --git a/src/prefect/workers/process.py b/src/prefect/workers/process.py
index 4a89665f4fca..e0059d185395 100644
--- a/src/prefect/workers/process.py
+++ b/src/prefect/workers/process.py
@@ -13,6 +13,7 @@
For more information about work pools and workers,
checkout out the [Prefect docs](/concepts/work-pools/).
"""
+from __future__ import annotations
import contextlib
import os
@@ -30,7 +31,7 @@
import anyio.abc
from pydantic import Field, field_validator
-from prefect._internal.schemas.validators import validate_command
+from prefect._internal.schemas.validators import validate_working_dir
from prefect.client.schemas import FlowRun
from prefect.client.schemas.filters import (
FlowRunFilter,
@@ -85,19 +86,21 @@ class ProcessJobConfiguration(BaseJobConfiguration):
@field_validator("working_dir")
@classmethod
- def validate_command(cls, v: str) -> str:
- return validate_command(v)
+ def validate_working_dir(cls, v: Path | str | None) -> Path | None:
+ if isinstance(v, str):
+ return validate_working_dir(v)
+ return v
def prepare_for_flow_run(
self,
flow_run: "FlowRun",
deployment: Optional["DeploymentResponse"] = None,
flow: Optional["Flow"] = None,
- ):
+ ) -> None:
super().prepare_for_flow_run(flow_run, deployment, flow)
- self.env = {**os.environ, **self.env}
- self.command = (
+ self.env: dict[str, str | None] = {**os.environ, **self.env}
+ self.command: str | None = (
f"{get_sys_executable()} -m prefect.engine"
if self.command == self._base_flow_run_command()
else self.command
@@ -134,10 +137,12 @@ class ProcessWorkerResult(BaseWorkerResult):
"""Contains information about the final state of a completed process"""
-class ProcessWorker(BaseWorker):
+class ProcessWorker(
+ BaseWorker[ProcessJobConfiguration, ProcessVariables, ProcessWorkerResult]
+):
type = "process"
- job_configuration = ProcessJobConfiguration
- job_configuration_variables = ProcessVariables
+ job_configuration: type[ProcessJobConfiguration] = ProcessJobConfiguration
+ job_configuration_variables: type[ProcessVariables] | None = ProcessVariables
_description = (
"Execute flow runs as subprocesses on a worker. Works well for local execution"
@@ -152,7 +157,7 @@ async def start(
run_once: bool = False,
with_healthcheck: bool = False,
printer: Callable[..., None] = print,
- ):
+ ) -> None:
"""
Starts the worker and runs the main worker loops.
@@ -241,8 +246,8 @@ async def run(
self,
flow_run: FlowRun,
configuration: ProcessJobConfiguration,
- task_status: Optional[anyio.abc.TaskStatus] = None,
- ):
+ task_status: Optional[anyio.abc.TaskStatus[int]] = None,
+ ) -> ProcessWorkerResult:
command = configuration.command
if not command:
command = f"{get_sys_executable()} -m prefect.engine"
@@ -322,7 +327,7 @@ async def kill_process(
self,
infrastructure_pid: str,
grace_seconds: int = 30,
- ):
+ ) -> None:
hostname, pid = _parse_infrastructure_pid(infrastructure_pid)
if hostname != socket.gethostname():
@@ -372,7 +377,7 @@ async def kill_process(
# process ended right after the check above.
return
- async def check_for_cancelled_flow_runs(self):
+ async def check_for_cancelled_flow_runs(self) -> list["FlowRun"]:
if not self.is_setup:
raise RuntimeError(
"Worker is not set up. Please make sure you are running this worker "
@@ -429,7 +434,7 @@ async def check_for_cancelled_flow_runs(self):
return cancelling_flow_runs
- async def cancel_run(self, flow_run: "FlowRun"):
+ async def cancel_run(self, flow_run: "FlowRun") -> None:
run_logger = self.get_flow_run_logger(flow_run)
try:
diff --git a/src/prefect/workers/server.py b/src/prefect/workers/server.py
index e513df3c22ef..a073621afc9a 100644
--- a/src/prefect/workers/server.py
+++ b/src/prefect/workers/server.py
@@ -1,4 +1,4 @@
-from typing import Union
+from typing import Any
import uvicorn
import uvicorn.server
@@ -10,14 +10,13 @@
PREFECT_WORKER_WEBSERVER_PORT,
)
from prefect.workers.base import BaseWorker
-from prefect.workers.process import ProcessWorker
def build_healthcheck_server(
- worker: Union[BaseWorker, ProcessWorker],
+ worker: BaseWorker[Any, Any, Any],
query_interval_seconds: float,
log_level: str = "error",
-):
+) -> uvicorn.Server:
"""
Build a healthcheck FastAPI server for a worker.
@@ -54,7 +53,7 @@ def perform_health_check():
def start_healthcheck_server(
- worker: Union[BaseWorker, ProcessWorker],
+ worker: BaseWorker[Any, Any, Any],
query_interval_seconds: float,
log_level: str = "error",
) -> None:
diff --git a/tests/infrastructure/provisioners/test_coiled.py b/tests/infrastructure/provisioners/test_coiled.py
new file mode 100644
index 000000000000..eca5cdceb11f
--- /dev/null
+++ b/tests/infrastructure/provisioners/test_coiled.py
@@ -0,0 +1,186 @@
+from unittest.mock import AsyncMock, MagicMock, patch
+from uuid import UUID
+
+import pytest
+
+from prefect.blocks.core import Block
+from prefect.client.orchestration import PrefectClient
+from prefect.infrastructure.provisioners.coiled import CoiledPushProvisioner
+
+
+@pytest.fixture(autouse=True)
+async def coiled_credentials_block_cls():
+ class MockCoiledCredentials(Block):
+ _block_type_name = "Coiled Credentials"
+ api_token: str
+
+ await MockCoiledCredentials.register_type_and_schema()
+
+ return MockCoiledCredentials
+
+
+@pytest.fixture
+async def coiled_credentials_block_id(coiled_credentials_block_cls: Block):
+ block_doc_id = await coiled_credentials_block_cls(api_token="existing_token").save(
+ "work-pool-name-coiled-credentials", overwrite=True
+ )
+
+ return block_doc_id
+
+
+@pytest.fixture
+def mock_run_process():
+ with patch("prefect.infrastructure.provisioners.coiled.run_process") as mock:
+ yield mock
+
+
+@pytest.fixture
+def mock_coiled(monkeypatch):
+ mock = MagicMock()
+ monkeypatch.setattr("prefect.infrastructure.provisioners.coiled.coiled", mock)
+ yield mock
+
+
+@pytest.fixture
+def mock_importlib():
+ with patch("prefect.infrastructure.provisioners.coiled.importlib") as mock:
+ yield mock
+
+
+@pytest.fixture
+def mock_confirm():
+ with patch("prefect.infrastructure.provisioners.coiled.Confirm") as mock:
+ yield mock
+
+
+@pytest.fixture
+def mock_dask_config():
+ with patch(
+ "prefect.infrastructure.provisioners.coiled.CoiledPushProvisioner._get_coiled_token"
+ ) as mock:
+ mock.return_value = "local-api-token-from-dask-config"
+ yield mock
+
+
+async def test_provision(
+ prefect_client: PrefectClient,
+ mock_run_process: AsyncMock,
+ mock_coiled: MagicMock,
+ mock_dask_config: MagicMock,
+ mock_confirm: MagicMock,
+ mock_importlib: MagicMock,
+):
+ """
+ Test provision from a clean slate:
+ - Coiled is not installed
+ - Coiled token does not exist
+ - CoiledCredentials block does not exist
+ """
+ provisioner = CoiledPushProvisioner()
+ provisioner.console.is_interactive = True
+
+ mock_confirm.ask.side_effect = [
+ True,
+ True,
+ True,
+ ] # confirm provision, install coiled, create new token
+ mock_importlib.import_module.side_effect = [
+ ModuleNotFoundError,
+ mock_coiled,
+ mock_coiled,
+ ]
+ # simulate coiled token creation
+ mock_coiled.config.Config.return_value.get.side_effect = [
+ None,
+ None,
+ "mock_token",
+ ]
+
+ work_pool_name = "work-pool-name"
+ base_job_template = {"variables": {"properties": {"credentials": {}}}}
+
+ result = await provisioner.provision(
+ work_pool_name, base_job_template, client=prefect_client
+ )
+
+ # Check if the block document exists and has expected values
+ block_document = await prefect_client.read_block_document_by_name(
+ "work-pool-name-coiled-credentials", "coiled-credentials"
+ )
+
+ assert block_document.data["api_token"], str == "mock_token"
+
+ # Check if the base job template was updated
+ assert result["variables"]["properties"]["credentials"] == {
+ "default": {"$ref": {"block_document_id": str(block_document.id)}},
+ }
+
+
+async def test_provision_existing_coiled_credentials_block(
+ prefect_client: PrefectClient,
+ coiled_credentials_block_id: UUID,
+ mock_run_process: AsyncMock,
+):
+ """
+ Test provision with an existing CoiledCredentials block.
+ """
+ provisioner = CoiledPushProvisioner()
+
+ work_pool_name = "work-pool-name"
+ base_job_template = {"variables": {"properties": {"credentials": {}}}}
+
+ result = await provisioner.provision(
+ work_pool_name, base_job_template, client=prefect_client
+ )
+
+ # Check if the base job template was updated
+ assert result["variables"]["properties"]["credentials"] == {
+ "default": {"$ref": {"block_document_id": str(coiled_credentials_block_id)}},
+ }
+
+ mock_run_process.assert_not_called()
+
+
+async def test_provision_existing_coiled_credentials(
+ prefect_client: PrefectClient,
+ mock_run_process: AsyncMock,
+ mock_coiled: MagicMock,
+ mock_dask_config: MagicMock,
+ mock_confirm: MagicMock,
+ mock_importlib: MagicMock,
+):
+ """
+ Test provision where the user has coiled installed and an existing Coiled configuration.
+ """
+ provisioner = CoiledPushProvisioner()
+ mock_confirm.ask.side_effect = [
+ True,
+ ] # confirm provision
+ mock_importlib.import_module.side_effect = [
+ mock_coiled,
+ mock_coiled,
+ ] # coiled is already installed
+ mock_coiled.config.Config.return_value.get.side_effect = [
+ "mock_token",
+ ] # coiled config exists
+
+ work_pool_name = "work-pool-name"
+ base_job_template = {"variables": {"properties": {"credentials": {}}}}
+
+ result = await provisioner.provision(
+ work_pool_name, base_job_template, client=prefect_client
+ )
+
+ # Check if the block document exists and has expected values
+ block_document = await prefect_client.read_block_document_by_name(
+ "work-pool-name-coiled-credentials", "coiled-credentials"
+ )
+
+ assert block_document.data["api_token"], str == "mock_token"
+
+ # Check if the base job template was updated
+ assert result["variables"]["properties"]["credentials"] == {
+ "default": {"$ref": {"block_document_id": str(block_document.id)}},
+ }
+
+ mock_run_process.assert_not_called()
diff --git a/tests/utilities/test_pydantic.py b/tests/utilities/test_pydantic.py
index d867a7a3ed38..10bed971b300 100644
--- a/tests/utilities/test_pydantic.py
+++ b/tests/utilities/test_pydantic.py
@@ -1,3 +1,4 @@
+import warnings
from pathlib import Path
import cloudpickle
@@ -6,12 +7,15 @@
from prefect.utilities.dispatch import register_type
from prefect.utilities.pydantic import (
- JsonPatch,
PartialModel,
add_cloudpickle_reduction,
get_class_fields_only,
)
+with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
+ from prefect.utilities.pydantic import JsonPatch
+
class SimplePydantic(BaseModel):
x: int
diff --git a/ui-v2/src/api/prefect.ts b/ui-v2/src/api/prefect.ts
index 738414555336..5bf741b1fea9 100644
--- a/ui-v2/src/api/prefect.ts
+++ b/ui-v2/src/api/prefect.ts
@@ -8487,7 +8487,9 @@ export interface components {
/** @description The current task run state. */
state?: components["schemas"]["State"] | null;
};
- TaskRunCount: Record;
+ TaskRunCount: {
+ [key: string]: number;
+ };
/**
* TaskRunCreate
* @description Data used by the Prefect REST API to create a task run
@@ -15409,7 +15411,7 @@ export interface operations {
[name: string]: unknown;
};
content: {
- "application/json": unknown;
+ "application/json": Record;
};
};
/** @description Validation Error */
@@ -15671,7 +15673,7 @@ export interface operations {
[name: string]: unknown;
};
content: {
- "application/json": unknown;
+ "application/json": string;
};
};
/** @description Validation Error */
diff --git a/ui/src/pages/Unauthenticated.vue b/ui/src/pages/Unauthenticated.vue
index faa79ee9fa3f..5b3d4bf7e1a7 100644
--- a/ui/src/pages/Unauthenticated.vue
+++ b/ui/src/pages/Unauthenticated.vue
@@ -8,7 +8,7 @@
=> {
try {
localStorage.setItem('prefect-password', btoa(password.value))
- router.push(props.redirect || '/')
+ api.admin.authCheck().then(status_code => {
+ if (status_code == 401) {
+ localStorage.removeItem('prefect-pasword')
+ showToast('Authentication failed.', 'error', { timeout: false })
+ if (router.currentRoute.value.name !== 'login') {
+ router.push({
+ name: 'login',
+ query: { redirect: router.currentRoute.value.fullPath }
+ })
+ }
+ } else {
+ api.health.isHealthy().then(healthy => {
+ if (!healthy) {
+ showToast(`Can't connect to Server API at ${config.baseUrl}. Check that it's accessible from your machine.`, 'error', { timeout: false })
+ }
+ router.push(props.redirect || '/')
+ })
+ }
+ })
} catch (e) {
localStorage.removeItem('prefect-password')
error.value = 'Invalid password'