diff --git a/docs/EnvVars.md b/docs/EnvVars.md index ddaa2403..46754332 100644 --- a/docs/EnvVars.md +++ b/docs/EnvVars.md @@ -26,3 +26,4 @@ Environment Variable | Description | Default | `SW_KAFKA_REPORTER_TOPIC_MANAGEMENT` | Specifying Kafka topic name for service instance reporting and registering. | `skywalking-managements` | | `SW_KAFKA_REPORTER_TOPIC_SEGMENT` | Specifying Kafka topic name for Tracing data. | `skywalking-segments` | | `SW_KAFKA_REPORTER_CONFIG_key` | The configs to init KafkaProducer. it support the basic arguments (whose type is either `str`, `bool`, or `int`) listed [here](https://kafka-python.readthedocs.io/en/master/apidoc/KafkaProducer.html#kafka.KafkaProducer) | unset | +| `SW_CELERY_PARAMETERS_LENGTH`| The maximum length of `celery` functions parameters, longer than this will be truncated, 0 turns off | `512` | diff --git a/docs/Plugins.md b/docs/Plugins.md index 28688e04..23f74734 100644 --- a/docs/Plugins.md +++ b/docs/Plugins.md @@ -3,7 +3,7 @@ Library | Versions | Plugin Name | :--- | :--- | :--- | | [http.server](https://docs.python.org/3/library/http.server.html) | Python 3.5 ~ 3.9 | `sw_http_server` | -| [urllib.request](https://docs.python.org/3/library/urllib.request.html) | Python 3.5 ~ 3.8 | `sw_urllib_request` | +| [urllib.request](https://docs.python.org/3/library/urllib.request.html) | Python 3.5 ~ 3.9 | `sw_urllib_request` | | [requests](https://requests.readthedocs.io/en/master/) | >= 2.9.0 < 2.15.0, >= 2.17.0 <= 2.24.0 | `sw_requests` | | [Flask](https://flask.palletsprojects.com/en/1.1.x/) | >=1.0.4 <= 1.1.2 | `sw_flask` | | [PyMySQL](https://pymysql.readthedocs.io/en/latest/) | 0.10.0 | `sw_pymysql` | @@ -18,6 +18,9 @@ Library | Versions | Plugin Name | [sanic](https://sanic.readthedocs.io/en/latest/) | >= 20.3.0 <= 20.9.1 | `sw_sanic` | | [aiohttp](https://sanic.readthedocs.io/en/latest/) | >= 3.7.3 | `sw_aiohttp` | | [pyramid](https://trypyramid.com) | >= 1.9 | `sw_pyramid` | -| [psycopg2](https://www.psycopg.org/) | 2.8.6 | `sw_psycopg2` | +| [psycopg2](https://www.psycopg.org/) | >= 2.8.6 | `sw_psycopg2` | +| [celery](https://docs.celeryproject.org/) | >= 4.2.1 | `sw_celery` | + +* Note: The celery server running with "celery -A ..." should be run with the http protocol as it uses multiprocessing by default which is not compatible with the grpc protocol implementation in skywalking currently. Celery clients can use whatever protocol they want. The column `Versions` only indicates that the versions are tested, if you found the newer versions are also supported, welcome to add the newer version into the table. diff --git a/requirements.txt b/requirements.txt index 91a06d2e..71d5e803 100755 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ aiofiles==0.6.0 aiohttp==3.7.3 attrs==19.3.0 blindspin==2.0.1 +celery==4.4.7 certifi==2020.6.20 chardet==3.0.4 click==7.1.2 diff --git a/setup.py b/setup.py index ec24ee26..779baf12 100644 --- a/setup.py +++ b/setup.py @@ -33,12 +33,13 @@ author="Apache", author_email="dev@skywalking.apache.org", license="Apache 2.0", - packages=find_packages(exclude=("tests",)), + packages=find_packages(exclude=("tests", "tests.*")), include_package_data=True, install_requires=[ "grpcio", "grpcio-tools", "packaging", + "requests", "wrapt", ], extras_require={ diff --git a/skywalking/__init__.py b/skywalking/__init__.py index 75f9db2a..95461ceb 100644 --- a/skywalking/__init__.py +++ b/skywalking/__init__.py @@ -42,6 +42,7 @@ class Component(Enum): AioHttp = 7008 Pyramid = 7009 Psycopg = 7010 + Celery = 7011 class Layer(Enum): diff --git a/skywalking/agent/__init__.py b/skywalking/agent/__init__.py index 356a28a3..64ca4c1f 100644 --- a/skywalking/agent/__init__.py +++ b/skywalking/agent/__init__.py @@ -16,6 +16,7 @@ # import atexit +import os from queue import Queue, Full from threading import Thread, Event from typing import TYPE_CHECKING @@ -28,6 +29,11 @@ from skywalking.trace.context import Segment +__started = False +__protocol = Protocol() # type: Protocol +__heartbeat_thread = __report_thread = __queue = __finished = None + + def __heartbeat(): while not __finished.is_set(): if connected(): @@ -39,21 +45,26 @@ def __heartbeat(): def __report(): while not __finished.is_set(): if connected(): - __protocol.report(__queue) # is blocking actually + __protocol.report(__queue) # is blocking actually, blocks for max config.QUEUE_TIMEOUT seconds __finished.wait(1) -__heartbeat_thread = Thread(name='HeartbeatThread', target=__heartbeat, daemon=True) -__report_thread = Thread(name='ReportThread', target=__report, daemon=True) -__queue = Queue(maxsize=10000) -__finished = Event() -__protocol = Protocol() # type: Protocol -__started = False +def __init_threading(): + global __heartbeat_thread, __report_thread, __queue, __finished + + __queue = Queue(maxsize=10000) + __finished = Event() + __heartbeat_thread = Thread(name='HeartbeatThread', target=__heartbeat, daemon=True) + __report_thread = Thread(name='ReportThread', target=__report, daemon=True) + + __heartbeat_thread.start() + __report_thread.start() def __init(): global __protocol + if config.protocol == 'grpc': from skywalking.agent.protocol.grpc import GrpcProtocol __protocol = GrpcProtocol() @@ -65,14 +76,40 @@ def __init(): __protocol = KafkaProtocol() plugins.install() + __init_threading() def __fini(): __protocol.report(__queue, False) __queue.join() + __finished.set() + + +def __fork_before(): + if config.protocol != 'http': + logger.warning('fork() not currently supported with %s protocol' % config.protocol) + + # TODO: handle __queue and __finished correctly (locks, mutexes, etc...), need to lock before fork and unlock after + # if possible, or ensure they are not locked in threads (end threads and restart after fork?) + + __protocol.fork_before() + + +def __fork_after_in_parent(): + __protocol.fork_after_in_parent() + + +def __fork_after_in_child(): + __protocol.fork_after_in_child() + __init_threading() def start(): + global __started + if __started: + return + __started = True + flag = False try: from gevent import monkey @@ -82,22 +119,22 @@ def start(): if flag: import grpc.experimental.gevent as grpc_gevent grpc_gevent.init_gevent() - global __started - if __started: - raise RuntimeError('the agent can only be started once') + loggings.init() config.finalize() - __started = True + __init() - __heartbeat_thread.start() - __report_thread.start() + atexit.register(__fini) + if (hasattr(os, 'register_at_fork')): + os.register_at_fork(before=__fork_before, after_in_parent=__fork_after_in_parent, + after_in_child=__fork_after_in_child) + def stop(): atexit.unregister(__fini) __fini() - __finished.set() def started(): diff --git a/skywalking/agent/protocol/__init__.py b/skywalking/agent/protocol/__init__.py index 0f6e62e5..3202734c 100644 --- a/skywalking/agent/protocol/__init__.py +++ b/skywalking/agent/protocol/__init__.py @@ -20,8 +20,17 @@ class Protocol(ABC): + def fork_before(self): + pass + + def fork_after_in_parent(self): + pass + + def fork_after_in_child(self): + pass + def connected(self): - raise NotImplementedError() + return False def heartbeat(self): raise NotImplementedError() diff --git a/skywalking/agent/protocol/http.py b/skywalking/agent/protocol/http.py index 89d43bfb..809d1f8a 100644 --- a/skywalking/agent/protocol/http.py +++ b/skywalking/agent/protocol/http.py @@ -17,7 +17,9 @@ from skywalking.loggings import logger from queue import Queue, Empty +from time import time +from skywalking import config from skywalking.agent import Protocol from skywalking.client.http import HttpServiceManagementClient, HttpTraceSegmentReportService from skywalking.trace.segment import Segment @@ -29,20 +31,27 @@ def __init__(self): self.service_management = HttpServiceManagementClient() self.traces_reporter = HttpTraceSegmentReportService() + def fork_after_in_child(self): + self.service_management.fork_after_in_child() + self.traces_reporter.fork_after_in_child() + + def connected(self): + return True + def heartbeat(self): if not self.properties_sent: self.service_management.send_instance_props() self.properties_sent = True self.service_management.send_heart_beat() - def connected(self): - return True - def report(self, queue: Queue, block: bool = True): + start = time() + def generator(): while True: try: - segment = queue.get(block=block) # type: Segment + timeout = max(0, config.QUEUE_TIMEOUT - int(time() - start)) # type: int + segment = queue.get(block=block, timeout=timeout) # type: Segment except Empty: return @@ -52,4 +61,7 @@ def generator(): queue.task_done() - self.traces_reporter.report(generator=generator()) + try: + self.traces_reporter.report(generator=generator()) + except Exception: + pass diff --git a/skywalking/client/http.py b/skywalking/client/http.py index 87c1c082..a614a006 100644 --- a/skywalking/client/http.py +++ b/skywalking/client/http.py @@ -25,10 +25,14 @@ class HttpServiceManagementClient(ServiceManagementClient): def __init__(self): - self.session = requests.session() + self.session = requests.Session() + + def fork_after_in_child(self): + self.session.close() + self.session = requests.Session() def send_instance_props(self): - url = config.collector_address.rstrip('/') + '/v3/management/reportProperties' + url = 'http://' + config.collector_address.rstrip('/') + '/v3/management/reportProperties' res = self.session.post(url, json={ 'service': config.service_name, 'serviceInstance': config.service_instance, @@ -44,7 +48,7 @@ def send_heart_beat(self): config.service_name, config.service_instance, ) - url = config.collector_address.rstrip('/') + '/v3/management/keepAlive' + url = 'http://' + config.collector_address.rstrip('/') + '/v3/management/keepAlive' res = self.session.post(url, json={ 'service': config.service_name, 'serviceInstance': config.service_instance, @@ -54,10 +58,14 @@ def send_heart_beat(self): class HttpTraceSegmentReportService(TraceSegmentReportService): def __init__(self): - self.session = requests.session() + self.session = requests.Session() + + def fork_after_in_child(self): + self.session.close() + self.session = requests.Session() def report(self, generator): - url = config.collector_address.rstrip('/') + '/v3/segment' + url = 'http://' + config.collector_address.rstrip('/') + '/v3/segment' for segment in generator: res = self.session.post(url, json={ 'traceId': str(segment.related_traces[0]), @@ -76,10 +84,10 @@ def report(self, generator): 'componentId': span.component.value, 'isError': span.error_occurred, 'logs': [{ - 'time': log.timestamp * 1000, + 'time': int(log.timestamp * 1000), 'data': [{ 'key': item.key, - 'value': item.val + 'value': item.val, } for item in log.items], } for log in span.logs], 'tags': [{ diff --git a/skywalking/config.py b/skywalking/config.py index b8a8027b..b1ff2e43 100644 --- a/skywalking/config.py +++ b/skywalking/config.py @@ -59,13 +59,14 @@ kafka_bootstrap_servers = os.getenv('SW_KAFKA_REPORTER_BOOTSTRAP_SERVERS') or "localhost:9092" # type: str kafka_topic_management = os.getenv('SW_KAFKA_REPORTER_TOPIC_MANAGEMENT') or "skywalking-managements" # type: str kafka_topic_segment = os.getenv('SW_KAFKA_REPORTER_TOPIC_SEGMENT') or "skywalking-segments" # type: str +celery_parameters_length = int(os.getenv('SW_CELERY_PARAMETERS_LENGTH') or '512') def init( service: str = None, instance: str = None, collector: str = None, - protocol_type: str = 'grpc', + protocol_type: str = None, token: str = None, ): global service_name diff --git a/skywalking/plugins/sw_celery.py b/skywalking/plugins/sw_celery.py new file mode 100644 index 00000000..a7cff41a --- /dev/null +++ b/skywalking/plugins/sw_celery.py @@ -0,0 +1,114 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from skywalking import Layer, Component, config +from skywalking.trace import tags +from skywalking.trace.carrier import Carrier +from skywalking.trace.context import get_context +from skywalking.trace.tags import Tag + + +def install(): + from urllib.parse import urlparse + from celery import Celery + + def send_task(self, name, args=None, kwargs=None, **options): + # NOTE: Lines commented out below left for documentation purposes if sometime in the future exchange / queue + # names are wanted. Currently these do not match between producer and consumer so would need some work. + + broker_url = self.conf['broker_url'] + # exchange = options['exchange'] + # queue = options['routing_key'] + # op = 'celery/{}/{}/{}'.format(exchange or '', queue or '', name) + op = 'celery/' + name + + if broker_url: + url = urlparse(broker_url) + peer = '{}:{}'.format(url.hostname, url.port) + else: + peer = '???' + + with get_context().new_exit_span(op=op, peer=peer) as span: + span.layer = Layer.MQ + span.component = Component.Celery + + span.tag(Tag(key=tags.MqBroker, val=broker_url)) + # span.tag(Tag(key=tags.MqTopic, val=exchange)) + # span.tag(Tag(key=tags.MqQueue, val=queue)) + + if config.celery_parameters_length: + params = '*{}, **{}'.format(args, kwargs)[:config.celery_parameters_length] + span.tag(Tag(key=tags.CeleryParameters, val=params)) + + options = {**options} + headers = options.get('headers') + headers = {**headers} if headers else {} + options['headers'] = headers + + for item in span.inject(): + headers[item.key] = item.val + + return _send_task(self, name, args, kwargs, **options) + + _send_task = Celery.send_task + Celery.send_task = send_task + + def task_from_fun(self, _fun, name=None, **options): + def fun(*args, **kwargs): + req = task.request_stack.top + # di = req.get('delivery_info') + # exchange = di and di.get('exchange') + # queue = di and di.get('routing_key') + # op = 'celery/{}/{}/{}'.format(exchange or '', queue or '', name) + op = 'celery/' + name + carrier = Carrier() + + for item in carrier: + val = req.get(item.key) + + if val: + item.val = val + + context = get_context() + + if req.get('sw8'): + span = context.new_entry_span(op=op, carrier=carrier) + span.peer = (req.get('hostname') or '???').split('@', 1)[-1] + else: + span = context.new_local_span(op=op) + + with span: + span.layer = Layer.MQ + span.component = Component.Celery + + span.tag(Tag(key=tags.MqBroker, val=task.app.conf['broker_url'])) + # span.tag(Tag(key=tags.MqTopic, val=exchange)) + # span.tag(Tag(key=tags.MqQueue, val=queue)) + + if config.celery_parameters_length: + params = '*{}, **{}'.format(args, kwargs)[:config.celery_parameters_length] + span.tag(Tag(key=tags.CeleryParameters, val=params)) + + return _fun(*args, **kwargs) + + name = name or self.gen_task_name(_fun.__name__, _fun.__module__) + task = _task_from_fun(self, fun, name, **options) + + return task + + _task_from_fun = Celery._task_from_fun + Celery._task_from_fun = task_from_fun diff --git a/skywalking/plugins/sw_sanic.py b/skywalking/plugins/sw_sanic.py index 1ea62990..2cf68bfb 100644 --- a/skywalking/plugins/sw_sanic.py +++ b/skywalking/plugins/sw_sanic.py @@ -27,7 +27,7 @@ version_rule = { "name": "sanic", - "rules": [">=20.3.0"] + "rules": [">=20.3.0", "<21.0.0"] } diff --git a/skywalking/trace/context.py b/skywalking/trace/context.py index 51dffa36..f7b74403 100644 --- a/skywalking/trace/context.py +++ b/skywalking/trace/context.py @@ -119,7 +119,7 @@ def new_exit_span(self, op: str, peer: str) -> Span: spans = _spans_dup() parent = spans[-1] if spans else None # type: Span - span = parent if parent is not None and parent.kind.is_exit else ExitSpan( + span = ExitSpan( context=self, sid=self._sid.next(), pid=parent.sid if parent else -1, diff --git a/skywalking/trace/tags.py b/skywalking/trace/tags.py index 57a389d2..f7c9abd9 100644 --- a/skywalking/trace/tags.py +++ b/skywalking/trace/tags.py @@ -31,3 +31,4 @@ MqBroker = 'mq.broker' MqTopic = 'mq.topic' MqQueue = 'mq.queue' +CeleryParameters = 'celery.parameters'