diff --git a/ext/opentelemetry-ext-opentracing-shim/tests/test_shim.py b/ext/opentelemetry-ext-opentracing-shim/tests/test_shim.py index 7a0913f973b..abca2e052bc 100644 --- a/ext/opentelemetry-ext-opentracing-shim/tests/test_shim.py +++ b/ext/opentelemetry-ext-opentracing-shim/tests/test_shim.py @@ -463,9 +463,6 @@ def test_span_on_error(self): # Verify exception details have been added to span. self.assertEqual(scope.span.unwrap().attributes["error"], True) - self.assertEqual( - scope.span.unwrap().events[0].attributes["error.kind"], Exception - ) def test_inject_http_headers(self): """Test `inject()` method for Format.HTTP_HEADERS.""" diff --git a/opentelemetry-sdk/CHANGELOG.md b/opentelemetry-sdk/CHANGELOG.md index f771a2df346..476908cf91d 100644 --- a/opentelemetry-sdk/CHANGELOG.md +++ b/opentelemetry-sdk/CHANGELOG.md @@ -2,6 +2,8 @@ ## Unreleased +- Validate span attribute types in SDK (#678) + ## 0.7b1 Released 2020-05-12 diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py index 5eff5e61307..5b74b4a6186 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py @@ -17,7 +17,6 @@ import atexit import json import logging -import os import random import threading from collections import OrderedDict @@ -41,6 +40,7 @@ MAX_NUM_ATTRIBUTES = 32 MAX_NUM_EVENTS = 128 MAX_NUM_LINKS = 32 +VALID_ATTR_VALUE_TYPES = (bool, str, int, float) class SpanProcessor: @@ -189,6 +189,48 @@ def attributes(self) -> types.Attributes: return self._event_formatter() +def _is_valid_attribute_value(value: types.AttributeValue) -> bool: + """Checks if attribute value is valid. + + An attribute value is valid if it is one of the valid types. If the value + is a sequence, it is only valid if all items in the sequence are of valid + type, not a sequence, and are of the same type. + """ + + if isinstance(value, Sequence): + if len(value) == 0: + return True + + first_element_type = type(value[0]) + + if first_element_type not in VALID_ATTR_VALUE_TYPES: + logger.warning( + "Invalid type %s in attribute value sequence. Expected one of " + "%s or a sequence of those types", + first_element_type.__name__, + [valid_type.__name__ for valid_type in VALID_ATTR_VALUE_TYPES], + ) + return False + + for element in list(value)[1:]: + if not isinstance(element, first_element_type): + logger.warning( + "Mixed types %s and %s in attribute value sequence", + first_element_type.__name__, + type(element).__name__, + ) + return False + elif not isinstance(value, VALID_ATTR_VALUE_TYPES): + logger.warning( + "Invalid type %s for attribute value. Expected one of %s or a " + "sequence of those types", + type(value).__name__, + [valid_type.__name__ for valid_type in VALID_ATTR_VALUE_TYPES], + ) + return False + return True + + class Span(trace_api.Span): """See `opentelemetry.trace.Span`. @@ -245,7 +287,8 @@ def __init__( self.status = None self._lock = threading.Lock() - if attributes is None: + self._filter_attribute_values(attributes) + if not attributes: self.attributes = Span._empty_attributes else: self.attributes = BoundedDict.from_map( @@ -255,7 +298,10 @@ def __init__( if events is None: self.events = Span._empty_events else: - self.events = BoundedList.from_seq(MAX_NUM_EVENTS, events) + self.events = BoundedList(MAX_NUM_EVENTS) + for event in events: + self._filter_attribute_values(event.attributes) + self.events.append(event) if links is None: self.links = Span._empty_links @@ -372,37 +418,24 @@ def set_attribute(self, key: str, value: types.AttributeValue) -> None: logger.warning("invalid key (empty or null)") return - if isinstance(value, Sequence): - error_message = self._check_attribute_value_sequence(value) - if error_message is not None: - logger.warning("%s in attribute value sequence", error_message) - return + if _is_valid_attribute_value(value): # Freeze mutable sequences defensively if isinstance(value, MutableSequence): value = tuple(value) - elif not isinstance(value, (bool, str, int, float)): - logger.warning("invalid type for attribute value") - return - - self.attributes[key] = value + with self._lock: + self.attributes[key] = value @staticmethod - def _check_attribute_value_sequence(sequence: Sequence) -> Optional[str]: - """ - Checks if sequence items are valid and are of the same type - """ - if len(sequence) == 0: - return None - - first_element_type = type(sequence[0]) - - if first_element_type not in (bool, str, int, float): - return "invalid type" - - for element in sequence: - if not isinstance(element, first_element_type): - return "different type" - return None + def _filter_attribute_values(attributes: types.Attributes): + if attributes: + for attr_key, attr_value in list(attributes.items()): + if _is_valid_attribute_value(attr_value): + if isinstance(attr_value, MutableSequence): + attributes[attr_key] = tuple(attr_value) + else: + attributes[attr_key] = attr_value + else: + attributes.pop(attr_key) def _add_event(self, event: EventBase) -> None: with self._lock: @@ -423,7 +456,8 @@ def add_event( attributes: types.Attributes = None, timestamp: Optional[int] = None, ) -> None: - if attributes is None: + self._filter_attribute_values(attributes) + if not attributes: attributes = Span._empty_attributes self._add_event( Event( @@ -514,7 +548,6 @@ def __exit__( and self._set_status_on_exception and exc_val is not None ): - self.set_status( Status( canonical_code=StatusCanonicalCode.UNKNOWN, diff --git a/opentelemetry-sdk/tests/trace/test_trace.py b/opentelemetry-sdk/tests/trace/test_trace.py index 1094f1afb98..e468652ec00 100644 --- a/opentelemetry-sdk/tests/trace/test_trace.py +++ b/opentelemetry-sdk/tests/trace/test_trace.py @@ -487,38 +487,26 @@ def test_invalid_attribute_values(self): self.assertEqual(len(root.attributes), 0) - def test_check_sequence_helper(self): + def test_check_attribute_helper(self): # pylint: disable=protected-access - self.assertEqual( - trace.Span._check_attribute_value_sequence([1, 2, 3.4, "ss", 4]), - "different type", - ) - self.assertEqual( - trace.Span._check_attribute_value_sequence([dict(), 1, 2, 3.4, 4]), - "invalid type", - ) - self.assertEqual( - trace.Span._check_attribute_value_sequence( - ["sw", "lf", 3.4, "ss"] - ), - "different type", - ) - self.assertEqual( - trace.Span._check_attribute_value_sequence([1, 2, 3.4, 5]), - "different type", + self.assertFalse(trace._is_valid_attribute_value([1, 2, 3.4, "ss", 4])) + self.assertFalse( + trace._is_valid_attribute_value([dict(), 1, 2, 3.4, 4]) ) - self.assertIsNone( - trace.Span._check_attribute_value_sequence([1, 2, 3, 5]) - ) - self.assertIsNone( - trace.Span._check_attribute_value_sequence([1.2, 2.3, 3.4, 4.5]) - ) - self.assertIsNone( - trace.Span._check_attribute_value_sequence([True, False]) - ) - self.assertIsNone( - trace.Span._check_attribute_value_sequence(["ss", "dw", "fw"]) + self.assertFalse( + trace._is_valid_attribute_value(["sw", "lf", 3.4, "ss"]) ) + self.assertFalse(trace._is_valid_attribute_value([1, 2, 3.4, 5])) + self.assertTrue(trace._is_valid_attribute_value([1, 2, 3, 5])) + self.assertTrue(trace._is_valid_attribute_value([1.2, 2.3, 3.4, 4.5])) + self.assertTrue(trace._is_valid_attribute_value([True, False])) + self.assertTrue(trace._is_valid_attribute_value(["ss", "dw", "fw"])) + self.assertTrue(trace._is_valid_attribute_value([])) + self.assertFalse(trace._is_valid_attribute_value(dict())) + self.assertTrue(trace._is_valid_attribute_value(True)) + self.assertTrue(trace._is_valid_attribute_value("hi")) + self.assertTrue(trace._is_valid_attribute_value(3.4)) + self.assertTrue(trace._is_valid_attribute_value(15)) def test_sampling_attributes(self): decision_attributes = { @@ -561,33 +549,67 @@ def test_events(self): # event name and attributes now = time_ns() - root.add_event("event1", {"name": "pluto"}) + root.add_event( + "event1", {"name": "pluto", "some_bools": [True, False]} + ) # event name, attributes and timestamp now = time_ns() - root.add_event("event2", {"name": "birthday"}, now) + root.add_event("event2", {"name": ["birthday"]}, now) + + mutable_list = ["original_contents"] + root.add_event("event3", {"name": mutable_list}) def event_formatter(): return {"name": "hello"} # lazy event - root.add_lazy_event("event3", event_formatter, now) + root.add_lazy_event("event4", event_formatter, now) - self.assertEqual(len(root.events), 4) + self.assertEqual(len(root.events), 5) self.assertEqual(root.events[0].name, "event0") self.assertEqual(root.events[0].attributes, {}) self.assertEqual(root.events[1].name, "event1") - self.assertEqual(root.events[1].attributes, {"name": "pluto"}) + self.assertEqual( + root.events[1].attributes, + {"name": "pluto", "some_bools": (True, False)}, + ) self.assertEqual(root.events[2].name, "event2") - self.assertEqual(root.events[2].attributes, {"name": "birthday"}) + self.assertEqual( + root.events[2].attributes, {"name": ("birthday",)} + ) self.assertEqual(root.events[2].timestamp, now) self.assertEqual(root.events[3].name, "event3") - self.assertEqual(root.events[3].attributes, {"name": "hello"}) - self.assertEqual(root.events[3].timestamp, now) + self.assertEqual( + root.events[3].attributes, {"name": ("original_contents",)} + ) + mutable_list = ["new_contents"] + self.assertEqual( + root.events[3].attributes, {"name": ("original_contents",)} + ) + + self.assertEqual(root.events[4].name, "event4") + self.assertEqual(root.events[4].attributes, {"name": "hello"}) + self.assertEqual(root.events[4].timestamp, now) + + def test_invalid_event_attributes(self): + self.assertIsNone(self.tracer.get_current_span()) + + with self.tracer.start_as_current_span("root") as root: + root.add_event("event0", {"attr1": True, "attr2": ["hi", False]}) + root.add_event("event0", {"attr1": dict()}) + root.add_event("event0", {"attr1": [[True]]}) + root.add_event("event0", {"attr1": [dict()], "attr2": [1, 2]}) + + self.assertEqual(len(root.events), 4) + self.assertEqual(root.events[0].attributes, {"attr1": True}) + self.assertEqual(root.events[1].attributes, {}) + self.assertEqual(root.events[2].attributes, {}) + self.assertEqual(root.events[3].attributes, {"attr2": (1, 2)}) def test_links(self): other_context1 = trace_api.SpanContext(