Skip to content
This repository has been archived by the owner on Dec 10, 2018. It is now read-only.

Commit

Permalink
Merge pull request #96 from maralla/tracking
Browse files Browse the repository at this point in the history
add tracking support
  • Loading branch information
lxyu committed Mar 11, 2015
2 parents d24912d + e14a63d commit 67acce2
Show file tree
Hide file tree
Showing 10 changed files with 545 additions and 13 deletions.
6 changes: 3 additions & 3 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
include README.rst CHANGES.rst
include thriftpy/protocol/cybin/*.pyx thriftpy/protocol/cybin/*.c thriftpy/protocol/cybin/*.h
include thriftpy/transport/*/*.pyx thriftpy/transport/*/*.c
include thriftpy/transport/*.pyx thriftpy/transport/*.pxd thriftpy/transport/*.c
recursive-include thriftpy/protocol/cybin *.pyx *.c *.h
recursive-include thriftpy/transport *.pyx *.pxd *.c
include thriftpy/contrib/tracking/tracking.thrift
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
author="Lx Yu",
author_email="i@lxyu.net",
packages=find_packages(exclude=['benchmark', 'docs', 'tests']),
package_data={"thriftpy": ["contrib/tracking/tracking.thrift"]},
entry_points={},
url="https://thriftpy.readthedocs.org/",
license="MIT",
Expand Down
12 changes: 6 additions & 6 deletions tests/test_protocol_cybinary.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

class TItem(TPayload):
thrift_spec = {
1: (TType.I32, "id"),
2: (TType.LIST, "phones", (TType.STRING)),
1: (TType.I32, "id", False),
2: (TType.LIST, "phones", TType.STRING, False),
}
default_spec = [("id", None), ("phones", None)]

Expand Down Expand Up @@ -214,8 +214,8 @@ def test_read_huge_args():

class Hello(TPayload):
thrift_spec = {
1: (TType.STRING, "name"),
2: (TType.STRING, "world"),
1: (TType.STRING, "name", False),
2: (TType.STRING, "world", False),
}
default_spec = [("name", None), ("world", None)]

Expand Down Expand Up @@ -346,8 +346,8 @@ def test_read_wrong_arg_type():

class TWrongTypeItem(TPayload):
thrift_spec = {
1: (TType.STRING, "id"),
2: (TType.LIST, "phones", (TType.STRING)),
1: (TType.STRING, "id", False),
2: (TType.LIST, "phones", TType.STRING, False),
}
default_spec = [("id", None), ("phones", None)]

Expand Down
290 changes: 290 additions & 0 deletions tests/test_tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
# -*- coding: utf-8 -*-

from __future__ import absolute_import

import contextlib
import os
import multiprocessing
import time
import tempfile
import pickle
import thriftpy

try:
import dbm
except ImportError:
import dbm.ndbm as dbm

import pytest

from thriftpy.contrib.tracking import TTrackedProcessor, TTrackedClient, \
TrackerBase, trace_thrift
from thriftpy.contrib.tracking.tracker import ctx

from thriftpy.thrift import TProcessorFactory, TClient, TProcessor
from thriftpy.server import TThreadedServer
from thriftpy.transport import TServerSocket, TBufferedTransportFactory, \
TTransportException, TSocket
from thriftpy.protocol import TBinaryProtocolFactory


addressbook = thriftpy.load(os.path.join(os.path.dirname(__file__),
"addressbook.thrift"))
_, db_file = tempfile.mkstemp()


class SampleTracker(TrackerBase):
def record(self, header, exception):
db = dbm.open(db_file, 'w')

key = "%s:%d" % (header.request_id, header.seq)
db[key.encode("ascii")] = pickle.dumps(header.__dict__)
db.close()

tracker = SampleTracker("test_client", "test_server")


class Dispatcher(object):
def __init__(self):
self.ab = addressbook.AddressBook()
self.ab.people = {}

def ping(self):
return True

def hello(self, name):
return "hello %s" % name

def remove(self, name):
person = addressbook.Person(name="mary")
with client(port=6098) as c:
c.add(person)
return True

def get_phonenumbers(self, name, count):
return [addressbook.PhoneNumber(number="sdaf"),
addressbook.PhoneNumber(number='saf')]

def add(self, person):
with client(port=6099) as c:
c.hello("jane")
return True

def get(self, name):
raise addressbook.PersonNotExistsError()


class TSampleServer(TThreadedServer):
def __init__(self, processor_factory, trans, trans_factory, prot_factory):
self.daemon = False
self.processor_factory = processor_factory
self.trans = trans

self.itrans_factory = self.otrans_factory = trans_factory
self.iprot_factory = self.oprot_factory = prot_factory
self.closed = False

def handle(self, client):
processor = self.processor_factory.get_processor()
itrans = self.itrans_factory.get_transport(client)
otrans = self.otrans_factory.get_transport(client)
iprot = self.iprot_factory.get_protocol(itrans)
oprot = self.oprot_factory.get_protocol(otrans)
try:
while True:
processor.process(iprot, oprot)
except TTransportException:
pass
except Exception:
raise

itrans.close()
otrans.close()


def gen_server(port=6029, tracker=tracker, processor=TTrackedProcessor):
args = [processor, addressbook.AddressBookService, Dispatcher()]
if tracker:
args.insert(1, tracker)
processor = TProcessorFactory(*args)
server_socket = TServerSocket(host="localhost", port=port)
server = TSampleServer(processor, server_socket,
prot_factory=TBinaryProtocolFactory(),
trans_factory=TBufferedTransportFactory())
ps = multiprocessing.Process(target=server.serve)
ps.start()
return ps, server


@pytest.fixture
def server(request):
ps, ser = gen_server()
time.sleep(0.15)

def fin():
if ps.is_alive():
ps.terminate()
request.addfinalizer(fin)
return ser


@pytest.fixture
def server1(request):
ps, ser = gen_server(port=6098)
time.sleep(0.15)

def fin():
if ps.is_alive():
ps.terminate()
request.addfinalizer(fin)
return ser


@pytest.fixture
def server2(request):
ps, ser = gen_server(port=6099)
time.sleep(0.15)

def fin():
if ps.is_alive():
ps.terminate()
request.addfinalizer(fin)
return ser


@pytest.fixture
def not_tracked_server(request):
ps, ser = gen_server(port=6030, tracker=None, processor=TProcessor)
time.sleep(0.15)

def fin():
if ps.is_alive():
ps.terminate()
request.addfinalizer(fin)
return ser


@contextlib.contextmanager
def client(client_class=TTrackedClient, port=6029):
socket = TSocket("localhost", port)

try:
trans = TBufferedTransportFactory().get_transport(socket)
proto = TBinaryProtocolFactory().get_protocol(trans)
trans.open()
args = [addressbook.AddressBookService, proto]
if client_class.__name__ == TTrackedClient.__name__:
args.insert(0, tracker)
yield client_class(*args)
finally:
trans.close()


@pytest.fixture
def dbm_db(request):
db = dbm.open(db_file, 'n')
db.close()

def fin():
try:
os.remove(db_file)
except OSError:
pass
request.addfinalizer(fin)


def test_negotiation(server):
with client() as c:
assert c._upgraded is True


def test_tracker(server, dbm_db):
with client() as c:
c.ping()

time.sleep(0.2)

db = dbm.open(db_file, 'r')
headers = list(db.keys())
assert len(headers) == 1

request_id = headers[0]
data = pickle.loads(db[request_id])

assert "start" in data and "end" in data
data.pop("start")
data.pop("end")

assert data == {
"request_id": request_id.decode("ascii").split(':')[0],
"seq": 0,
"client": "test_client",
"server": "test_server",
"api": "ping",
"status": True
}


def test_tracker_chain(server, server1, server2, dbm_db):
with client() as c:
c.remove("jane")

time.sleep(0.2)

db = dbm.open(db_file, 'r')
headers = list(db.keys())
assert len(headers) == 3

headers.sort()

header0 = pickle.loads(db[headers[0]])
header1 = pickle.loads(db[headers[1]])
header2 = pickle.loads(db[headers[2]])

assert header0["request_id"] == header1["request_id"] == \
header2["request_id"] == headers[0].decode("ascii").split(':')[0]
assert header0["seq"] == 0 and header1["seq"] == 1 and header2["seq"] == 2


def test_exception(server, dbm_db):
with pytest.raises(addressbook.PersonNotExistsError):
with client() as c:
c.get("jane")

db = dbm.open(db_file, 'r')
headers = list(db.keys())
assert len(headers) == 1

header = pickle.loads(db[headers[0]])
assert header["status"] is False


def test_not_tracked_client_tracked_server(server):
with client(TClient) as c:
c.ping()
c.hello("world")


def test_tracked_client_not_tracked_server(not_tracked_server):
with client(port=6030) as c:
assert c._upgraded is False
c.ping()
c.hello("cat")
a = c.get_phonenumbers("hello", 54)
assert len(a) == 2
assert a[0].number == 'sdaf' and a[1].number == 'saf'


def test_request_id_func():
ctx.__dict__.clear()

header = trace_thrift.RequestHeader()
header.request_id = "hello"
header.seq = 0

tracker = TrackerBase()
tracker.handle(header)

header2 = trace_thrift.RequestHeader()
tracker.gen_header(header2)
assert header2.request_id == "hello"
Empty file added thriftpy/contrib/__init__.py
Empty file.
Loading

0 comments on commit 67acce2

Please sign in to comment.