Skip to content

Commit

Permalink
feat(backend): add e2e useful tracing
Browse files Browse the repository at this point in the history
Enabled tracing component inside git agent in order to have end to end
spans for a single request.
Added spans at most useful places such as message bus' execute_message
and database execute_query.

Signed-off-by: Fatih Acar <fatih@opsmill.com>
  • Loading branch information
fatih-acar committed Mar 3, 2024
1 parent 2bdf259 commit c5fd503
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 108 deletions.
12 changes: 11 additions & 1 deletion backend/infrahub/cli/git_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from prometheus_client import start_http_server
from rich.logging import RichHandler

from infrahub import config
from infrahub import __version__, config
from infrahub.components import ComponentType
from infrahub.core.initialization import initialization
from infrahub.database import InfrahubDatabase, get_db
Expand All @@ -21,6 +21,7 @@
from infrahub.services import InfrahubServices
from infrahub.services.adapters.cache.redis import RedisCache
from infrahub.services.adapters.message_bus.rabbitmq import RabbitMQMessageBus
from infrahub.trace import configure_trace

app = typer.Typer()

Expand Down Expand Up @@ -66,6 +67,15 @@ async def _start(debug: bool, port: int) -> None:
client = await InfrahubClient.init(address=config.SETTINGS.main.internal_address, retry_on_failure=True, log=log)
await client.branch.all()

# Initialize trace
if config.SETTINGS.trace.enable:
configure_trace(
service="infrahub-git-agent",
version=__version__,
exporter_endpoint=config.SETTINGS.trace.trace_endpoint,
exporter_protocol=config.SETTINGS.trace.exporter_protocol,
)

# Initialize the lock
initialize_lock()

Expand Down
30 changes: 17 additions & 13 deletions backend/infrahub/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Record,
)
from neo4j.exceptions import ClientError, ServiceUnavailable
from otel_extensions import get_tracer
from typing_extensions import Self

from infrahub import config
Expand Down Expand Up @@ -161,19 +162,22 @@ async def close(self):
async def execute_query(
self, query: str, params: Optional[Dict[str, Any]] = None, name: Optional[str] = "undefined"
) -> List[Record]:
with QUERY_EXECUTION_METRICS.labels(str(self._session_mode), name).time():
if self.is_transaction:
execution_method = await self.transaction()
else:
execution_method = await self.session()

try:
response = await execution_method.run(query=query, parameters=params)
except ServiceUnavailable as exc:
log.error("Database Service unavailable", error=str(exc))
raise DatabaseError(message="Unable to connect to the database") from exc

return [item async for item in response]
with get_tracer(__name__).start_as_current_span("execute_db_query") as span:
span.set_attribute("query", query)

with QUERY_EXECUTION_METRICS.labels(str(self._session_mode), name).time():
if self.is_transaction:
execution_method = await self.transaction()
else:
execution_method = await self.session()

try:
response = await execution_method.run(query=query, parameters=params)
except ServiceUnavailable as exc:
log.error("Database Service unavailable", error=str(exc))
raise DatabaseError(message="Unable to connect to the database") from exc

return [item async for item in response]

def render_list_comprehension(self, items: str, item_name: str) -> str:
if self.db_type == DatabaseType.MEMGRAPH:
Expand Down
2 changes: 2 additions & 0 deletions backend/infrahub/graphql/mutations/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from graphql import GraphQLResolveInfo
from infrahub_sdk.utils import extract_fields, extract_fields_first_node
from typing_extensions import Self
from otel_extensions import instrumented

from infrahub import config, lock
from infrahub.core import registry
Expand Down Expand Up @@ -46,6 +47,7 @@ class Arguments:
ok = Boolean()
object = Field(BranchType)

@instrumented
@classmethod
async def mutate(
cls, root: dict, info: GraphQLResolveInfo, data: BranchCreateInput, background_execution: bool = False
Expand Down
37 changes: 21 additions & 16 deletions backend/infrahub/message_bus/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
transform,
trigger,
)
from otel_extensions import get_tracer

from infrahub.message_bus.types import MessageTTL
from infrahub.services import InfrahubServices
from infrahub.tasks.check import set_check_status
Expand Down Expand Up @@ -68,19 +70,22 @@


async def execute_message(routing_key: str, message_body: bytes, service: InfrahubServices):
message_data = json.loads(message_body)
message = messages.MESSAGE_MAP[routing_key](**message_data)
message.set_log_data(routing_key=routing_key)
try:
await COMMAND_MAP[routing_key](message=message, service=service)
except Exception as exc: # pylint: disable=broad-except
if message.reply_requested:
response = RPCErrorResponse(errors=[str(exc)], initial_message=message.model_dump())
await service.reply(message=response, initiator=message)
return
if message.reached_max_retries:
service.log.exception("Message failed after maximum number of retries", error=exc)
await set_check_status(message, conclusion="failure", service=service)
return
message.increase_retry_count()
await service.send(message, delay=MessageTTL.FIVE)
with get_tracer(__name__).start_as_current_span("execute_message") as span:
span.set_attribute("routing_key", routing_key)

message_data = json.loads(message_body)
message = messages.MESSAGE_MAP[routing_key](**message_data)
message.set_log_data(routing_key=routing_key)
try:
await COMMAND_MAP[routing_key](message=message, service=service)
except Exception as exc: # pylint: disable=broad-except
if message.reply_requested:
response = RPCErrorResponse(errors=[str(exc)], initial_message=message.model_dump())
await service.reply(message=response, initiator=message)
return
if message.reached_max_retries:
service.log.exception("Message failed after maximum number of retries", error=exc)
await set_check_status(message, conclusion="failure", service=service)
return
message.increase_retry_count()
await service.send(message, delay=MessageTTL.FIVE)
35 changes: 21 additions & 14 deletions backend/infrahub/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from infrahub_sdk.timestamp import TimestampFormatError
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor, Span
from pydantic import ValidationError
from starlette_exporter import PrometheusMiddleware, handle_metrics

Expand All @@ -32,7 +32,7 @@
from infrahub.services import InfrahubServices, services
from infrahub.services.adapters.cache.redis import RedisCache
from infrahub.services.adapters.message_bus.rabbitmq import RabbitMQMessageBus
from infrahub.trace import add_span_exception, configure_trace, get_traceid, get_tracer
from infrahub.trace import add_span_exception, configure_trace, get_traceid
from infrahub.worker import WORKER_IDENTITY


Expand All @@ -42,8 +42,8 @@ async def app_initialization(application: FastAPI) -> None:
# Initialize trace
if config.SETTINGS.trace.enable:
configure_trace(
service="infrahub-server",
version=__version__,
exporter_type=config.SETTINGS.trace.exporter_type,
exporter_endpoint=config.SETTINGS.trace.trace_endpoint,
exporter_protocol=config.SETTINGS.trace.exporter_protocol,
)
Expand Down Expand Up @@ -92,8 +92,13 @@ async def lifespan(application: FastAPI):
redoc_url="/api/redoc",
)

FastAPIInstrumentor().instrument_app(app, excluded_urls=".*/metrics")
tracer = get_tracer()

def server_request_hook(span: Span, scope: dict): # pylint: disable=unused-argument
if span and span.is_recording():
span.set_attribute("worker", WORKER_IDENTITY)


FastAPIInstrumentor().instrument_app(app, excluded_urls=".*/metrics", server_request_hook=server_request_hook)

FRONTEND_DIRECTORY = os.environ.get("INFRAHUB_FRONTEND_DIRECTORY", os.path.abspath("frontend"))
FRONTEND_ASSET_DIRECTORY = f"{FRONTEND_DIRECTORY}/dist/assets"
Expand All @@ -115,15 +120,17 @@ async def lifespan(application: FastAPI):
async def logging_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
clear_log_context()
request_id = correlation_id.get()
with tracer.start_as_current_span("processing request " + request_id):
trace_id = get_traceid()
set_log_data(key="request_id", value=request_id)
set_log_data(key="app", value="infrahub.api")
set_log_data(key="worker", value=WORKER_IDENTITY)
if trace_id:
set_log_data(key="trace_id", value=trace_id)
response = await call_next(request)
return response

set_log_data(key="request_id", value=request_id)
set_log_data(key="app", value="infrahub.api")
set_log_data(key="worker", value=WORKER_IDENTITY)

trace_id = get_traceid()
if trace_id:
set_log_data(key="trace_id", value=trace_id)

response = await call_next(request)
return response


@app.middleware("http")
Expand Down
61 changes: 51 additions & 10 deletions backend/infrahub/services/adapters/message_bus/rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from typing import TYPE_CHECKING, Awaitable, Callable, List, MutableMapping, Optional, Type, TypeVar

import aio_pika
import opentelemetry.instrumentation.aio_pika.span_builder
from infrahub_sdk import UUIDT
from opentelemetry import context, propagate
from opentelemetry.instrumentation.aio_pika import AioPikaInstrumentor
from opentelemetry.semconv.trace import SpanAttributes

from infrahub import config
from infrahub.components import ComponentType
Expand All @@ -24,6 +28,7 @@
AbstractQueue,
AbstractRobustConnection,
)
from opentelemetry.instrumentation.aio_pika.span_builder import SpanBuilder

from infrahub.config import BrokerSettings
from infrahub.services import InfrahubServices
Expand All @@ -32,6 +37,29 @@
ResponseClass = TypeVar("ResponseClass")


AioPikaInstrumentor().instrument()


# TODO: remove this once https://github.com/open-telemetry/opentelemetry-python-contrib/issues/1835 is resolved
def patch_spanbuilder_set_channel() -> None:
"""
The default SpanBuilder.set_channel does not work with aio_pika 9.1 and the refactored connection
attribute
"""

def set_channel(self: SpanBuilder, channel: AbstractChannel) -> None:
if hasattr(channel, "_connection"):
url = channel._connection.url
self._attributes.update(
{
SpanAttributes.NET_PEER_NAME: url.host,
SpanAttributes.NET_PEER_PORT: url.port,
}
)

opentelemetry.instrumentation.aio_pika.span_builder.SpanBuilder.set_channel = set_channel # type: ignore


async def _add_request_id(message: InfrahubMessage) -> None:
log_data = get_log_data()
message.meta.request_id = log_data.get("request_id", "")
Expand All @@ -54,6 +82,8 @@ def __init__(self, settings: Optional[BrokerSettings] = None) -> None:
self.futures: MutableMapping[str, asyncio.Future] = {}

async def initialize(self, service: InfrahubServices) -> None:
patch_spanbuilder_set_channel()

self.service = service
self.connection = await aio_pika.connect_robust(
host=self.settings.address,
Expand Down Expand Up @@ -193,17 +223,28 @@ async def subscribe(self) -> None:
async for message in qiterator:
try:
async with message.process(requeue=False):
# auto instrumentation not supported yet for RPCs, do it ourselves...
token = None
headers = message.headers or {}
ctx = propagate.extract(headers)
if ctx is not None:
token = context.attach(ctx)

clear_log_context()
if message.routing_key in messages.MESSAGE_MAP:
await execute_message(
routing_key=message.routing_key, message_body=message.body, service=self.service
)
else:
self.service.log.error(
"Unhandled routing key for message",
routing_key=message.routing_key,
message=message.body,
)
try:
if message.routing_key in messages.MESSAGE_MAP:
await execute_message(
routing_key=message.routing_key, message_body=message.body, service=self.service
)
else:
self.service.log.error(
"Unhandled routing key for message",
routing_key=message.routing_key,
message=message.body,
)
finally:
if token is not None:
context.detach(token)

except Exception: # pylint: disable=broad-except
self.service.log.exception("Processing error for message %r" % message)
Expand Down
56 changes: 7 additions & 49 deletions backend/infrahub/trace.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,6 @@
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
OTLPSpanExporter as GRPCSpanExporter,
)
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as HTTPSpanExporter,
)
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
from opentelemetry.trace import StatusCode


def get_tracer(name: str = "infrahub") -> trace.Tracer:
return trace.get_tracer(name)
from otel_extensions import TelemetryOptions, init_telemetry_provider


def get_current_span_with_context() -> trace.Span:
Expand Down Expand Up @@ -54,42 +42,12 @@ def add_span_exception(exception: Exception) -> None:
current_span.record_exception(exception)


def create_tracer_provider(
version: str, exporter_type: str, exporter_endpoint: str = None, exporter_protocol: str = None
) -> TracerProvider:
# Create a BatchSpanProcessor exporter based on the type
if exporter_type == "console":
exporter = ConsoleSpanExporter()
elif exporter_type == "otlp":
if not exporter_endpoint:
raise ValueError("Exporter type is set to otlp but endpoint is not set")
if exporter_protocol == "http/protobuf":
exporter = HTTPSpanExporter(endpoint=exporter_endpoint)
elif exporter_protocol == "grpc":
exporter = GRPCSpanExporter(endpoint=exporter_endpoint)
else:
raise ValueError("Exporter type unsupported by Infrahub")

# Resource can be required for some backends, e.g. Jaeger
resource = Resource(attributes={"service.name": "infrahub", "service.version": version})
span_processor = BatchSpanProcessor(exporter)
tracer_provider = TracerProvider(resource=resource)
tracer_provider.add_span_processor(span_processor)

return tracer_provider


def configure_trace(
version: str, exporter_type: str, exporter_endpoint: str = None, exporter_protocol: str = None
service: str, version: str, exporter_endpoint: str | None = None, exporter_protocol: str = None
) -> None:
# Create a trace provider with the exporter
tracer_provider = create_tracer_provider(
version=version,
exporter_type=exporter_type,
exporter_endpoint=exporter_endpoint,
exporter_protocol=exporter_protocol,
options = TelemetryOptions(
OTEL_SERVICE_NAME=service,
OTEL_EXPORTER_OTLP_ENDPOINT=exporter_endpoint,
OTEL_EXPORTER_OTLP_PROTOCOL=exporter_protocol,
)
tracer_provider.get_tracer(__name__)

# Register the trace provider
trace.set_tracer_provider(tracer_provider)
init_telemetry_provider(options, **{"service.version": version})
Loading

0 comments on commit c5fd503

Please sign in to comment.