Skip to content

Commit

Permalink
Better context management
Browse files Browse the repository at this point in the history
  • Loading branch information
ItayGibel-helios committed Nov 23, 2021
1 parent 1c0eea0 commit 0d1d010
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
API
___
"""
import atexit
from typing import Collection

import kafka
Expand All @@ -50,6 +51,7 @@
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.kafka.package import _instruments
from opentelemetry.instrumentation.kafka.utils import (
KafkaInstrumentorContextManager,
_wrap_next,
_wrap_send,
dummy_callback,
Expand Down Expand Up @@ -83,11 +85,16 @@ def _instrument(self, **kwargs):
__name__, __version__, tracer_provider=tracer_provider
)

context_manager = KafkaInstrumentorContextManager()
atexit.register(context_manager.close)

wrap_function_wrapper(
kafka.KafkaProducer, "send", _wrap_send(tracer, produce_hook)
)
wrap_function_wrapper(
kafka.KafkaConsumer, "__next__", _wrap_next(tracer, consume_hook)
kafka.KafkaConsumer,
"__next__",
_wrap_next(tracer, context_manager, consume_hook),
)

def _uninstrument(self, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
from logging import getLogger
from typing import Callable, Dict, List, Optional

from kafka import KafkaConsumer

from opentelemetry import trace
from opentelemetry.context import attach
from opentelemetry.context import attach, detach
from opentelemetry.context.context import Context
from opentelemetry.propagate import extract, inject
from opentelemetry.propagators import textmap
from opentelemetry.semconv.trace import SpanAttributes
Expand All @@ -13,6 +16,46 @@
_LOG = getLogger(__name__)


class KafkaInstrumentorContextManager:
def __init__(self):
self.spans = dict()
self.tokens = dict()

def set_consumer_context(
self, consumer: KafkaConsumer, context: Context, span: Span
):
self.set_span(consumer, span)
self.attach_context(consumer, context)

def set_span(self, consumer: KafkaConsumer, span: Span):
self.close_span(consumer)
self.spans[consumer] = span

def close_span(self, consumer: KafkaConsumer):
if consumer in self.spans:
self.spans.get(consumer).close()
del self.spans[consumer]

def attach_context(self, consumer: KafkaConsumer, context: Context):
self.detach_context(consumer)
self.tokens[consumer] = attach(context)

def detach_context(self, consumer: KafkaConsumer):
if consumer in self.tokens:
detach(self.tokens.get(consumer))
del self.tokens[consumer]

def close(self, kafka_consumer: KafkaConsumer = None):
if kafka_consumer:
self.close_span(kafka_consumer)
self.detach_context(kafka_consumer)
else:
for consumer in self.spans:
self.close_span(consumer)
for consumer in self.tokens:
self.detach_context(consumer)


class KafkaPropertiesExtractor:
@staticmethod
def extract_bootstrap_servers(instance):
Expand Down Expand Up @@ -167,26 +210,30 @@ def _traced_send(func, instance, args, kwargs):


def _start_consume_span_with_extracted_context(
tracer: Tracer, headers: List, topic: str
tracer: Tracer,
context_manager: KafkaInstrumentorContextManager,
instance: KafkaConsumer,
headers: List,
topic: str,
) -> Span:
extracted_context = extract(headers, getter=_kafka_getter)
span_name = _get_span_name("receive", topic)
span = tracer.start_span(
span_name, context=extracted_context, kind=trace.SpanKind.CONSUMER
)
new_context = set_span_in_context(span, extracted_context)
attach(new_context)
context_manager.set_consumer_context(instance, new_context, span)
return span


def _wrap_next(tracer: Tracer, consume_hook: HookT) -> Callable:
def _wrap_next(
tracer: Tracer,
context_manager: KafkaInstrumentorContextManager,
consume_hook: HookT,
) -> Callable:
def _traced_next(func, instance, args, kwargs):
# End the current span if exists before processing the next record
current_span = trace.get_current_span()
if current_span.is_recording() and current_span.name.startswith(
"receive"
):
current_span.end()
context_manager.close(instance)

record = func(*args, **kwargs)

Expand All @@ -198,7 +245,7 @@ def _traced_next(func, instance, args, kwargs):
)
partition = record.partition
span = _start_consume_span_with_extracted_context(
tracer, headers, topic
tracer, context_manager, instance, headers, topic
)
with trace.use_span(span):
_enrich_span(span, bootstrap_servers, topic, partition)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def test_wrap_send(
)
self.assertEqual(retval, original_send_callback.return_value)

@mock.patch("opentelemetry.trace.get_current_span")
@mock.patch("opentelemetry.trace.use_span")
@mock.patch(
"opentelemetry.instrumentation.kafka.utils._start_consume_span_with_extracted_context"
Expand All @@ -91,31 +90,34 @@ def test_wrap_next(
enrich_span: mock.MagicMock,
start_consume_span_with_extracted_context: mock.MagicMock,
use_span: mock.MagicMock,
get_current_span: mock.MagicMock,
) -> None:
tracer = mock.MagicMock()
consume_hook = mock.MagicMock()
original_next_callback = mock.MagicMock()
kafka_consumer = mock.MagicMock()
context_manager = mock.MagicMock()

wrapped_next = _wrap_next(tracer, consume_hook)
wrapped_next = _wrap_next(tracer, context_manager, consume_hook)
record = wrapped_next(
original_next_callback, kafka_consumer, self.args, self.kwargs
)

extract_bootstrap_servers.assert_called_once_with(kafka_consumer)
bootstrap_servers = extract_bootstrap_servers.return_value
get_current_span.assert_called_once()
current_span = get_current_span.return_value
current_span.end.assert_called_once()

context_manager.close.assert_called_once_with(kafka_consumer)

original_next_callback.assert_called_once_with(
*self.args, **self.kwargs
)
self.assertEqual(record, original_next_callback.return_value)

start_consume_span_with_extracted_context.assert_called_once_with(
tracer, record.headers, record.topic
tracer,
context_manager,
kafka_consumer,
record.headers,
record.topic,
)
span = start_consume_span_with_extracted_context.return_value
use_span.assert_called_once_with(span)
Expand All @@ -124,20 +126,24 @@ def test_wrap_next(
)
consume_hook.assert_called_once_with(span, self.args, self.kwargs)

@mock.patch("opentelemetry.context.attach")
@mock.patch("opentelemetry.trace.set_span_in_context")
@mock.patch("opentelemetry.propagate.extract")
def test_start_consume_span_with_extracted_context(
self,
extract: mock.MagicMock,
set_span_in_context: mock.MagicMock,
attach: mock.MagicMock,
):
tracer = mock.MagicMock()
context_manager = mock.MagicMock()
kafka_consumer = mock.MagicMock()
expected_span_name = _get_span_name("receive", self.topic_name)

_start_consume_span_with_extracted_context(
tracer, self.headers, self.topic_name
tracer,
context_manager,
kafka_consumer,
self.headers,
self.topic_name,
)

extract.assert_called_once_with(self.headers, _kafka_getter)
Expand All @@ -148,4 +154,6 @@ def test_start_consume_span_with_extracted_context(
span = tracer.start_span.return_value
set_span_in_context.assert_called_once_with(span, context)
new_context = set_span_in_context.return_value
attach.assert_called_once_with(new_context)
context_manager.set_consumer_context.assert_called_once_with(
kafka_consumer, new_context, span
)

0 comments on commit 0d1d010

Please sign in to comment.