This repository has been archived by the owner on Dec 10, 2018. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 286
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #96 from maralla/tracking
add tracking support
- Loading branch information
Showing
10 changed files
with
545 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
include README.rst CHANGES.rst | ||
include thriftpy/protocol/cybin/*.pyx thriftpy/protocol/cybin/*.c thriftpy/protocol/cybin/*.h | ||
include thriftpy/transport/*/*.pyx thriftpy/transport/*/*.c | ||
include thriftpy/transport/*.pyx thriftpy/transport/*.pxd thriftpy/transport/*.c | ||
recursive-include thriftpy/protocol/cybin *.pyx *.c *.h | ||
recursive-include thriftpy/transport *.pyx *.pxd *.c | ||
include thriftpy/contrib/tracking/tracking.thrift |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,290 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from __future__ import absolute_import | ||
|
||
import contextlib | ||
import os | ||
import multiprocessing | ||
import time | ||
import tempfile | ||
import pickle | ||
import thriftpy | ||
|
||
try: | ||
import dbm | ||
except ImportError: | ||
import dbm.ndbm as dbm | ||
|
||
import pytest | ||
|
||
from thriftpy.contrib.tracking import TTrackedProcessor, TTrackedClient, \ | ||
TrackerBase, trace_thrift | ||
from thriftpy.contrib.tracking.tracker import ctx | ||
|
||
from thriftpy.thrift import TProcessorFactory, TClient, TProcessor | ||
from thriftpy.server import TThreadedServer | ||
from thriftpy.transport import TServerSocket, TBufferedTransportFactory, \ | ||
TTransportException, TSocket | ||
from thriftpy.protocol import TBinaryProtocolFactory | ||
|
||
|
||
addressbook = thriftpy.load(os.path.join(os.path.dirname(__file__), | ||
"addressbook.thrift")) | ||
_, db_file = tempfile.mkstemp() | ||
|
||
|
||
class SampleTracker(TrackerBase): | ||
def record(self, header, exception): | ||
db = dbm.open(db_file, 'w') | ||
|
||
key = "%s:%d" % (header.request_id, header.seq) | ||
db[key.encode("ascii")] = pickle.dumps(header.__dict__) | ||
db.close() | ||
|
||
tracker = SampleTracker("test_client", "test_server") | ||
|
||
|
||
class Dispatcher(object): | ||
def __init__(self): | ||
self.ab = addressbook.AddressBook() | ||
self.ab.people = {} | ||
|
||
def ping(self): | ||
return True | ||
|
||
def hello(self, name): | ||
return "hello %s" % name | ||
|
||
def remove(self, name): | ||
person = addressbook.Person(name="mary") | ||
with client(port=6098) as c: | ||
c.add(person) | ||
return True | ||
|
||
def get_phonenumbers(self, name, count): | ||
return [addressbook.PhoneNumber(number="sdaf"), | ||
addressbook.PhoneNumber(number='saf')] | ||
|
||
def add(self, person): | ||
with client(port=6099) as c: | ||
c.hello("jane") | ||
return True | ||
|
||
def get(self, name): | ||
raise addressbook.PersonNotExistsError() | ||
|
||
|
||
class TSampleServer(TThreadedServer): | ||
def __init__(self, processor_factory, trans, trans_factory, prot_factory): | ||
self.daemon = False | ||
self.processor_factory = processor_factory | ||
self.trans = trans | ||
|
||
self.itrans_factory = self.otrans_factory = trans_factory | ||
self.iprot_factory = self.oprot_factory = prot_factory | ||
self.closed = False | ||
|
||
def handle(self, client): | ||
processor = self.processor_factory.get_processor() | ||
itrans = self.itrans_factory.get_transport(client) | ||
otrans = self.otrans_factory.get_transport(client) | ||
iprot = self.iprot_factory.get_protocol(itrans) | ||
oprot = self.oprot_factory.get_protocol(otrans) | ||
try: | ||
while True: | ||
processor.process(iprot, oprot) | ||
except TTransportException: | ||
pass | ||
except Exception: | ||
raise | ||
|
||
itrans.close() | ||
otrans.close() | ||
|
||
|
||
def gen_server(port=6029, tracker=tracker, processor=TTrackedProcessor): | ||
args = [processor, addressbook.AddressBookService, Dispatcher()] | ||
if tracker: | ||
args.insert(1, tracker) | ||
processor = TProcessorFactory(*args) | ||
server_socket = TServerSocket(host="localhost", port=port) | ||
server = TSampleServer(processor, server_socket, | ||
prot_factory=TBinaryProtocolFactory(), | ||
trans_factory=TBufferedTransportFactory()) | ||
ps = multiprocessing.Process(target=server.serve) | ||
ps.start() | ||
return ps, server | ||
|
||
|
||
@pytest.fixture | ||
def server(request): | ||
ps, ser = gen_server() | ||
time.sleep(0.15) | ||
|
||
def fin(): | ||
if ps.is_alive(): | ||
ps.terminate() | ||
request.addfinalizer(fin) | ||
return ser | ||
|
||
|
||
@pytest.fixture | ||
def server1(request): | ||
ps, ser = gen_server(port=6098) | ||
time.sleep(0.15) | ||
|
||
def fin(): | ||
if ps.is_alive(): | ||
ps.terminate() | ||
request.addfinalizer(fin) | ||
return ser | ||
|
||
|
||
@pytest.fixture | ||
def server2(request): | ||
ps, ser = gen_server(port=6099) | ||
time.sleep(0.15) | ||
|
||
def fin(): | ||
if ps.is_alive(): | ||
ps.terminate() | ||
request.addfinalizer(fin) | ||
return ser | ||
|
||
|
||
@pytest.fixture | ||
def not_tracked_server(request): | ||
ps, ser = gen_server(port=6030, tracker=None, processor=TProcessor) | ||
time.sleep(0.15) | ||
|
||
def fin(): | ||
if ps.is_alive(): | ||
ps.terminate() | ||
request.addfinalizer(fin) | ||
return ser | ||
|
||
|
||
@contextlib.contextmanager | ||
def client(client_class=TTrackedClient, port=6029): | ||
socket = TSocket("localhost", port) | ||
|
||
try: | ||
trans = TBufferedTransportFactory().get_transport(socket) | ||
proto = TBinaryProtocolFactory().get_protocol(trans) | ||
trans.open() | ||
args = [addressbook.AddressBookService, proto] | ||
if client_class.__name__ == TTrackedClient.__name__: | ||
args.insert(0, tracker) | ||
yield client_class(*args) | ||
finally: | ||
trans.close() | ||
|
||
|
||
@pytest.fixture | ||
def dbm_db(request): | ||
db = dbm.open(db_file, 'n') | ||
db.close() | ||
|
||
def fin(): | ||
try: | ||
os.remove(db_file) | ||
except OSError: | ||
pass | ||
request.addfinalizer(fin) | ||
|
||
|
||
def test_negotiation(server): | ||
with client() as c: | ||
assert c._upgraded is True | ||
|
||
|
||
def test_tracker(server, dbm_db): | ||
with client() as c: | ||
c.ping() | ||
|
||
time.sleep(0.2) | ||
|
||
db = dbm.open(db_file, 'r') | ||
headers = list(db.keys()) | ||
assert len(headers) == 1 | ||
|
||
request_id = headers[0] | ||
data = pickle.loads(db[request_id]) | ||
|
||
assert "start" in data and "end" in data | ||
data.pop("start") | ||
data.pop("end") | ||
|
||
assert data == { | ||
"request_id": request_id.decode("ascii").split(':')[0], | ||
"seq": 0, | ||
"client": "test_client", | ||
"server": "test_server", | ||
"api": "ping", | ||
"status": True | ||
} | ||
|
||
|
||
def test_tracker_chain(server, server1, server2, dbm_db): | ||
with client() as c: | ||
c.remove("jane") | ||
|
||
time.sleep(0.2) | ||
|
||
db = dbm.open(db_file, 'r') | ||
headers = list(db.keys()) | ||
assert len(headers) == 3 | ||
|
||
headers.sort() | ||
|
||
header0 = pickle.loads(db[headers[0]]) | ||
header1 = pickle.loads(db[headers[1]]) | ||
header2 = pickle.loads(db[headers[2]]) | ||
|
||
assert header0["request_id"] == header1["request_id"] == \ | ||
header2["request_id"] == headers[0].decode("ascii").split(':')[0] | ||
assert header0["seq"] == 0 and header1["seq"] == 1 and header2["seq"] == 2 | ||
|
||
|
||
def test_exception(server, dbm_db): | ||
with pytest.raises(addressbook.PersonNotExistsError): | ||
with client() as c: | ||
c.get("jane") | ||
|
||
db = dbm.open(db_file, 'r') | ||
headers = list(db.keys()) | ||
assert len(headers) == 1 | ||
|
||
header = pickle.loads(db[headers[0]]) | ||
assert header["status"] is False | ||
|
||
|
||
def test_not_tracked_client_tracked_server(server): | ||
with client(TClient) as c: | ||
c.ping() | ||
c.hello("world") | ||
|
||
|
||
def test_tracked_client_not_tracked_server(not_tracked_server): | ||
with client(port=6030) as c: | ||
assert c._upgraded is False | ||
c.ping() | ||
c.hello("cat") | ||
a = c.get_phonenumbers("hello", 54) | ||
assert len(a) == 2 | ||
assert a[0].number == 'sdaf' and a[1].number == 'saf' | ||
|
||
|
||
def test_request_id_func(): | ||
ctx.__dict__.clear() | ||
|
||
header = trace_thrift.RequestHeader() | ||
header.request_id = "hello" | ||
header.seq = 0 | ||
|
||
tracker = TrackerBase() | ||
tracker.handle(header) | ||
|
||
header2 = trace_thrift.RequestHeader() | ||
tracker.gen_header(header2) | ||
assert header2.request_id == "hello" |
Empty file.
Oops, something went wrong.