Skip to content

Commit

Permalink
refactor: explicitly deserialize messages in event consumer (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rebecca Graber authored May 5, 2023
1 parent 5363210 commit 6712aa7
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 40 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Unreleased

* Switch from ``edx-sphinx-theme`` to ``sphinx-book-theme`` since the former is
deprecated
* Refactored consumer to manually deserialize messages instead of using DeserializingConsumer

[3.9.6] - 2023-02-24
********************
Expand Down
86 changes: 64 additions & 22 deletions edx_event_bus_kafka/internal/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import time
import warnings
from datetime import datetime
from functools import lru_cache

from django.conf import settings
from django.core.management.base import BaseCommand
from django.db import connection
from django.dispatch import receiver
from django.test.signals import setting_changed
from edx_django_utils.monitoring import function_trace, record_exception, set_custom_attribute
from edx_toggles.toggles import SettingToggle
from openedx_events.event_bus.avro.deserializer import AvroSignalDeserializer
Expand All @@ -29,9 +32,10 @@
# See https://github.com/openedx/event-bus-kafka/blob/main/docs/decisions/0005-optional-import-of-confluent-kafka.rst
try:
import confluent_kafka
from confluent_kafka import TIMESTAMP_NOT_AVAILABLE, DeserializingConsumer
from confluent_kafka import TIMESTAMP_NOT_AVAILABLE, Consumer
from confluent_kafka.error import KafkaError
from confluent_kafka.schema_registry.avro import AvroDeserializer
from confluent_kafka.serialization import MessageField, SerializationContext
except ImportError: # pragma: no cover
confluent_kafka = None

Expand Down Expand Up @@ -116,40 +120,28 @@ def __init__(self, topic, group_id, signal):
self.signal = signal
self.consumer = self._create_consumer()
self._shut_down_loop = False
self.schema_registry_client = get_schema_registry_client()

# return type (Optional[DeserializingConsumer]) removed from signature to avoid error on import
# return type Consumer removed from signature to avoid error on import
def _create_consumer(self):
"""
Create a DeserializingConsumer for events of the given signal instance.
Create a Consumer in the correct consumer group
Returns
None if confluent_kafka is not available.
DeserializingConsumer if it is.
Consumer in the configured consumer group
"""

schema_registry_client = get_schema_registry_client()

signal_deserializer = AvroSignalDeserializer(self.signal)

def inner_from_dict(event_data_dict, ctx=None): # pylint: disable=unused-argument
return signal_deserializer.from_dict(event_data_dict)

consumer_config = load_common_settings()

# We do not deserialize the key because we don't need it for anything yet.
# Also see https://github.com/openedx/openedx-events/issues/86 for some challenges on determining key schema.
consumer_config.update({
'group.id': self.group_id,
'value.deserializer': AvroDeserializer(schema_str=signal_deserializer.schema_string(),
schema_registry_client=schema_registry_client,
from_dict=inner_from_dict),
# Turn off auto commit. Auto commit will commit offsets for the entire batch of messages received,
# potentially resulting in data loss if some of those messages are not fully processed. See
# https://newrelic.com/blog/best-practices/kafka-consumer-config-auto-commit-data-loss
'enable.auto.commit': False,
})

return DeserializingConsumer(consumer_config)
return Consumer(consumer_config)

def _shut_down(self):
"""
Expand Down Expand Up @@ -276,7 +268,7 @@ def _consume_indefinitely(self):
with function_trace('_consume_indefinitely_consume_single_message'):
# Before processing, make sure our db connection is still active
_reconnect_to_db_if_needed()

msg.set_value(self._deserialize_message_value(msg))
self.emit_signals_from_message(msg)
consecutive_errors = 0

Expand Down Expand Up @@ -381,6 +373,20 @@ def emit_signals_from_message(self, msg):
if AUDIT_LOGGING_ENABLED.is_enabled():
logger.info('Message from Kafka processed successfully')

def _deserialize_message_value(self, msg):
"""
Deserialize an Avro message value
Arguments:
msg (Message): the raw message from the consumer
Returns:
The deserialized message value
"""
signal_deserializer = get_deserializer(self.signal, self.schema_registry_client)
ctx = SerializationContext(msg.topic(), MessageField.VALUE, msg.headers())
return signal_deserializer(msg.value(), ctx)

def _check_receiver_results(self, send_results: list):
"""
Raises exception if any of the receivers produced an exception.
Expand All @@ -390,16 +396,16 @@ def _check_receiver_results(self, send_results: list):
"""
error_descriptions = []
errors = []
for receiver, response in send_results:
for signal_receiver, response in send_results:
if not isinstance(response, BaseException):
continue

# Probably every receiver will be a regular function or even a lambda with
# these attrs, so this check is just to be safe.
try:
receiver_name = f"{receiver.__module__}.{receiver.__qualname__}"
receiver_name = f"{signal_receiver.__module__}.{signal_receiver.__qualname__}"
except AttributeError:
receiver_name = str(receiver)
receiver_name = str(signal_receiver)

# The stack traces are already logged by django.dispatcher, so just the error message is fine.
error_descriptions.append(f"{receiver_name}={response!r}")
Expand Down Expand Up @@ -649,3 +655,39 @@ def handle(self, *args, **options):
event_consumer.reset_offsets_and_sleep_indefinitely(offset_timestamp=offset_timestamp)
except Exception: # pylint: disable=broad-except
logger.exception("Error consuming Kafka events")


# argument type SchemaRegistryClient for schema_registry_client removed from signature to avoid error on import
@lru_cache
def get_deserializer(signal: OpenEdxPublicSignal, schema_registry_client):
"""
Get the value deserializer for a signal.
This is cached in order to save work re-transforming classes into Avro schemas.
We do not deserialize the key because we don't need it for anything yet.
Also see https://github.com/openedx/openedx-events/issues/86 for some challenges on determining key schema.
Arguments:
signal: The OpenEdxPublicSignal to make a deserializer for.
schema_registry_client: The SchemaRegistryClient instance for the consumer
Returns:
AvroSignalDeserializer for event value
"""
if schema_registry_client is None:
raise Exception('Cannot create Kafka deserializer -- missing library or settings')

signal_deserializer = AvroSignalDeserializer(signal)

def inner_from_dict(event_data_dict, ctx=None): # pylint: disable=unused-argument
return signal_deserializer.from_dict(event_data_dict)

return AvroDeserializer(schema_str=signal_deserializer.schema_string(),
schema_registry_client=schema_registry_client,
from_dict=inner_from_dict)


@receiver(setting_changed)
def _reset_caches(sender, **kwargs): # pylint: disable=unused-argument
"""Reset caches when settings change during unit tests."""
get_deserializer.cache_clear()
73 changes: 56 additions & 17 deletions edx_event_bus_kafka/internal/tests/test_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from openedx_events.learning.data import UserData, UserPersonalData
from openedx_events.learning.signals import SESSION_LOGIN_COMPLETED

from edx_event_bus_kafka.internal.consumer import KafkaEventConsumer, ReceiverError, UnusableMessageError
from edx_event_bus_kafka.internal.consumer import (
KafkaEventConsumer,
ReceiverError,
UnusableMessageError,
get_deserializer,
)
from edx_event_bus_kafka.internal.tests.test_utils import FakeMessage, side_effects
from edx_event_bus_kafka.management.commands.consume_events import Command

Expand Down Expand Up @@ -46,6 +51,7 @@ def fake_receiver_raises_error(**kwargs):
EVENT_BUS_KAFKA_BOOTSTRAP_SERVERS='bootstrap-servers',
EVENT_BUS_KAFKA_API_KEY='test-key',
EVENT_BUS_KAFKA_API_SECRET='test-secret',
EVENT_BUS_KAFKA_CONSUMER_CONSECUTIVE_ERRORS_LIMIT=10, # prevent infinite looping if tests are broken
)
@ddt.ddt
class TestEmitSignals(TestCase):
Expand All @@ -66,6 +72,9 @@ def setUp(self):
)
)
}
# determined by manual testing
self.normal_event_data_bytes = b'\x00\x00\x01\x86\xb3\xf6\x01\x01\x0cfoobob\x1ebob@foo.example\x0eBob Foo'

self.message_id = uuid1()
self.message_id_bytes = str(self.message_id).encode('utf-8')

Expand All @@ -80,7 +89,7 @@ def setUp(self):
('ce_type', self.signal_type_bytes),
],
key=b'\x00\x00\x00\x00\x01\x0cfoobob', # Avro, as observed in manual test
value=self.normal_event_data,
value=self.normal_event_data_bytes,
error=None,
timestamp=(TIMESTAMP_CREATE_TIME, 1675114920123),
)
Expand Down Expand Up @@ -177,9 +186,11 @@ def raise_exception():
self.event_consumer, 'emit_signals_from_message',
side_effect=side_effects(mock_emit_side_effects),
) as mock_emit:
mock_consumer = Mock(**{'poll.return_value': self.normal_message}, autospec=True)
self.event_consumer.consumer = mock_consumer
self.event_consumer.consume_indefinitely()
with patch('edx_event_bus_kafka.internal.consumer.AvroDeserializer',
return_value=lambda _x, _y: self.normal_event_data):
mock_consumer = Mock(**{'poll.return_value': self.normal_message}, autospec=True)
self.event_consumer.consumer = mock_consumer
self.event_consumer.consume_indefinitely()

# Check that each of the mocked out methods got called as expected.
mock_consumer.subscribe.assert_called_once_with(['local-some-topic'])
Expand Down Expand Up @@ -236,10 +247,12 @@ def raise_exception():
self.event_consumer, 'emit_signals_from_message',
side_effect=side_effects([raise_exception] * exception_count)
) as mock_emit:
mock_consumer = Mock(**{'poll.return_value': self.normal_message}, autospec=True)
self.event_consumer.consumer = mock_consumer
with pytest.raises(Exception) as exc_info:
self.event_consumer.consume_indefinitely()
with patch('edx_event_bus_kafka.internal.consumer.AvroDeserializer',
return_value=lambda _x, _y: self.normal_event_data):
mock_consumer = Mock(**{'poll.return_value': self.normal_message}, autospec=True)
self.event_consumer.consumer = mock_consumer
with pytest.raises(Exception) as exc_info:
self.event_consumer.consume_indefinitely()

assert mock_emit.call_args_list == [call(self.normal_message)] * exception_count
assert exc_info.value.args == ("Too many consecutive errors, exiting (4 in a row)",)
Expand All @@ -265,9 +278,11 @@ def test_connection_reset(self, has_connection, is_usable, reconnect_expected, m
self.event_consumer, 'emit_signals_from_message',
side_effect=side_effects([self.event_consumer._shut_down]) # pylint: disable=protected-access
):
mock_consumer = Mock(**{'poll.return_value': self.normal_message}, autospec=True)
self.event_consumer.consumer = mock_consumer
self.event_consumer.consume_indefinitely()
with patch('edx_event_bus_kafka.internal.consumer.AvroDeserializer',
return_value=lambda _x, _y: self.normal_event_data):
mock_consumer = Mock(**{'poll.return_value': self.normal_message}, autospec=True)
self.event_consumer.consumer = mock_consumer
self.event_consumer.consume_indefinitely()

if reconnect_expected:
mock_connection.connect.assert_called_once()
Expand Down Expand Up @@ -311,9 +326,11 @@ def raise_exception():
self.event_consumer, 'emit_signals_from_message',
side_effect=side_effects(mock_emit_side_effects)
) as mock_emit:
mock_consumer = Mock(**{'poll.return_value': self.normal_message}, autospec=True)
self.event_consumer.consumer = mock_consumer
self.event_consumer.consume_indefinitely() # exits normally
with patch('edx_event_bus_kafka.internal.consumer.AvroDeserializer',
return_value=lambda _x, _y: self.normal_event_data):
mock_consumer = Mock(**{'poll.return_value': self.normal_message}, autospec=True)
self.event_consumer.consumer = mock_consumer
self.event_consumer.consume_indefinitely() # exits normally

assert mock_emit.call_args_list == [call(self.normal_message)] * len(mock_emit_side_effects)

Expand Down Expand Up @@ -433,6 +450,8 @@ def test_check_event_error(self):
@ddt.data(True, False)
def test_emit_success(self, audit_logging, mock_logger, mock_set_attribute):
self.signal.disconnect(fake_receiver_raises_error) # just successes for this one!
# assume we've already deserialized the data
self.normal_message.set_value(self.normal_event_data)

with override_settings(EVENT_BUS_KAFKA_AUDIT_LOGGING_ENABLED=audit_logging):
self.event_consumer.emit_signals_from_message(self.normal_message)
Expand All @@ -458,6 +477,8 @@ def test_emit_success(self, audit_logging, mock_logger, mock_set_attribute):
@patch('edx_event_bus_kafka.internal.consumer.logger', autospec=True)
def test_emit_success_tolerates_missing_timestamp(self, mock_logger, mock_set_attribute):
self.signal.disconnect(fake_receiver_raises_error) # just successes for this one!
# assume we've already deserialized the data
self.normal_message.set_value(self.normal_event_data)
self.normal_message._timestamp = (TIMESTAMP_NOT_AVAILABLE, None) # pylint: disable=protected-access

self.event_consumer.emit_signals_from_message(self.normal_message)
Expand All @@ -475,6 +496,8 @@ def test_emit_success_tolerates_missing_timestamp(self, mock_logger, mock_set_at

@patch('django.dispatch.dispatcher.logger', autospec=True)
def test_emit(self, mock_logger):
# assume we've already deserialized the data
self.normal_message.set_value(self.normal_event_data)
with pytest.raises(ReceiverError) as exc_info:
self.event_consumer.emit_signals_from_message(self.normal_message)
self.assert_signal_sent_with(self.signal, self.normal_event_data)
Expand Down Expand Up @@ -520,6 +543,8 @@ def test_malformed_receiver_errors(self):

def test_no_type(self):
msg = copy.copy(self.normal_message)
# assume we've already deserialized the data
msg.set_value(self.normal_event_data)
msg._headers = [] # pylint: disable=protected-access

with pytest.raises(UnusableMessageError) as excinfo:
Expand All @@ -535,6 +560,8 @@ def test_multiple_types(self):
Very unlikely case, but this gets us coverage.
"""
msg = copy.copy(self.normal_message)
# assume we've already deserialized the data
msg.set_value(self.normal_event_data)
msg._headers = [['ce_type', b'abc'], ['ce_type', b'def']] # pylint: disable=protected-access

with pytest.raises(UnusableMessageError) as excinfo:
Expand All @@ -547,6 +574,8 @@ def test_multiple_types(self):

def test_unexpected_signal_type_in_header(self):
msg = copy.copy(self.normal_message)
# assume we've already deserialized the data
msg.set_value(self.normal_event_data)
msg._headers = [ # pylint: disable=protected-access
['ce_type', b'xxxx']
]
Expand Down Expand Up @@ -581,19 +610,29 @@ def test_bad_headers(self):
The various kinds of bad headers are more fully tested in test_utils
"""
self.normal_message._headers = [ # pylint: disable=protected-access
msg = copy.copy(self.normal_message)
# assume we've already deserialized the data
msg.set_value(self.normal_event_data)
msg._headers = [ # pylint: disable=protected-access
('ce_type', b'org.openedx.learning.auth.session.login.completed.v1'),
('ce_id', b'bad_id')
]
with pytest.raises(UnusableMessageError) as excinfo:
self.event_consumer.emit_signals_from_message(self.normal_message)
self.event_consumer.emit_signals_from_message(msg)

assert excinfo.value.args == (
"Error determining metadata from message headers: badly formed hexadecimal UUID string",
)

assert not self.mock_receiver.called

def test_no_deserializer_if_no_registry_client(self):
with pytest.raises(Exception) as excinfo:
get_deserializer(self.signal, None)
assert excinfo.value.args == (
"Cannot create Kafka deserializer -- missing library or settings",
)


class TestCommand(TestCase):
"""
Expand Down
5 changes: 4 additions & 1 deletion edx_event_bus_kafka/internal/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,12 @@ def key(self) -> Optional[bytes]:
return self._key

def value(self):
"""Deserialized event value."""
"""Event value (bytes or object)"""
return self._value

def set_value(self, value):
self._value = value

def error(self):
return self._error

Expand Down

0 comments on commit 6712aa7

Please sign in to comment.