diff --git a/pccommon/pccommon/tracing.py b/pccommon/pccommon/tracing.py index 80aabc1e..13b5b75f 100644 --- a/pccommon/pccommon/tracing.py +++ b/pccommon/pccommon/tracing.py @@ -3,11 +3,14 @@ import re from typing import Awaitable, Callable, List, Optional, Tuple, Union, cast +import fastapi from fastapi import Request, Response from opencensus.ext.azure.trace_exporter import AzureExporter from opencensus.trace.samplers import ProbabilitySampler from opencensus.trace.span import SpanKind from opencensus.trace.tracer import Tracer +from opentelemetry import trace +from starlette.datastructures import QueryParams from pccommon.config import get_apis_config from pccommon.constants import ( @@ -25,6 +28,11 @@ logger = logging.getLogger(__name__) +COLLECTION = "spatio.collection" +COLLECTIONS = "spatio.collections" +ITEM = "spatio.item" +ITEMS = "spatio.items" + exporter = ( AzureExporter( connection_string=( @@ -249,3 +257,49 @@ def _iter_cql(cql: dict, property_name: str) -> Optional[Union[str, List[str]]]: return result # No collection was found return None + + +def add_stac_attributes_from_search(search_json: str, request: fastapi.Request) -> None: + """ + Try to add the Collection ID and Item ID from a search to the current span. + """ + collection_id, item_id = parse_collection_from_search( + json.loads(search_json), request.method, request.query_params + ) + span = trace.get_current_span() + + if collection_id is not None: + span.set_attribute(COLLECTIONS, collection_id) + + if item_id is not None: + span.set_attribute(ITEMS, item_id) + + +def parse_collection_from_search( + body: dict, + method: str, + query_params: QueryParams, +) -> Tuple[Optional[str], Optional[str]]: + """ + Parse the collection id from a search request. + + The search endpoint is a bit of a special case. If it's a GET, the collection + and item ids are in the querystring. If it's a POST, the collection and item may + be in either a CQL-JSON or CQL2-JSON filter body, or a query/stac-ql body. + """ + if method.lower() == "get": + collection_id = query_params.get("collections") + item_id = query_params.get("ids") + return (collection_id, item_id) + elif method.lower() == "post": + try: + if "collections" in body: + return _parse_queryjson(body) + elif "filter" in body: + return _parse_cqljson(body["filter"]) + except json.JSONDecodeError as e: + logger.warning( + "Unable to parse search body as JSON. Ignoring collection" + f"parameter. {e}" + ) + return (None, None) diff --git a/pccommon/setup.py b/pccommon/setup.py index 60cc7e8e..4197ae50 100644 --- a/pccommon/setup.py +++ b/pccommon/setup.py @@ -20,6 +20,8 @@ "html-sanitizer==2.4", # Soon available as lxml[html_clean] "lxml_html_clean==0.1.0", + "opentelemetry-api==1.21.0", + "opentelemetry-sdk==1.21.0", ] extra_reqs = { diff --git a/pcstac/pcstac/client.py b/pcstac/pcstac/client.py index 9e162d74..017f6034 100644 --- a/pcstac/pcstac/client.py +++ b/pcstac/pcstac/client.py @@ -20,6 +20,7 @@ from pccommon.constants import DEFAULT_COLLECTION_REGION from pccommon.logging import get_custom_dimensions from pccommon.redis import back_pressure, cached_result, rate_limit +from pccommon.tracing import add_stac_attributes_from_search from pcstac.config import API_DESCRIPTION, API_LANDING_PAGE_ID, API_TITLE, get_settings from pcstac.contants import ( CACHE_KEY_COLLECTION, @@ -227,6 +228,8 @@ async def _fetch() -> ItemCollection: return item_collection search_json = search_request.json() + add_stac_attributes_from_search(search_json, request) + logger.info( "STAC: Item search body", extra=get_custom_dimensions({"search_body": search_json}, request), diff --git a/pcstac/pcstac/main.py b/pcstac/pcstac/main.py index 09304224..405ae28c 100644 --- a/pcstac/pcstac/main.py +++ b/pcstac/pcstac/main.py @@ -1,9 +1,9 @@ """FastAPI application using PGStac.""" import logging import os -from typing import Any, Dict +from typing import Any, Awaitable, Callable, Dict -from fastapi import FastAPI, Request +from fastapi import FastAPI, Request, Response from fastapi.exceptions import RequestValidationError, StarletteHTTPException from fastapi.openapi.utils import get_openapi from fastapi.responses import ORJSONResponse @@ -18,6 +18,7 @@ from pccommon.middleware import add_timeout, http_exception_handler from pccommon.openapi import fixup_schema from pccommon.redis import connect_to_redis +from pccommon.tracing import trace_request from pcstac.api import PCStacApi from pcstac.client import PCClient from pcstac.config import ( @@ -75,6 +76,15 @@ add_timeout(app, app_settings.request_timeout) + +@app.middleware("http") +async def _request_middleware( + request: Request, call_next: Callable[[Request], Awaitable[Response]] +) -> Response: + """Add a trace to all requests.""" + return await trace_request(ServiceName.STAC, request, call_next) + + # Note: If requests are being sent through an application gateway like # nginx-ingress, you may need to configure CORS through that system. app.add_middleware( diff --git a/pctiler/pctiler/main.py b/pctiler/pctiler/main.py index b6be2bdc..3553d0f4 100755 --- a/pctiler/pctiler/main.py +++ b/pctiler/pctiler/main.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 import logging import os -from typing import Dict, List +from typing import Awaitable, Callable, Dict, List -from fastapi import FastAPI +from fastapi import FastAPI, Request, Response from fastapi.openapi.utils import get_openapi from morecantile.defaults import tms as defaultTileMatrices from morecantile.models import TileMatrixSet @@ -21,6 +21,7 @@ from pccommon.logging import ServiceName, init_logging from pccommon.middleware import add_timeout, http_exception_handler from pccommon.openapi import fixup_schema +from pccommon.tracing import trace_request from pctiler.config import get_settings from pctiler.endpoints import ( configuration, @@ -89,6 +90,15 @@ add_exception_handlers(app, DEFAULT_STATUS_CODES) add_exception_handlers(app, MOSAIC_STATUS_CODES) + +@app.middleware("http") +async def _request_middleware( + request: Request, call_next: Callable[[Request], Awaitable[Response]] +) -> Response: + """Add a trace to all requests.""" + return await trace_request(ServiceName.TILER, request, call_next) + + app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=3600") app.add_middleware(TotalTimeMiddleware)