diff --git a/ddapm_test_agent/client.py b/ddapm_test_agent/client.py new file mode 100644 index 00000000..2f749d14 --- /dev/null +++ b/ddapm_test_agent/client.py @@ -0,0 +1,138 @@ +import time +from typing import Any +from typing import List +import urllib.parse + +import requests + +from ddapm_test_agent.trace import Trace + + +class TestAgentClient: + __test__ = False + + def __init__(self, base_url: str): + self._base_url = base_url + self._session = requests.Session() + + def _url(self, path: str) -> str: + return urllib.parse.urljoin(self._base_url, path) + + def traces(self, clear=False, **kwargs: Any) -> List[Trace]: + resp = self._session.get(self._url("/test/session/traces"), **kwargs) + if clear: + self.clear() + json = resp.json() + return json + + def requests(self, **kwargs: Any) -> List[Any]: + resp = self._session.get(self._url("/test/session/requests"), **kwargs) + json = resp.json() + return json + + def raw_telemetry(self, clear=False) -> List[Any]: + raw_reqs = self.requests() + reqs = [] + for req in raw_reqs: + if req["url"].endswith("/telemetry/proxy/api/v2/apmtelemetry"): + reqs.append(req) + if clear: + self.clear() + return reqs + + def telemetry(self, clear=False, **kwargs): + resp = self._session.get(self._url("/test/session/apmtelemetry"), **kwargs) + if clear: + self.clear() + return resp.json() + + def clear(self, **kwargs) -> None: + self._session.get(self._url("/test/session/clear"), **kwargs) + + def info(self, **kwargs): + resp = self._session.get(self._url("/info"), **kwargs) + json = resp.json() + return json + + def wait_for_num_traces(self, num: int, clear: bool = False, wait_loops: int = 30) -> List[Trace]: + """Wait for `num` traces to be received from the test agent. + + Returns after the number of traces has been received or raises otherwise after 2 seconds of polling. + + Returned traces are sorted by the first span start time to simplify assertions for more than one trace by knowing that returned traces are in the same order as they have been created. + """ + num_received = 0 + traces = [] + for i in range(wait_loops): + try: + traces = self.traces(clear=False) + except requests.exceptions.RequestException: + pass + else: + num_received = len(traces) + if num_received == num: + if clear: + self.clear() + return sorted(traces, key=lambda trace: trace[0]["start"]) + time.sleep(0.1) + raise ValueError( + "Number (%r) of traces not available from test agent, got %r:\n%r" % (num, num_received, traces) + ) + + def wait_for_num_spans(self, num: int, clear: bool = False, wait_loops: int = 30) -> List[Trace]: + """Wait for `num` spans to be received from the test agent. + + Returns after the number of spans has been received or raises otherwise after 2 seconds of polling. + + Returned traces are sorted by the first span start time to simplify assertions for more than one trace by knowing that returned traces are in the same order as they have been created. + """ + num_received = None + for i in range(wait_loops): + try: + traces = self.traces(clear=False) + except requests.exceptions.RequestException: + pass + else: + num_received = 0 + for trace in traces: + num_received += len(trace) + if num_received == num: + if clear: + self.clear() + return sorted(traces, key=lambda trace: trace[0]["start"]) + time.sleep(0.1) + raise ValueError("Number (%r) of spans not available from test agent, got %r" % (num, num_received)) + + def wait_for_telemetry_event(self, event_name: str, clear: bool = False, wait_loops: int = 200) -> None: + """Wait for and return the given telemetry event from the test agent.""" + for i in range(wait_loops): + try: + events = self.telemetry(clear=False) + except requests.exceptions.RequestException: + pass + else: + for event in events: + if event["request_type"] == "message-batch": + for message in event["payload"]: + if message["request_type"] == event_name: + if clear: + self.clear() + return message + elif event["request_type"] == event_name: + if clear: + self.clear() + return event + time.sleep(0.01) + raise AssertionError("Telemetry event %r not found" % event_name) + + def wait_to_start(self, num_tries: int = 50, delay: float = 0.1) -> None: + exc = [] + for i in range(num_tries): + try: + self.info() + except requests.exceptions.RequestException as e: + exc.append(e) + time.sleep(delay) + else: + return + raise AssertionError(f"Test agent did not start in time ({num_tries * delay} seconds). Got {exc[-1]}") diff --git a/releasenotes/notes/client-173dbc6655e42337.yaml b/releasenotes/notes/client-173dbc6655e42337.yaml new file mode 100644 index 00000000..033ed1f4 --- /dev/null +++ b/releasenotes/notes/client-173dbc6655e42337.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Add a `client` module which provides an API client that can be used to + programmatically interface with the test agent. diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 00000000..a3e9bee8 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,32 @@ +from ddtrace import Tracer +import pytest + +from ddapm_test_agent.client import TestAgentClient + + +@pytest.fixture +async def client(testagent, testagent_url): + return TestAgentClient(testagent_url) + + +@pytest.fixture +async def tracer(testagent_url): + t = Tracer(testagent_url) + yield t + t.shutdown() + + +async def test_client_traces(client: TestAgentClient, tracer: Tracer) -> None: + assert client.traces() == [] + with tracer.trace("test"): + pass + traces = client.wait_for_num_traces(1) + assert len(traces) == 1 + + +async def test_client_requests(client: TestAgentClient, tracer: Tracer) -> None: + assert client.requests() == [] + with tracer.trace("test"): + pass + tracer.flush() + assert len(client.requests()) == 1