diff --git a/docs/faq.rst b/docs/faq.rst index 6bbc2f9a5..a7333fc71 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -545,7 +545,7 @@ To start hacking on Stem please do the following and don't hesitate to let me know if you get stuck or would like to discuss anything! #. Clone our `git `_ repository: **git clone https://git.torproject.org/stem.git** -#. Get our test dependencies: **sudo pip install mock pycodestyle pyflakes**. +#. Get our test dependencies: **sudo pip install mock pycodestyle pyflakes mypy**. #. Find a `bug or feature `_ that sounds interesting. #. When you have something that you would like to contribute back do the following... @@ -588,11 +588,14 @@ can exercise alternate Tor configurations with the ``--target`` argument (see ~/stem$ ./run_tests.py --integ --tor /path/to/tor ~/stem$ ./run_tests.py --integ --target RUN_COOKIE -**Static** tests use `pyflakes `_ to do static -error checking and `pycodestyle -`_ for style checking. If you -have them installed then they automatically take place as part of all test -runs. +**Static** tests use... + +* `pyflakes `_ for error checks +* `pycodestyle `_ for style checks +* `mypy `_ for type checks + +If you have them installed then they automatically take place as part of all +test runs. See ``run_tests.py --help`` for more usage information. diff --git a/run_tests.py b/run_tests.py index c9196e188..fd46211fc 100755 --- a/run_tests.py +++ b/run_tests.py @@ -194,7 +194,7 @@ def main(): test_config.load(os.environ['STEM_TEST_CONFIG']) try: - args = test.arguments.parse(sys.argv[1:]) + args = test.arguments.Arguments.parse(sys.argv[1:]) test.task.TOR_VERSION.args = (args.tor_path,) test.output.SUPPRESS_STDOUT = args.quiet except ValueError as exc: @@ -202,7 +202,7 @@ def main(): sys.exit(1) if args.print_help: - println(test.arguments.get_help()) + println(test.arguments.Arguments.get_help()) sys.exit() elif not args.run_unit and not args.run_integ: println('Nothing to run (for usage provide --help)\n') @@ -217,12 +217,14 @@ def main(): test.task.CRYPTO_VERSION, test.task.PYFLAKES_VERSION, test.task.PYCODESTYLE_VERSION, + test.task.MYPY_VERSION, test.task.CLEAN_PYC, test.task.UNUSED_TESTS, test.task.IMPORT_TESTS, test.task.REMOVE_TOR_DATA_DIR if args.run_integ else None, test.task.PYFLAKES_TASK if not args.specific_test else None, test.task.PYCODESTYLE_TASK if not args.specific_test else None, + test.task.MYPY_TASK if not args.specific_test else None, ) # Test logging. If '--log-file' is provided we log to that location, @@ -334,7 +336,7 @@ def main(): static_check_issues = {} - for task in (test.task.PYFLAKES_TASK, test.task.PYCODESTYLE_TASK): + for task in (test.task.PYFLAKES_TASK, test.task.PYCODESTYLE_TASK, test.task.MYPY_TASK): if not task.is_available and task.unavailable_msg: println(task.unavailable_msg, ERROR) else: @@ -381,7 +383,7 @@ def _print_static_issues(static_check_issues): if static_check_issues: println('STATIC CHECKS', STATUS) - for file_path in static_check_issues: + for file_path in sorted(static_check_issues): println('* %s' % file_path, STATUS) # Make a dict of line numbers to its issues. This is so we can both sort diff --git a/stem/__init__.py b/stem/__init__.py index 907156fe0..ce8d70a95 100644 --- a/stem/__init__.py +++ b/stem/__init__.py @@ -507,6 +507,8 @@ import stem.util import stem.util.enum +from typing import Any, Optional, Sequence + __version__ = '1.8.0-dev' __author__ = 'Damian Johnson' __contact__ = 'atagar@torproject.org' @@ -565,7 +567,7 @@ ] # Constant that we use by default for our User-Agent when downloading descriptors -stem.USER_AGENT = 'Stem/%s' % __version__ +USER_AGENT = 'Stem/%s' % __version__ # Constant to indicate an undefined argument default. Usually we'd use None for # this, but users will commonly provide None as the argument so need something @@ -584,7 +586,7 @@ class Endpoint(object): :var int port: port of the endpoint """ - def __init__(self, address, port): + def __init__(self, address: str, port: int) -> None: if not stem.util.connection.is_valid_ipv4_address(address) and not stem.util.connection.is_valid_ipv6_address(address): raise ValueError("'%s' isn't a valid IPv4 or IPv6 address" % address) elif not stem.util.connection.is_valid_port(port): @@ -593,13 +595,13 @@ def __init__(self, address, port): self.address = address self.port = int(port) - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'address', 'port', cache = True) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, Endpoint) else False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other @@ -610,11 +612,11 @@ class ORPort(Endpoint): :var list link_protocols: link protocol version we're willing to establish """ - def __init__(self, address, port, link_protocols = None): + def __init__(self, address: str, port: int, link_protocols: Optional[Sequence['stem.client.datatype.LinkProtocol']] = None) -> None: # type: ignore super(ORPort, self).__init__(address, port) self.link_protocols = link_protocols - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'link_protocols', parent = Endpoint, cache = True) @@ -642,7 +644,9 @@ class OperationFailed(ControllerError): message """ - def __init__(self, code = None, message = None): + # TODO: should the code be an int instead? + + def __init__(self, code: Optional[str] = None, message: Optional[str] = None) -> None: super(ControllerError, self).__init__(message) self.code = code self.message = message @@ -658,10 +662,10 @@ class CircuitExtensionFailed(UnsatisfiableRequest): """ An attempt to create or extend a circuit failed. - :var stem.response.CircuitEvent circ: response notifying us of the failure + :var stem.response.events.CircuitEvent circ: response notifying us of the failure """ - def __init__(self, message, circ = None): + def __init__(self, message: str, circ: Optional['stem.response.events.CircuitEvent'] = None) -> None: # type: ignore super(CircuitExtensionFailed, self).__init__(message = message) self.circ = circ @@ -674,7 +678,7 @@ class DescriptorUnavailable(UnsatisfiableRequest): Subclassed under UnsatisfiableRequest rather than OperationFailed. """ - def __init__(self, message): + def __init__(self, message: str) -> None: super(DescriptorUnavailable, self).__init__(message = message) @@ -685,7 +689,7 @@ class Timeout(UnsatisfiableRequest): .. versionadded:: 1.7.0 """ - def __init__(self, message): + def __init__(self, message: str) -> None: super(Timeout, self).__init__(message = message) @@ -705,7 +709,7 @@ class InvalidArguments(InvalidRequest): :var list arguments: a list of arguments which were invalid """ - def __init__(self, code = None, message = None, arguments = None): + def __init__(self, code: Optional[str] = None, message: Optional[str] = None, arguments: Optional[Sequence[str]] = None): super(InvalidArguments, self).__init__(code, message) self.arguments = arguments @@ -736,7 +740,7 @@ class DownloadFailed(IOError): :var str stacktrace_str: string representation of the stacktrace """ - def __init__(self, url, error, stacktrace, message = None): + def __init__(self, url: str, error: BaseException, stacktrace: Any, message: Optional[str] = None) -> None: if message is None: # The string representation of exceptions can reside in several places. # urllib.URLError use a 'reason' attribute that in turn may referrence @@ -773,7 +777,7 @@ class DownloadTimeout(DownloadFailed): .. versionadded:: 1.8.0 """ - def __init__(self, url, error, stacktrace, timeout): + def __init__(self, url: str, error: BaseException, stacktrace: Any, timeout: float): message = 'Failed to download from %s: %0.1f second timeout reached' % (url, timeout) super(DownloadTimeout, self).__init__(url, error, stacktrace, message) @@ -917,7 +921,7 @@ def __init__(self, url, error, stacktrace, timeout): ) # StreamClosureReason is a superset of RelayEndReason -StreamClosureReason = stem.util.enum.UppercaseEnum(*(RelayEndReason.keys() + [ +StreamClosureReason = stem.util.enum.UppercaseEnum(*(RelayEndReason.keys() + [ # type: ignore 'END', 'PRIVATE_ADDR', ])) diff --git a/stem/client/__init__.py b/stem/client/__init__.py index 57cd3457e..8726bdbff 100644 --- a/stem/client/__init__.py +++ b/stem/client/__init__.py @@ -33,9 +33,13 @@ import stem.socket import stem.util.connection +from types import TracebackType +from typing import Dict, Iterator, List, Optional, Sequence, Type, Union + from stem.client.cell import ( CELL_TYPE_SIZE, FIXED_PAYLOAD_LEN, + PAYLOAD_LEN_SIZE, Cell, ) @@ -63,15 +67,15 @@ class Relay(object): :var int link_protocol: link protocol version we established """ - def __init__(self, orport, link_protocol): + def __init__(self, orport: stem.socket.RelaySocket, link_protocol: int) -> None: self.link_protocol = LinkProtocol(link_protocol) self._orport = orport self._orport_buffer = b'' # unread bytes self._orport_lock = threading.RLock() - self._circuits = {} + self._circuits = {} # type: Dict[int, stem.client.Circuit] @staticmethod - def connect(address, port, link_protocols = DEFAULT_LINK_PROTOCOLS): + def connect(address: str, port: int, link_protocols: Sequence['stem.client.datatype.LinkProtocol'] = DEFAULT_LINK_PROTOCOLS) -> 'stem.client.Relay': # type: ignore """ Establishes a connection with the given ORPort. @@ -118,7 +122,7 @@ def connect(address, port, link_protocols = DEFAULT_LINK_PROTOCOLS): # first VERSIONS cell, always have CIRCID_LEN == 2 for backward # compatibility. - conn.send(stem.client.cell.VersionsCell(link_protocols).pack(2)) + conn.send(stem.client.cell.VersionsCell(link_protocols).pack(2)) # type: ignore response = conn.recv() # Link negotiation ends right away if we lack a common protocol @@ -128,12 +132,12 @@ def connect(address, port, link_protocols = DEFAULT_LINK_PROTOCOLS): conn.close() raise stem.SocketError('Unable to establish a common link protocol with %s:%i' % (address, port)) - versions_reply = stem.client.cell.Cell.pop(response, 2)[0] + versions_reply = stem.client.cell.Cell.pop(response, 2)[0] # type: stem.client.cell.VersionsCell # type: ignore common_protocols = set(link_protocols).intersection(versions_reply.versions) if not common_protocols: conn.close() - raise stem.SocketError('Unable to find a common link protocol. We support %s but %s:%i supports %s.' % (', '.join(link_protocols), address, port, ', '.join(versions_reply.versions))) + raise stem.SocketError('Unable to find a common link protocol. We support %s but %s:%i supports %s.' % (', '.join(map(str, link_protocols)), address, port, ', '.join(map(str, versions_reply.versions)))) # Establishing connections requires sending a NETINFO, but including our # address is optional. We can revisit including it when we have a usecase @@ -144,7 +148,10 @@ def connect(address, port, link_protocols = DEFAULT_LINK_PROTOCOLS): return Relay(conn, link_protocol) - def _recv(self, raw = False): + def _recv_bytes(self) -> bytes: + return self._recv(True) # type: ignore + + def _recv(self, raw: bool = False) -> 'stem.client.cell.Cell': """ Reads the next cell from our ORPort. If none is present this blocks until one is available. @@ -169,23 +176,23 @@ def _recv(self, raw = False): else: # variable length, our next field is the payload size - while len(self._orport_buffer) < (circ_id_size + CELL_TYPE_SIZE.size + FIXED_PAYLOAD_LEN.size): + while len(self._orport_buffer) < (circ_id_size + CELL_TYPE_SIZE.size + FIXED_PAYLOAD_LEN): self._orport_buffer += self._orport.recv() # read until we know the cell size - payload_len = FIXED_PAYLOAD_LEN.pop(self._orport_buffer[circ_id_size + CELL_TYPE_SIZE.size:])[0] - cell_size = circ_id_size + CELL_TYPE_SIZE.size + FIXED_PAYLOAD_LEN.size + payload_len + payload_len = PAYLOAD_LEN_SIZE.pop(self._orport_buffer[circ_id_size + CELL_TYPE_SIZE.size:])[0] + cell_size = circ_id_size + CELL_TYPE_SIZE.size + payload_len while len(self._orport_buffer) < cell_size: self._orport_buffer += self._orport.recv() # read until we have the full cell if raw: content, self._orport_buffer = split(self._orport_buffer, cell_size) - return content + return content # type: ignore else: cell, self._orport_buffer = Cell.pop(self._orport_buffer, self.link_protocol) return cell - def _msg(self, cell): + def _msg(self, cell: 'stem.client.cell.Cell') -> Iterator['stem.client.cell.Cell']: """ Sends a cell on the ORPort and provides the response we receive in reply. @@ -210,14 +217,14 @@ def _msg(self, cell): :returns: **generator** with the cells received in reply """ + # TODO: why is this an iterator? + self._orport.recv(timeout = 0) # discard unread data self._orport.send(cell.pack(self.link_protocol)) response = self._orport.recv(timeout = 1) + yield stem.client.cell.Cell.pop(response, self.link_protocol)[0] - for received_cell in stem.client.cell.Cell.pop(response, self.link_protocol): - yield received_cell - - def is_alive(self): + def is_alive(self) -> bool: """ Checks if our socket is currently connected. This is a pass-through for our socket's :func:`~stem.socket.BaseSocket.is_alive` method. @@ -227,7 +234,7 @@ def is_alive(self): return self._orport.is_alive() - def connection_time(self): + def connection_time(self) -> float: """ Provides the unix timestamp for when our socket was either connected or disconnected. That is to say, the time we connected if we're currently @@ -239,7 +246,7 @@ def connection_time(self): return self._orport.connection_time() - def close(self): + def close(self) -> None: """ Closes our socket connection. This is a pass-through for our socket's :func:`~stem.socket.BaseSocket.close` method. @@ -248,7 +255,7 @@ def close(self): with self._orport_lock: return self._orport.close() - def create_circuit(self): + def create_circuit(self) -> 'stem.client.Circuit': """ Establishes a new circuit. """ @@ -277,15 +284,15 @@ def create_circuit(self): return circ - def __iter__(self): + def __iter__(self) -> Iterator['stem.client.Circuit']: with self._orport_lock: for circ in self._circuits.values(): yield circ - def __enter__(self): + def __enter__(self) -> 'stem.client.Relay': return self - def __exit__(self, exit_type, value, traceback): + def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: self.close() @@ -304,14 +311,14 @@ class Circuit(object): :raises: **ImportError** if the cryptography module is unavailable """ - def __init__(self, relay, circ_id, kdf): + def __init__(self, relay: 'stem.client.Relay', circ_id: int, kdf: 'stem.client.datatype.KDF') -> None: try: from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend except ImportError: raise ImportError('Circuit construction requires the cryptography module') - ctr = modes.CTR(ZERO * (algorithms.AES.block_size // 8)) + ctr = modes.CTR(ZERO * (algorithms.AES.block_size // 8)) # type: ignore self.relay = relay self.id = circ_id @@ -320,7 +327,7 @@ def __init__(self, relay, circ_id, kdf): self.forward_key = Cipher(algorithms.AES(kdf.forward_key), ctr, default_backend()).encryptor() self.backward_key = Cipher(algorithms.AES(kdf.backward_key), ctr, default_backend()).decryptor() - def directory(self, request, stream_id = 0): + def directory(self, request: str, stream_id: int = 0) -> bytes: """ Request descriptors from the relay. @@ -334,13 +341,13 @@ def directory(self, request, stream_id = 0): self._send(RelayCommand.BEGIN_DIR, stream_id = stream_id) self._send(RelayCommand.DATA, request, stream_id = stream_id) - response = [] + response = [] # type: List[stem.client.cell.RelayCell] while True: # Decrypt relay cells received in response. Our digest/key only # updates when handled successfully. - encrypted_cell = self.relay._recv(raw = True) + encrypted_cell = self.relay._recv_bytes() decrypted_cell, backward_key, backward_digest = stem.client.cell.RelayCell.decrypt(self.relay.link_protocol, encrypted_cell, self.backward_key, self.backward_digest) @@ -355,7 +362,7 @@ def directory(self, request, stream_id = 0): else: response.append(decrypted_cell) - def _send(self, command, data = '', stream_id = 0): + def _send(self, command: 'stem.client.datatype.RelayCommand', data: Union[bytes, str] = b'', stream_id: int = 0) -> None: """ Sends a message over the circuit. @@ -375,13 +382,13 @@ def _send(self, command, data = '', stream_id = 0): self.forward_digest = forward_digest self.forward_key = forward_key - def close(self): + def close(self) -> None: with self.relay._orport_lock: self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol)) del self.relay._circuits[self.id] - def __enter__(self): + def __enter__(self) -> 'stem.client.Circuit': return self - def __exit__(self, exit_type, value, traceback): + def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: self.close() diff --git a/stem/client/cell.py b/stem/client/cell.py index 838885566..c88ba7166 100644 --- a/stem/client/cell.py +++ b/stem/client/cell.py @@ -49,10 +49,12 @@ from stem.client.datatype import HASH_LEN, ZERO, LinkProtocol, Address, Certificate, CloseReason, RelayCommand, Size, split from stem.util import datetime_to_unix, str_tools +from typing import Any, Iterator, List, Optional, Sequence, Tuple, Type, Union + FIXED_PAYLOAD_LEN = 509 # PAYLOAD_LEN, per tor-spec section 0.2 AUTH_CHALLENGE_SIZE = 32 -CELL_TYPE_SIZE = Size.CHAR +CELL_TYPE_SIZE = Size.CHAR # type: stem.client.datatype.Size PAYLOAD_LEN_SIZE = Size.SHORT RELAY_DIGEST_SIZE = Size.LONG @@ -96,17 +98,19 @@ class Cell(object): VALUE = -1 IS_FIXED_SIZE = False - def __init__(self, unused = b''): + def __init__(self, unused: bytes = b'') -> None: super(Cell, self).__init__() self.unused = unused @staticmethod - def by_name(name): + def by_name(name: str) -> Type['stem.client.cell.Cell']: """ Provides cell attributes by its name. :param str name: cell command to fetch + :returns: cell class with this name + :raises: **ValueError** if cell type is invalid """ @@ -117,12 +121,14 @@ def by_name(name): raise ValueError("'%s' isn't a valid cell type" % name) @staticmethod - def by_value(value): + def by_value(value: int) -> Type['stem.client.cell.Cell']: """ Provides cell attributes by its value. :param int value: cell value to fetch + :returns: cell class with this numeric value + :raises: **ValueError** if cell type is invalid """ @@ -132,11 +138,11 @@ def by_value(value): raise ValueError("'%s' isn't a valid cell value" % value) - def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: raise NotImplementedError('Packing not yet implemented for %s cells' % type(self).NAME) @staticmethod - def unpack(content, link_protocol): + def unpack(content: bytes, link_protocol: 'stem.client.datatype.LinkProtocol') -> Iterator['stem.client.cell.Cell']: """ Unpacks all cells from a response. @@ -155,7 +161,7 @@ def unpack(content, link_protocol): yield cell @staticmethod - def pop(content, link_protocol): + def pop(content: bytes, link_protocol: 'stem.client.datatype.LinkProtocol') -> Tuple['stem.client.cell.Cell', bytes]: """ Unpacks the first cell. @@ -187,7 +193,7 @@ def pop(content, link_protocol): return cls._unpack(payload, circ_id, link_protocol), content @classmethod - def _pack(cls, link_protocol, payload, unused = b'', circ_id = None): + def _pack(cls: Type['stem.client.cell.Cell'], link_protocol: 'stem.client.datatype.LinkProtocol', payload: bytes, unused: bytes = b'', circ_id: Optional[int] = None) -> bytes: """ Provides bytes that can be used on the wire for these cell attributes. Format of a properly packed cell depends on if it's fixed or variable @@ -241,13 +247,13 @@ def _pack(cls, link_protocol, payload, unused = b'', circ_id = None): return bytes(cell) @classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls: Type['stem.client.cell.Cell'], content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.Cell': """ Subclass implementation for unpacking cell content. :param bytes content: payload to decode - :param stem.client.datatype.LinkProtocol link_protocol: link protocol version :param int circ_id: circuit id cell is for + :param stem.client.datatype.LinkProtocol link_protocol: link protocol version :returns: instance of this cell type @@ -256,10 +262,10 @@ def _unpack(cls, content, circ_id, link_protocol): raise NotImplementedError('Unpacking not yet implemented for %s cells' % cls.NAME) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, Cell) else False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other @@ -270,7 +276,7 @@ class CircuitCell(Cell): :var int circ_id: circuit id """ - def __init__(self, circ_id, unused = b''): + def __init__(self, circ_id: int, unused: bytes = b'') -> None: super(CircuitCell, self).__init__(unused) self.circ_id = circ_id @@ -286,7 +292,7 @@ class PaddingCell(Cell): VALUE = 0 IS_FIXED_SIZE = True - def __init__(self, payload = None): + def __init__(self, payload: Optional[bytes] = None) -> None: if not payload: payload = os.urandom(FIXED_PAYLOAD_LEN) elif len(payload) != FIXED_PAYLOAD_LEN: @@ -295,14 +301,14 @@ def __init__(self, payload = None): super(PaddingCell, self).__init__() self.payload = payload - def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: return PaddingCell._pack(link_protocol, self.payload) @classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.PaddingCell': return PaddingCell(content) - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'payload', cache = True) @@ -311,8 +317,8 @@ class CreateCell(CircuitCell): VALUE = 1 IS_FIXED_SIZE = True - def __init__(self): - super(CreateCell, self).__init__() # TODO: implement + def __init__(self, circ_id: int, unused: bytes = b'') -> None: + super(CreateCell, self).__init__(circ_id, unused) # TODO: implement class CreatedCell(CircuitCell): @@ -320,8 +326,8 @@ class CreatedCell(CircuitCell): VALUE = 2 IS_FIXED_SIZE = True - def __init__(self): - super(CreatedCell, self).__init__() # TODO: implement + def __init__(self, circ_id: int, unused: bytes = b'') -> None: + super(CreatedCell, self).__init__(circ_id, unused) # TODO: implement class RelayCell(CircuitCell): @@ -346,13 +352,13 @@ class RelayCell(CircuitCell): VALUE = 3 IS_FIXED_SIZE = True - def __init__(self, circ_id, command, data, digest = 0, stream_id = 0, recognized = 0, unused = b''): + def __init__(self, circ_id: int, command, data: Union[bytes, str], digest: Union[int, bytes, str, 'hashlib._HASH'] = 0, stream_id: int = 0, recognized: int = 0, unused: bytes = b'') -> None: # type: ignore if 'hash' in str(type(digest)).lower(): # Unfortunately hashlib generates from a dynamic private class so # isinstance() isn't such a great option. With python2/python3 the # name is 'hashlib.HASH' whereas PyPy calls it just 'HASH' or 'Hash'. - digest_packed = digest.digest()[:RELAY_DIGEST_SIZE.size] + digest_packed = digest.digest()[:RELAY_DIGEST_SIZE.size] # type: ignore digest = RELAY_DIGEST_SIZE.unpack(digest_packed) elif isinstance(digest, (bytes, str)): digest_packed = digest[:RELAY_DIGEST_SIZE.size] @@ -375,7 +381,7 @@ def __init__(self, circ_id, command, data, digest = 0, stream_id = 0, recognized elif stream_id and self.command in STREAM_ID_DISALLOWED: raise ValueError('%s relay cells concern the circuit itself and cannot have a stream id' % self.command) - def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: payload = bytearray() payload += Size.CHAR.pack(self.command_int) payload += Size.SHORT.pack(self.recognized) @@ -387,7 +393,7 @@ def pack(self, link_protocol): return RelayCell._pack(link_protocol, bytes(payload), self.unused, self.circ_id) @staticmethod - def decrypt(link_protocol, content, key, digest): + def decrypt(link_protocol: 'stem.client.datatype.LinkProtocol', content: bytes, key: 'cryptography.hazmat.primitives.ciphers.CipherContext', digest: 'hashlib._HASH') -> Tuple['stem.client.cell.RelayCell', 'cryptography.hazmat.primitives.ciphers.CipherContext', 'hashlib._HASH']: # type: ignore """ Decrypts content as a relay cell addressed to us. This provides back a tuple of the form... @@ -441,7 +447,7 @@ def decrypt(link_protocol, content, key, digest): return cell, new_key, new_digest - def encrypt(self, link_protocol, key, digest): + def encrypt(self, link_protocol: 'stem.client.datatype.LinkProtocol', key: 'cryptography.hazmat.primitives.ciphers.CipherContext', digest: 'hashlib._HASH') -> Tuple[bytes, 'cryptography.hazmat.primitives.ciphers.CipherContext', 'hashlib._HASH']: # type: ignore """ Encrypts our cell content to be sent with the given key. This provides back a tuple of the form... @@ -477,7 +483,7 @@ def encrypt(self, link_protocol, key, digest): return header + new_key.update(payload), new_key, new_digest @classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.RelayCell': command, content = Size.CHAR.pop(content) recognized, content = Size.SHORT.pop(content) # 'recognized' field stream_id, content = Size.SHORT.pop(content) @@ -490,7 +496,7 @@ def _unpack(cls, content, circ_id, link_protocol): return RelayCell(circ_id, command, data, digest, stream_id, recognized, unused) - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'command_int', 'stream_id', 'digest', 'data', cache = True) @@ -506,19 +512,19 @@ class DestroyCell(CircuitCell): VALUE = 4 IS_FIXED_SIZE = True - def __init__(self, circ_id, reason = CloseReason.NONE, unused = b''): + def __init__(self, circ_id: int, reason: 'stem.client.datatype.CloseReason' = CloseReason.NONE, unused: bytes = b'') -> None: super(DestroyCell, self).__init__(circ_id, unused) self.reason, self.reason_int = CloseReason.get(reason) - def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: return DestroyCell._pack(link_protocol, Size.CHAR.pack(self.reason_int), self.unused, self.circ_id) @classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls: Type['stem.client.cell.DestroyCell'], content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.DestroyCell': reason, unused = Size.CHAR.pop(content) return DestroyCell(circ_id, reason, unused) - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'circ_id', 'reason_int', cache = True) @@ -534,7 +540,7 @@ class CreateFastCell(CircuitCell): VALUE = 5 IS_FIXED_SIZE = True - def __init__(self, circ_id, key_material = None, unused = b''): + def __init__(self, circ_id: int, key_material: Optional[bytes] = None, unused: bytes = b'') -> None: if not key_material: key_material = os.urandom(HASH_LEN) elif len(key_material) != HASH_LEN: @@ -543,11 +549,11 @@ def __init__(self, circ_id, key_material = None, unused = b''): super(CreateFastCell, self).__init__(circ_id, unused) self.key_material = key_material - def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: return CreateFastCell._pack(link_protocol, self.key_material, self.unused, self.circ_id) @classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.CreateFastCell': key_material, unused = split(content, HASH_LEN) if len(key_material) != HASH_LEN: @@ -555,7 +561,7 @@ def _unpack(cls, content, circ_id, link_protocol): return CreateFastCell(circ_id, key_material, unused) - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'circ_id', 'key_material', cache = True) @@ -571,7 +577,7 @@ class CreatedFastCell(CircuitCell): VALUE = 6 IS_FIXED_SIZE = True - def __init__(self, circ_id, derivative_key, key_material = None, unused = b''): + def __init__(self, circ_id: int, derivative_key: bytes, key_material: Optional[bytes] = None, unused: bytes = b'') -> None: if not key_material: key_material = os.urandom(HASH_LEN) elif len(key_material) != HASH_LEN: @@ -584,11 +590,11 @@ def __init__(self, circ_id, derivative_key, key_material = None, unused = b''): self.key_material = key_material self.derivative_key = derivative_key - def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: return CreatedFastCell._pack(link_protocol, self.key_material + self.derivative_key, self.unused, self.circ_id) @classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.CreatedFastCell': if len(content) < HASH_LEN * 2: raise ValueError('Key material and derivatived key should be %i bytes, but was %i' % (HASH_LEN * 2, len(content))) @@ -597,7 +603,7 @@ def _unpack(cls, content, circ_id, link_protocol): return CreatedFastCell(circ_id, derivative_key, key_material, content) - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'circ_id', 'derivative_key', 'key_material', cache = True) @@ -612,16 +618,16 @@ class VersionsCell(Cell): VALUE = 7 IS_FIXED_SIZE = False - def __init__(self, versions): + def __init__(self, versions: Sequence[int]) -> None: super(VersionsCell, self).__init__() self.versions = versions - def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: payload = b''.join([Size.SHORT.pack(v) for v in self.versions]) return VersionsCell._pack(link_protocol, payload) @classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls: Type['stem.client.cell.VersionsCell'], content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.VersionsCell': link_protocols = [] while content: @@ -630,7 +636,7 @@ def _unpack(cls, content, circ_id, link_protocol): return VersionsCell(link_protocols) - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'versions', cache = True) @@ -647,13 +653,13 @@ class NetinfoCell(Cell): VALUE = 8 IS_FIXED_SIZE = True - def __init__(self, receiver_address, sender_addresses, timestamp = None, unused = b''): + def __init__(self, receiver_address: 'stem.client.datatype.Address', sender_addresses: Sequence['stem.client.datatype.Address'], timestamp: Optional[datetime.datetime] = None, unused: bytes = b'') -> None: super(NetinfoCell, self).__init__(unused) self.timestamp = timestamp if timestamp else datetime.datetime.now() self.receiver_address = receiver_address self.sender_addresses = sender_addresses - def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: payload = bytearray() payload += Size.LONG.pack(int(datetime_to_unix(self.timestamp))) payload += self.receiver_address.pack() @@ -665,7 +671,7 @@ def pack(self, link_protocol): return NetinfoCell._pack(link_protocol, bytes(payload), self.unused) @classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.NetinfoCell': timestamp, content = Size.LONG.pop(content) receiver_address, content = Address.pop(content) @@ -678,7 +684,7 @@ def _unpack(cls, content, circ_id, link_protocol): return NetinfoCell(receiver_address, sender_addresses, datetime.datetime.utcfromtimestamp(timestamp), unused = content) - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'timestamp', 'receiver_address', 'sender_addresses', cache = True) @@ -687,8 +693,8 @@ class RelayEarlyCell(CircuitCell): VALUE = 9 IS_FIXED_SIZE = True - def __init__(self): - super(RelayEarlyCell, self).__init__() # TODO: implement + def __init__(self, circ_id: int, unused: bytes = b'') -> None: + super(RelayEarlyCell, self).__init__(circ_id, unused) # TODO: implement class Create2Cell(CircuitCell): @@ -696,8 +702,8 @@ class Create2Cell(CircuitCell): VALUE = 10 IS_FIXED_SIZE = True - def __init__(self): - super(Create2Cell, self).__init__() # TODO: implement + def __init__(self, circ_id: int, unused: bytes = b'') -> None: + super(Create2Cell, self).__init__(circ_id, unused) # TODO: implement class Created2Cell(Cell): @@ -705,7 +711,7 @@ class Created2Cell(Cell): VALUE = 11 IS_FIXED_SIZE = True - def __init__(self): + def __init__(self) -> None: super(Created2Cell, self).__init__() # TODO: implement @@ -714,7 +720,7 @@ class PaddingNegotiateCell(Cell): VALUE = 12 IS_FIXED_SIZE = True - def __init__(self): + def __init__(self) -> None: super(PaddingNegotiateCell, self).__init__() # TODO: implement @@ -729,7 +735,7 @@ class VPaddingCell(Cell): VALUE = 128 IS_FIXED_SIZE = False - def __init__(self, size = None, payload = None): + def __init__(self, size: Optional[int] = None, payload: Optional[bytes] = None) -> None: if size is None and payload is None: raise ValueError('VPaddingCell constructor must specify payload or size') elif size is not None and size < 0: @@ -740,14 +746,14 @@ def __init__(self, size = None, payload = None): super(VPaddingCell, self).__init__() self.payload = payload if payload is not None else os.urandom(size) - def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: return VPaddingCell._pack(link_protocol, self.payload) @classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.VPaddingCell': return VPaddingCell(payload = content) - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'payload', cache = True) @@ -762,17 +768,17 @@ class CertsCell(Cell): VALUE = 129 IS_FIXED_SIZE = False - def __init__(self, certs, unused = b''): + def __init__(self, certs: Sequence['stem.client.datatype.Certificate'], unused: bytes = b'') -> None: super(CertsCell, self).__init__(unused) self.certificates = certs - def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: return CertsCell._pack(link_protocol, Size.CHAR.pack(len(self.certificates)) + b''.join([cert.pack() for cert in self.certificates]), self.unused) @classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.CertsCell': cert_count, content = Size.CHAR.pop(content) - certs = [] + certs = [] # type: List[stem.client.datatype.Certificate] for i in range(cert_count): if not content: @@ -783,7 +789,7 @@ def _unpack(cls, content, circ_id, link_protocol): return CertsCell(certs, unused = content) - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'certificates', cache = True) @@ -800,7 +806,7 @@ class AuthChallengeCell(Cell): VALUE = 130 IS_FIXED_SIZE = False - def __init__(self, methods, challenge = None, unused = b''): + def __init__(self, methods: Sequence[int], challenge: Optional[bytes] = None, unused: bytes = b'') -> None: if not challenge: challenge = os.urandom(AUTH_CHALLENGE_SIZE) elif len(challenge) != AUTH_CHALLENGE_SIZE: @@ -810,7 +816,7 @@ def __init__(self, methods, challenge = None, unused = b''): self.challenge = challenge self.methods = methods - def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: payload = bytearray() payload += self.challenge payload += Size.SHORT.pack(len(self.methods)) @@ -821,7 +827,7 @@ def pack(self, link_protocol): return AuthChallengeCell._pack(link_protocol, bytes(payload), self.unused) @classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls: Type['stem.client.cell.AuthChallengeCell'], content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.AuthChallengeCell': min_size = AUTH_CHALLENGE_SIZE + Size.SHORT.size if len(content) < min_size: raise ValueError('AUTH_CHALLENGE payload should be at least %i bytes, but was %i' % (min_size, len(content))) @@ -840,7 +846,7 @@ def _unpack(cls, content, circ_id, link_protocol): return AuthChallengeCell(methods, challenge, unused = content) - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'challenge', 'methods', cache = True) @@ -849,7 +855,7 @@ class AuthenticateCell(Cell): VALUE = 131 IS_FIXED_SIZE = False - def __init__(self): + def __init__(self) -> None: super(AuthenticateCell, self).__init__() # TODO: implement @@ -858,5 +864,5 @@ class AuthorizeCell(Cell): VALUE = 132 IS_FIXED_SIZE = False - def __init__(self): + def __init__(self) -> None: super(AuthorizeCell, self).__init__() # TODO: implement diff --git a/stem/client/datatype.py b/stem/client/datatype.py index 4f7110e92..acc9ec34e 100644 --- a/stem/client/datatype.py +++ b/stem/client/datatype.py @@ -144,6 +144,8 @@ import stem.util.connection import stem.util.enum +from typing import Any, Optional, Tuple, Union + ZERO = b'\x00' HASH_LEN = 20 KEY_LEN = 16 @@ -155,17 +157,17 @@ class _IntegerEnum(stem.util.enum.Enum): **UNKNOWN** value for integer values that lack a mapping. """ - def __init__(self, *args): + def __init__(self, *args: Union[Tuple[str, int], Tuple[str, str, int]]) -> None: self._enum_to_int = {} self._int_to_enum = {} parent_args = [] for entry in args: if len(entry) == 2: - enum, int_val = entry + enum, int_val = entry # type: ignore str_val = enum elif len(entry) == 3: - enum, str_val, int_val = entry + enum, str_val, int_val = entry # type: ignore else: raise ValueError('IntegerEnums can only be constructed with two or three value tuples: %s' % repr(entry)) @@ -176,7 +178,7 @@ def __init__(self, *args): parent_args.append(('UNKNOWN', 'UNKNOWN')) super(_IntegerEnum, self).__init__(*parent_args) - def get(self, val): + def get(self, val: Union[int, str]) -> Tuple[str, int]: """ Provides the (enum, int_value) tuple for a given value. """ @@ -246,7 +248,7 @@ def get(self, val): ) -def split(content, size): +def split(content: bytes, size: int) -> Tuple[bytes, bytes]: """ Simple split of bytes into two substrings. @@ -270,28 +272,25 @@ class LinkProtocol(int): from a range that's determined by our link protocol. """ - def __new__(cls, version): - if isinstance(version, LinkProtocol): - return version # already a LinkProtocol - - protocol = int.__new__(cls, version) - protocol.version = version - protocol.circ_id_size = Size.LONG if version > 3 else Size.SHORT - protocol.first_circ_id = 0x80000000 if version > 3 else 0x01 + def __new__(self, version: int) -> 'stem.client.datatype.LinkProtocol': + return int.__new__(self, version) # type: ignore - cell_header_size = protocol.circ_id_size.size + 1 # circuit id (2 or 4 bytes) + command (1 byte) - protocol.fixed_cell_length = cell_header_size + stem.client.cell.FIXED_PAYLOAD_LEN + def __init__(self, version: int) -> None: + self.version = version + self.circ_id_size = Size.LONG if version > 3 else Size.SHORT + self.first_circ_id = 0x80000000 if version > 3 else 0x01 - return protocol + cell_header_size = self.circ_id_size.size + 1 # circuit id (2 or 4 bytes) + command (1 byte) + self.fixed_cell_length = cell_header_size + stem.client.cell.FIXED_PAYLOAD_LEN - def __hash__(self): + def __hash__(self) -> int: # All LinkProtocol attributes can be derived from our version, so that's # all we need in our hash. Offsetting by our type so we don't hash conflict # with ints. return self.version * hash(str(type(self))) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, int): return self.version == other elif isinstance(other, LinkProtocol): @@ -299,10 +298,10 @@ def __eq__(self, other): else: return False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other - def __int__(self): + def __int__(self) -> int: return self.version @@ -311,7 +310,7 @@ class Field(object): Packable and unpackable datatype. """ - def pack(self): + def pack(self) -> bytes: """ Encodes field into bytes. @@ -323,7 +322,7 @@ def pack(self): raise NotImplementedError('Not yet available') @classmethod - def unpack(cls, packed): + def unpack(cls, packed: bytes) -> 'stem.client.datatype.Field': """ Decodes bytes into a field of this type. @@ -342,7 +341,7 @@ def unpack(cls, packed): return unpacked @staticmethod - def pop(packed): + def pop(packed: bytes) -> Tuple[Any, bytes]: """ Decodes bytes as this field type, providing it and the remainder. @@ -355,10 +354,10 @@ def pop(packed): raise NotImplementedError('Not yet available') - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, Field) else False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other @@ -378,15 +377,20 @@ class Size(Field): ==================== =========== """ - def __init__(self, name, size): + CHAR = None # type: Optional[stem.client.datatype.Size] + SHORT = None # type: Optional[stem.client.datatype.Size] + LONG = None # type: Optional[stem.client.datatype.Size] + LONG_LONG = None # type: Optional[stem.client.datatype.Size] + + def __init__(self, name: str, size: int) -> None: self.name = name self.size = size @staticmethod - def pop(packed): + def pop(packed: bytes) -> Tuple[int, bytes]: raise NotImplementedError("Use our constant's unpack() and pop() instead") - def pack(self, content): + def pack(self, content: int) -> bytes: # type: ignore try: return content.to_bytes(self.size, 'big') except: @@ -397,18 +401,18 @@ def pack(self, content): else: raise - def unpack(self, packed): + def unpack(self, packed: bytes) -> int: # type: ignore if self.size != len(packed): raise ValueError('%s is the wrong size for a %s field' % (repr(packed), self.name)) return int.from_bytes(packed, 'big') - def pop(self, packed): + def pop(self, packed: bytes) -> Tuple[int, bytes]: # type: ignore to_unpack, remainder = split(packed, self.size) return self.unpack(to_unpack), remainder - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'name', 'size', cache = True) @@ -418,50 +422,55 @@ class Address(Field): :var stem.client.AddrType type: address type :var int type_int: integer value of the address type - :var unicode value: address value + :var str value: address value :var bytes value_bin: encoded address value """ - def __init__(self, value, addr_type = None): + def __init__(self, value: Union[bytes, str], addr_type: Union[int, 'stem.client.datatype.AddrType'] = None) -> None: if addr_type is None: - if stem.util.connection.is_valid_ipv4_address(value): + if stem.util.connection.is_valid_ipv4_address(value): # type: ignore addr_type = AddrType.IPv4 - elif stem.util.connection.is_valid_ipv6_address(value): + elif stem.util.connection.is_valid_ipv6_address(value): # type: ignore addr_type = AddrType.IPv6 else: - raise ValueError("'%s' isn't an IPv4 or IPv6 address" % value) + raise ValueError("'%s' isn't an IPv4 or IPv6 address" % stem.util.str_tools._to_unicode(value)) + + value_bytes = stem.util.str_tools._to_bytes(value) + + self.value = None # type: Optional[str] + self.value_bin = None # type: Optional[bytes] self.type, self.type_int = AddrType.get(addr_type) if self.type == AddrType.IPv4: - if stem.util.connection.is_valid_ipv4_address(value): - self.value = value - self.value_bin = b''.join([Size.CHAR.pack(int(v)) for v in value.split('.')]) + if stem.util.connection.is_valid_ipv4_address(value_bytes): # type: ignore + self.value = stem.util.str_tools._to_unicode(value_bytes) + self.value_bin = b''.join([Size.CHAR.pack(int(v)) for v in value_bytes.split(b'.')]) else: - if len(value) != 4: + if len(value_bytes) != 4: raise ValueError('Packed IPv4 addresses should be four bytes, but was: %s' % repr(value)) - self.value = _unpack_ipv4_address(value) - self.value_bin = value + self.value = _unpack_ipv4_address(value_bytes) + self.value_bin = value_bytes elif self.type == AddrType.IPv6: - if stem.util.connection.is_valid_ipv6_address(value): - self.value = stem.util.connection.expand_ipv6_address(value).lower() + if stem.util.connection.is_valid_ipv6_address(value_bytes): # type: ignore + self.value = stem.util.connection.expand_ipv6_address(value_bytes).lower() # type: ignore self.value_bin = b''.join([Size.SHORT.pack(int(v, 16)) for v in self.value.split(':')]) else: - if len(value) != 16: + if len(value_bytes) != 16: raise ValueError('Packed IPv6 addresses should be sixteen bytes, but was: %s' % repr(value)) - self.value = _unpack_ipv6_address(value) - self.value_bin = value + self.value = _unpack_ipv6_address(value_bytes) + self.value_bin = value_bytes else: # The spec doesn't really tell us what form to expect errors to be. For # now just leaving the value unset so we can fill it in later when we # know what would be most useful. self.value = None - self.value_bin = value + self.value_bin = value_bytes - def pack(self): + def pack(self) -> bytes: cell = bytearray() cell += Size.CHAR.pack(self.type_int) cell += Size.CHAR.pack(len(self.value_bin)) @@ -469,7 +478,7 @@ def pack(self): return bytes(cell) @staticmethod - def pop(content): + def pop(content: bytes) -> Tuple['stem.client.datatype.Address', bytes]: addr_type, content = Size.CHAR.pop(content) addr_length, content = Size.CHAR.pop(content) @@ -480,7 +489,7 @@ def pop(content): return Address(addr_value, addr_type), content - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'type_int', 'value_bin', cache = True) @@ -493,11 +502,11 @@ class Certificate(Field): :var bytes value: certificate value """ - def __init__(self, cert_type, value): + def __init__(self, cert_type: Union[int, 'stem.client.datatype.CertType'], value: bytes) -> None: self.type, self.type_int = CertType.get(cert_type) self.value = value - def pack(self): + def pack(self) -> bytes: cell = bytearray() cell += Size.CHAR.pack(self.type_int) cell += Size.SHORT.pack(len(self.value)) @@ -505,7 +514,7 @@ def pack(self): return bytes(cell) @staticmethod - def pop(content): + def pop(content: bytes) -> Tuple['stem.client.datatype.Certificate', bytes]: cert_type, content = Size.CHAR.pop(content) cert_size, content = Size.SHORT.pop(content) @@ -515,7 +524,7 @@ def pop(content): cert_bytes, content = split(content, cert_size) return Certificate(cert_type, cert_bytes), content - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'type_int', 'value') @@ -532,12 +541,12 @@ class LinkSpecifier(Field): :var bytes value: encoded link specification destination """ - def __init__(self, link_type, value): + def __init__(self, link_type: int, value: bytes) -> None: self.type = link_type self.value = value @staticmethod - def pop(packed): + def pop(packed: bytes) -> Tuple['stem.client.datatype.LinkSpecifier', bytes]: # LSTYPE (Link specifier type) [1 byte] # LSLEN (Link specifier length) [1 byte] # LSPEC (Link specifier) [LSLEN bytes] @@ -561,7 +570,7 @@ def pop(packed): else: return LinkSpecifier(link_type, value), packed # unrecognized type - def pack(self): + def pack(self) -> bytes: cell = bytearray() cell += Size.CHAR.pack(self.type) cell += Size.CHAR.pack(len(self.value)) @@ -579,16 +588,16 @@ class LinkByIPv4(LinkSpecifier): :var int port: relay ORPort """ - def __init__(self, address, port): + def __init__(self, address: str, port: int) -> None: super(LinkByIPv4, self).__init__(0, _pack_ipv4_address(address) + Size.SHORT.pack(port)) self.address = address self.port = port @staticmethod - def unpack(value): + def unpack(value: bytes) -> 'stem.client.datatype.LinkByIPv4': if len(value) != 6: - raise ValueError('IPv4 link specifiers should be six bytes, but was %i instead: %s' % (len(value), binascii.hexlify(value))) + raise ValueError('IPv4 link specifiers should be six bytes, but was %i instead: %s' % (len(value), stem.util.str_tools._to_unicode(binascii.hexlify(value)))) addr, port = split(value, 4) return LinkByIPv4(_unpack_ipv4_address(addr), Size.SHORT.unpack(port)) @@ -604,16 +613,16 @@ class LinkByIPv6(LinkSpecifier): :var int port: relay ORPort """ - def __init__(self, address, port): + def __init__(self, address: str, port: int) -> None: super(LinkByIPv6, self).__init__(1, _pack_ipv6_address(address) + Size.SHORT.pack(port)) self.address = address self.port = port @staticmethod - def unpack(value): + def unpack(value: bytes) -> 'stem.client.datatype.LinkByIPv6': if len(value) != 18: - raise ValueError('IPv6 link specifiers should be eighteen bytes, but was %i instead: %s' % (len(value), binascii.hexlify(value))) + raise ValueError('IPv6 link specifiers should be eighteen bytes, but was %i instead: %s' % (len(value), stem.util.str_tools._to_unicode(binascii.hexlify(value)))) addr, port = split(value, 16) return LinkByIPv6(_unpack_ipv6_address(addr), Size.SHORT.unpack(port)) @@ -628,11 +637,11 @@ class LinkByFingerprint(LinkSpecifier): :var str fingerprint: relay sha1 fingerprint """ - def __init__(self, value): + def __init__(self, value: bytes) -> None: super(LinkByFingerprint, self).__init__(2, value) if len(value) != 20: - raise ValueError('Fingerprint link specifiers should be twenty bytes, but was %i instead: %s' % (len(value), binascii.hexlify(value))) + raise ValueError('Fingerprint link specifiers should be twenty bytes, but was %i instead: %s' % (len(value), stem.util.str_tools._to_unicode(binascii.hexlify(value)))) self.fingerprint = stem.util.str_tools._to_unicode(value) @@ -646,11 +655,11 @@ class LinkByEd25519(LinkSpecifier): :var str fingerprint: relay ed25519 fingerprint """ - def __init__(self, value): + def __init__(self, value: bytes) -> None: super(LinkByEd25519, self).__init__(3, value) if len(value) != 32: - raise ValueError('Fingerprint link specifiers should be thirty two bytes, but was %i instead: %s' % (len(value), binascii.hexlify(value))) + raise ValueError('Fingerprint link specifiers should be thirty two bytes, but was %i instead: %s' % (len(value), stem.util.str_tools._to_unicode(binascii.hexlify(value)))) self.fingerprint = stem.util.str_tools._to_unicode(value) @@ -668,7 +677,7 @@ class KDF(collections.namedtuple('KDF', ['key_hash', 'forward_digest', 'backward """ @staticmethod - def from_value(key_material): + def from_value(key_material: bytes) -> 'stem.client.datatype.KDF': # Derived key material, as per... # # K = H(K0 | [00]) | H(K0 | [01]) | H(K0 | [02]) | ... @@ -689,19 +698,19 @@ def from_value(key_material): return KDF(key_hash, forward_digest, backward_digest, forward_key, backward_key) -def _pack_ipv4_address(address): +def _pack_ipv4_address(address: str) -> bytes: return b''.join([Size.CHAR.pack(int(v)) for v in address.split('.')]) -def _unpack_ipv4_address(value): +def _unpack_ipv4_address(value: bytes) -> str: return '.'.join([str(Size.CHAR.unpack(value[i:i + 1])) for i in range(4)]) -def _pack_ipv6_address(address): +def _pack_ipv6_address(address: str) -> bytes: return b''.join([Size.SHORT.pack(int(v, 16)) for v in address.split(':')]) -def _unpack_ipv6_address(value): +def _unpack_ipv6_address(value: bytes) -> str: return ':'.join(['%04x' % Size.SHORT.unpack(value[i * 2:(i + 1) * 2]) for i in range(8)]) diff --git a/stem/connection.py b/stem/connection.py index e3032784f..ff950a0c3 100644 --- a/stem/connection.py +++ b/stem/connection.py @@ -135,6 +135,7 @@ import stem.control import stem.response +import stem.response.protocolinfo import stem.socket import stem.util.connection import stem.util.enum @@ -142,6 +143,7 @@ import stem.util.system import stem.version +from typing import Any, List, Optional, Sequence, Tuple, Type, Union from stem.util import log AuthMethod = stem.util.enum.Enum('NONE', 'PASSWORD', 'COOKIE', 'SAFECOOKIE', 'UNKNOWN') @@ -209,7 +211,7 @@ ) -def connect(control_port = ('127.0.0.1', 'default'), control_socket = '/var/run/tor/control', password = None, password_prompt = False, chroot_path = None, controller = stem.control.Controller): +def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: Type = stem.control.Controller) -> Any: """ Convenience function for quickly getting a control connection. This is very handy for debugging or CLI setup, handling setup and prompting for a password @@ -234,7 +236,7 @@ def connect(control_port = ('127.0.0.1', 'default'), control_socket = '/var/run/ Use both port 9051 and 9151 by default. :param tuple contol_port: address and port tuple, for instance **('127.0.0.1', 9051)** - :param str path: path where the control socket is located + :param str control_socket: path where the control socket is located :param str password: passphrase to authenticate to the socket :param bool password_prompt: prompt for the controller password if it wasn't supplied @@ -248,6 +250,8 @@ def connect(control_port = ('127.0.0.1', 'default'), control_socket = '/var/run/ **control_port** and **control_socket** are **None** """ + # TODO: change this function's API so we can provide a concrete type + if control_port is None and control_socket is None: raise ValueError('Neither a control port nor control socket were provided. Nothing to connect to.') elif control_port: @@ -258,7 +262,8 @@ def connect(control_port = ('127.0.0.1', 'default'), control_socket = '/var/run/ elif control_port[1] != 'default' and not stem.util.connection.is_valid_port(control_port[1]): raise ValueError("'%s' isn't a valid port" % control_port[1]) - control_connection, error_msg = None, '' + control_connection = None # type: Optional[stem.socket.ControlSocket] + error_msg = '' if control_socket: if os.path.exists(control_socket): @@ -295,7 +300,7 @@ def connect(control_port = ('127.0.0.1', 'default'), control_socket = '/var/run/ return _connect_auth(control_connection, password, password_prompt, chroot_path, controller) -def _connect_auth(control_socket, password, password_prompt, chroot_path, controller): +def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[stem.control.BaseController]]) -> Any: """ Helper for the connect_* functions that authenticates the socket and constructs the controller. @@ -361,7 +366,7 @@ def _connect_auth(control_socket, password, password_prompt, chroot_path, contro return None -def authenticate(controller, password = None, chroot_path = None, protocolinfo_response = None): +def authenticate(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], password: Optional[str] = None, chroot_path: Optional[str] = None, protocolinfo_response: Optional[stem.response.protocolinfo.ProtocolInfoResponse] = None) -> None: """ Authenticates to a control socket using the information provided by a PROTOCOLINFO response. In practice this will often be all we need to @@ -479,7 +484,7 @@ def authenticate(controller, password = None, chroot_path = None, protocolinfo_r raise AuthenticationFailure('socket connection failed (%s)' % exc) auth_methods = list(protocolinfo_response.auth_methods) - auth_exceptions = [] + auth_exceptions = [] # type: List[stem.connection.AuthenticationFailure] if len(auth_methods) == 0: raise NoAuthMethods('our PROTOCOLINFO response did not have any methods for authenticating') @@ -575,7 +580,7 @@ def authenticate(controller, password = None, chroot_path = None, protocolinfo_r raise AssertionError('BUG: Authentication failed without providing a recognized exception: %s' % str(auth_exceptions)) -def authenticate_none(controller, suppress_ctl_errors = True): +def authenticate_none(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], suppress_ctl_errors: bool = True) -> None: """ Authenticates to an open control socket. All control connections need to authenticate before they can be used, even if tor hasn't been configured to @@ -622,7 +627,7 @@ def authenticate_none(controller, suppress_ctl_errors = True): raise OpenAuthRejected('Socket failed (%s)' % exc) -def authenticate_password(controller, password, suppress_ctl_errors = True): +def authenticate_password(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], password: str, suppress_ctl_errors: bool = True) -> None: """ Authenticates to a control socket that uses a password (via the HashedControlPassword torrc option). Quotes in the password are escaped. @@ -692,7 +697,7 @@ def authenticate_password(controller, password, suppress_ctl_errors = True): raise PasswordAuthRejected('Socket failed (%s)' % exc) -def authenticate_cookie(controller, cookie_path, suppress_ctl_errors = True): +def authenticate_cookie(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], cookie_path: str, suppress_ctl_errors: bool = True) -> None: """ Authenticates to a control socket that uses the contents of an authentication cookie (generated via the CookieAuthentication torrc option). This does basic @@ -782,7 +787,7 @@ def authenticate_cookie(controller, cookie_path, suppress_ctl_errors = True): raise CookieAuthRejected('Socket failed (%s)' % exc, cookie_path, False) -def authenticate_safecookie(controller, cookie_path, suppress_ctl_errors = True): +def authenticate_safecookie(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], cookie_path: str, suppress_ctl_errors: bool = True) -> None: """ Authenticates to a control socket using the safe cookie method, which is enabled by setting the CookieAuthentication torrc option on Tor client's which @@ -844,10 +849,11 @@ def authenticate_safecookie(controller, cookie_path, suppress_ctl_errors = True) cookie_data = _read_cookie(cookie_path, True) client_nonce = os.urandom(32) + authchallenge_response = None # type: stem.response.authchallenge.AuthChallengeResponse try: client_nonce_hex = stem.util.str_tools._to_unicode(binascii.b2a_hex(client_nonce)) - authchallenge_response = _msg(controller, 'AUTHCHALLENGE SAFECOOKIE %s' % client_nonce_hex) + authchallenge_response = _msg(controller, 'AUTHCHALLENGE SAFECOOKIE %s' % client_nonce_hex) # type: ignore if not authchallenge_response.is_ok(): try: @@ -860,13 +866,18 @@ def authenticate_safecookie(controller, cookie_path, suppress_ctl_errors = True) if 'Authentication required.' in authchallenge_response_str: raise AuthChallengeUnsupported("SAFECOOKIE authentication isn't supported", cookie_path) elif 'AUTHCHALLENGE only supports' in authchallenge_response_str: - raise UnrecognizedAuthChallengeMethod(authchallenge_response_str, cookie_path) + # TODO: This code path has been broken for years. Do we still need it? + # If so, what should authchallenge_method be? + + authchallenge_method = None + + raise UnrecognizedAuthChallengeMethod(authchallenge_response_str, cookie_path, authchallenge_method) elif 'Invalid base16 client nonce' in authchallenge_response_str: raise InvalidClientNonce(authchallenge_response_str, cookie_path) elif 'Cookie authentication is disabled' in authchallenge_response_str: raise CookieAuthRejected(authchallenge_response_str, cookie_path, True) else: - raise AuthChallengeFailed(authchallenge_response, cookie_path) + raise AuthChallengeFailed(authchallenge_response_str, cookie_path) except stem.ControllerError as exc: try: controller.connect() @@ -876,7 +887,7 @@ def authenticate_safecookie(controller, cookie_path, suppress_ctl_errors = True) if not suppress_ctl_errors: raise else: - raise AuthChallengeFailed('Socket failed (%s)' % exc, cookie_path, True) + raise AuthChallengeFailed('Socket failed (%s)' % exc, cookie_path) try: stem.response.convert('AUTHCHALLENGE', authchallenge_response) @@ -931,7 +942,7 @@ def authenticate_safecookie(controller, cookie_path, suppress_ctl_errors = True) raise CookieAuthRejected(str(auth_response), cookie_path, True, auth_response) -def get_protocolinfo(controller): +def get_protocolinfo(controller: Union[stem.control.BaseController, stem.socket.ControlSocket]) -> stem.response.protocolinfo.ProtocolInfoResponse: """ Issues a PROTOCOLINFO query to a control socket, getting information about the tor process running on it. If the socket is already closed then it is @@ -968,10 +979,10 @@ def get_protocolinfo(controller): raise stem.SocketError(exc) stem.response.convert('PROTOCOLINFO', protocolinfo_response) - return protocolinfo_response + return protocolinfo_response # type: ignore -def _msg(controller, message): +def _msg(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], message: str) -> stem.response.ControlMessage: """ Sends and receives a message with either a :class:`~stem.socket.ControlSocket` or :class:`~stem.control.BaseController`. @@ -984,7 +995,7 @@ def _msg(controller, message): return controller.msg(message) -def _connection_for_default_port(address): +def _connection_for_default_port(address: str) -> stem.socket.ControlPort: """ Attempts to provide a controller connection for either port 9051 (default for relays) or 9151 (default for Tor Browser). If both fail then this raises the @@ -1006,7 +1017,7 @@ def _connection_for_default_port(address): raise exc -def _read_cookie(cookie_path, is_safecookie): +def _read_cookie(cookie_path: str, is_safecookie: bool) -> bytes: """ Provides the contents of a given cookie file. @@ -1014,6 +1025,8 @@ def _read_cookie(cookie_path, is_safecookie): :param bool is_safecookie: **True** if this was for SAFECOOKIE authentication, **False** if for COOKIE + :returns: **bytes** with the cookie file content + :raises: * :class:`stem.connection.UnreadableCookieFile` if the cookie file is unreadable @@ -1048,12 +1061,12 @@ def _read_cookie(cookie_path, is_safecookie): raise UnreadableCookieFile(exc_msg, cookie_path, is_safecookie) -def _hmac_sha256(key, msg): +def _hmac_sha256(key: bytes, msg: bytes) -> bytes: """ Generates a sha256 digest using the given key and message. - :param str key: starting key for the hash - :param str msg: message to be hashed + :param bytes key: starting key for the hash + :param bytes msg: message to be hashed :returns: sha256 digest of msg as bytes, hashed using the given key """ @@ -1065,11 +1078,11 @@ class AuthenticationFailure(Exception): """ Base error for authentication failures. - :var stem.socket.ControlMessage auth_response: AUTHENTICATE response from the + :var stem.response.ControlMessage auth_response: AUTHENTICATE response from the control socket, **None** if one wasn't received """ - def __init__(self, message, auth_response = None): + def __init__(self, message: str, auth_response: Optional[stem.response.ControlMessage] = None) -> None: super(AuthenticationFailure, self).__init__(message) self.auth_response = auth_response @@ -1081,7 +1094,7 @@ class UnrecognizedAuthMethods(AuthenticationFailure): :var list unknown_auth_methods: authentication methods that weren't recognized """ - def __init__(self, message, unknown_auth_methods): + def __init__(self, message: str, unknown_auth_methods: Sequence[str]) -> None: super(UnrecognizedAuthMethods, self).__init__(message) self.unknown_auth_methods = unknown_auth_methods @@ -1125,7 +1138,7 @@ class CookieAuthFailed(AuthenticationFailure): authentication attempt """ - def __init__(self, message, cookie_path, is_safecookie, auth_response = None): + def __init__(self, message: str, cookie_path: str, is_safecookie: bool, auth_response: Optional[stem.response.ControlMessage] = None) -> None: super(CookieAuthFailed, self).__init__(message, auth_response) self.is_safecookie = is_safecookie self.cookie_path = cookie_path @@ -1152,7 +1165,7 @@ class AuthChallengeFailed(CookieAuthFailed): AUTHCHALLENGE command has failed. """ - def __init__(self, message, cookie_path): + def __init__(self, message: str, cookie_path: str) -> None: super(AuthChallengeFailed, self).__init__(message, cookie_path, True) @@ -1169,7 +1182,7 @@ class UnrecognizedAuthChallengeMethod(AuthChallengeFailed): :var str authchallenge_method: AUTHCHALLENGE method that Tor couldn't recognize """ - def __init__(self, message, cookie_path, authchallenge_method): + def __init__(self, message: str, cookie_path: str, authchallenge_method: str) -> None: super(UnrecognizedAuthChallengeMethod, self).__init__(message, cookie_path) self.authchallenge_method = authchallenge_method @@ -1201,7 +1214,7 @@ class NoAuthCookie(MissingAuthInfo): authentication, **False** if for COOKIE """ - def __init__(self, message, is_safecookie): + def __init__(self, message: str, is_safecookie: bool) -> None: super(NoAuthCookie, self).__init__(message) self.is_safecookie = is_safecookie diff --git a/stem/control.py b/stem/control.py index 4016e7625..626b2b3e3 100644 --- a/stem/control.py +++ b/stem/control.py @@ -255,7 +255,9 @@ import stem.descriptor.server_descriptor import stem.exit_policy import stem.response +import stem.response.add_onion import stem.response.events +import stem.response.protocolinfo import stem.socket import stem.util import stem.util.conf @@ -268,6 +270,8 @@ from stem import UNDEFINED, CircStatus, Signal from stem.util import log +from types import TracebackType +from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union # When closing the controller we attempt to finish processing enqueued events, # but if it takes longer than this we terminate. @@ -400,7 +404,7 @@ server descriptors. As of Tor version 0.2.3.25 it downloads microdescriptors \ instead unless you set 'UseMicrodescriptors 0' in your torrc." -EVENT_DESCRIPTIONS = None +EVENT_DESCRIPTIONS = None # type: Dict[str, str] class AccountingStats(collections.namedtuple('AccountingStats', ['retrieved', 'status', 'interval_end', 'time_until_reset', 'read_bytes', 'read_bytes_left', 'read_limit', 'written_bytes', 'write_bytes_left', 'write_limit'])): @@ -447,15 +451,15 @@ class CreateHiddenServiceOutput(collections.namedtuple('CreateHiddenServiceOutpu """ -def with_default(yields = False): +def with_default(yields: bool = False) -> Callable: """ Provides a decorator to support having a default value. This should be treated as private. """ - def decorator(func): - def get_default(func, args, kwargs): - arg_names = inspect.getargspec(func).args[1:] # drop 'self' + def decorator(func: Callable) -> Callable: + def get_default(func: Callable, args: Any, kwargs: Any) -> Any: + arg_names = inspect.getfullargspec(func).args[1:] # drop 'self' default_position = arg_names.index('default') if 'default' in arg_names else None if default_position is not None and default_position < len(args): @@ -465,7 +469,7 @@ def get_default(func, args, kwargs): if not yields: @functools.wraps(func) - def wrapped(self, *args, **kwargs): + def wrapped(self, *args: Any, **kwargs: Any) -> Any: try: return func(self, *args, **kwargs) except: @@ -477,7 +481,7 @@ def wrapped(self, *args, **kwargs): return default else: @functools.wraps(func) - def wrapped(self, *args, **kwargs): + def wrapped(self, *args: Any, **kwargs: Any) -> Any: try: for val in func(self, *args, **kwargs): yield val @@ -496,7 +500,7 @@ def wrapped(self, *args, **kwargs): return decorator -def event_description(event): +def event_description(event: str) -> str: """ Provides a description for Tor events. @@ -514,7 +518,7 @@ def event_description(event): try: config.load(config_path) - EVENT_DESCRIPTIONS = dict([(key.lower()[18:], config.get_value(key)) for key in config.keys() if key.startswith('event.description.')]) + EVENT_DESCRIPTIONS = dict([(key.lower()[18:], config.get_value(key)) for key in config.keys() if key.startswith('event.description.')]) # type: ignore except Exception as exc: log.warn("BUG: stem failed to load its internal manual information from '%s': %s" % (config_path, exc)) return None @@ -538,23 +542,23 @@ class BaseController(object): socket as though it hasn't yet been authenticated. """ - def __init__(self, control_socket, is_authenticated = False): + def __init__(self, control_socket: stem.socket.ControlSocket, is_authenticated: bool = False) -> None: self._socket = control_socket self._msg_lock = threading.RLock() - self._status_listeners = [] # tuples of the form (callback, spawn_thread) + self._status_listeners = [] # type: List[Tuple[Callable[[stem.control.BaseController, stem.control.State, float], None], bool]] # tuples of the form (callback, spawn_thread) self._status_listeners_lock = threading.RLock() # queues where incoming messages are directed - self._reply_queue = queue.Queue() - self._event_queue = queue.Queue() + self._reply_queue = queue.Queue() # type: queue.Queue[Union[stem.response.ControlMessage, stem.ControllerError]] + self._event_queue = queue.Queue() # type: queue.Queue[stem.response.ControlMessage] # thread to continually pull from the control socket - self._reader_thread = None + self._reader_thread = None # type: Optional[threading.Thread] # thread to pull from the _event_queue and call handle_event self._event_notice = threading.Event() - self._event_thread = None + self._event_thread = None # type: Optional[threading.Thread] # saves our socket's prior _connect() and _close() methods so they can be # called along with ours @@ -562,13 +566,13 @@ def __init__(self, control_socket, is_authenticated = False): self._socket_connect = self._socket._connect self._socket_close = self._socket._close - self._socket._connect = self._connect - self._socket._close = self._close + self._socket._connect = self._connect # type: ignore + self._socket._close = self._close # type: ignore self._last_heartbeat = 0.0 # timestamp for when we last heard from tor self._is_authenticated = False - self._state_change_threads = [] # threads we've spawned to notify of state changes + self._state_change_threads = [] # type: List[threading.Thread] # threads we've spawned to notify of state changes if self._socket.is_alive(): self._launch_threads() @@ -576,7 +580,7 @@ def __init__(self, control_socket, is_authenticated = False): if is_authenticated: self._post_authentication() - def msg(self, message): + def msg(self, message: str) -> stem.response.ControlMessage: """ Sends a message to our control socket and provides back its reply. @@ -659,7 +663,7 @@ def msg(self, message): self.close() raise - def is_alive(self): + def is_alive(self) -> bool: """ Checks if our socket is currently connected. This is a pass-through for our socket's :func:`~stem.socket.BaseSocket.is_alive` method. @@ -669,7 +673,7 @@ def is_alive(self): return self._socket.is_alive() - def is_localhost(self): + def is_localhost(self) -> bool: """ Returns if the connection is for the local system or not. @@ -680,7 +684,7 @@ def is_localhost(self): return self._socket.is_localhost() - def connection_time(self): + def connection_time(self) -> float: """ Provides the unix timestamp for when our socket was either connected or disconnected. That is to say, the time we connected if we're currently @@ -694,7 +698,7 @@ def connection_time(self): return self._socket.connection_time() - def is_authenticated(self): + def is_authenticated(self) -> bool: """ Checks if our socket is both connected and authenticated. @@ -704,7 +708,7 @@ def is_authenticated(self): return self._is_authenticated if self.is_alive() else False - def connect(self): + def connect(self) -> None: """ Reconnects our control socket. This is a pass-through for our socket's :func:`~stem.socket.ControlSocket.connect` method. @@ -714,7 +718,7 @@ def connect(self): self._socket.connect() - def close(self): + def close(self) -> None: """ Closes our socket connection. This is a pass-through for our socket's :func:`~stem.socket.BaseSocket.close` method. @@ -733,7 +737,7 @@ def close(self): if t.is_alive() and threading.current_thread() != t: t.join() - def get_socket(self): + def get_socket(self) -> stem.socket.ControlSocket: """ Provides the socket used to speak with the tor process. Communicating with the socket directly isn't advised since it may confuse this controller. @@ -743,7 +747,7 @@ def get_socket(self): return self._socket - def get_latest_heartbeat(self): + def get_latest_heartbeat(self) -> float: """ Provides the unix timestamp for when we last heard from tor. This is zero if we've never received a message. @@ -753,7 +757,7 @@ def get_latest_heartbeat(self): return self._last_heartbeat - def add_status_listener(self, callback, spawn = True): + def add_status_listener(self, callback: Callable[['stem.control.BaseController', 'stem.control.State', float], None], spawn: bool = True) -> None: """ Notifies a given function when the state of our socket changes. Functions are expected to be of the form... @@ -783,7 +787,7 @@ def add_status_listener(self, callback, spawn = True): with self._status_listeners_lock: self._status_listeners.append((callback, spawn)) - def remove_status_listener(self, callback): + def remove_status_listener(self, callback: Callable[['stem.control.Controller', 'stem.control.State', float], None]) -> bool: """ Stops listener from being notified of further events. @@ -805,13 +809,13 @@ def remove_status_listener(self, callback): self._status_listeners = new_listeners return is_changed - def __enter__(self): + def __enter__(self) -> 'stem.control.BaseController': return self - def __exit__(self, exit_type, value, traceback): + def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: self.close() - def _handle_event(self, event_message): + def _handle_event(self, event_message: stem.response.ControlMessage) -> None: """ Callback to be overwritten by subclasses for event listening. This is notified whenever we receive an event from the control socket. @@ -822,13 +826,13 @@ def _handle_event(self, event_message): pass - def _connect(self): + def _connect(self) -> None: self._launch_threads() self._notify_status_listeners(State.INIT) self._socket_connect() self._is_authenticated = False - def _close(self): + def _close(self) -> None: # Our is_alive() state is now false. Our reader thread should already be # awake from recv() raising a closure exception. Wake up the event thread # too so it can end. @@ -846,12 +850,12 @@ def _close(self): self._socket_close() - def _post_authentication(self): + def _post_authentication(self) -> None: # actions to be taken after we have a newly authenticated connection self._is_authenticated = True - def _notify_status_listeners(self, state): + def _notify_status_listeners(self, state: 'stem.control.State') -> None: """ Informs our status listeners that a state change occurred. @@ -895,7 +899,7 @@ def _notify_status_listeners(self, state): else: listener(self, state, change_timestamp) - def _launch_threads(self): + def _launch_threads(self) -> None: """ Initializes daemon threads. Threads can't be reused so we need to recreate them if we're restarted. @@ -915,7 +919,7 @@ def _launch_threads(self): self._event_thread.setDaemon(True) self._event_thread.start() - def _reader_loop(self): + def _reader_loop(self) -> None: """ Continually pulls from the control socket, directing the messages into queues based on their type. Controller messages come in two varieties... @@ -944,7 +948,7 @@ def _reader_loop(self): self._reply_queue.put(exc) - def _event_loop(self): + def _event_loop(self) -> None: """ Continually pulls messages from the _event_queue and sends them to our handle_event callback. This is done via its own thread so subclasses with a @@ -982,7 +986,7 @@ class Controller(BaseController): """ @staticmethod - def from_port(address = '127.0.0.1', port = 'default'): + def from_port(address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control.Controller': """ Constructs a :class:`~stem.socket.ControlPort` based Controller. @@ -1011,12 +1015,12 @@ def from_port(address = '127.0.0.1', port = 'default'): if port == 'default': control_port = stem.connection._connection_for_default_port(address) else: - control_port = stem.socket.ControlPort(address, port) + control_port = stem.socket.ControlPort(address, int(port)) return Controller(control_port) @staticmethod - def from_socket_file(path = '/var/run/tor/control'): + def from_socket_file(path: str = '/var/run/tor/control') -> 'stem.control.Controller': """ Constructs a :class:`~stem.socket.ControlSocketFile` based Controller. @@ -1030,35 +1034,35 @@ def from_socket_file(path = '/var/run/tor/control'): control_socket = stem.socket.ControlSocketFile(path) return Controller(control_socket) - def __init__(self, control_socket, is_authenticated = False): + def __init__(self, control_socket: stem.socket.ControlSocket, is_authenticated: bool = False) -> None: self._is_caching_enabled = True - self._request_cache = {} + self._request_cache = {} # type: Dict[str, Any] self._last_newnym = 0.0 self._cache_lock = threading.RLock() # mapping of event types to their listeners - self._event_listeners = {} + self._event_listeners = {} # type: Dict[stem.control.EventType, List[Callable[[stem.response.events.Event], None]]] self._event_listeners_lock = threading.RLock() - self._enabled_features = [] + self._enabled_features = [] # type: List[str] - self._last_address_exc = None - self._last_fingerprint_exc = None + self._last_address_exc = None # type: Optional[BaseException] + self._last_fingerprint_exc = None # type: Optional[BaseException] super(Controller, self).__init__(control_socket, is_authenticated) - def _sighup_listener(event): + def _sighup_listener(event: stem.response.events.SignalEvent) -> None: if event.signal == Signal.RELOAD: self.clear_cache() self._notify_status_listeners(State.RESET) - self.add_event_listener(_sighup_listener, EventType.SIGNAL) + self.add_event_listener(_sighup_listener, EventType.SIGNAL) # type: ignore - def _confchanged_listener(event): + def _confchanged_listener(event: stem.response.events.ConfChangedEvent) -> None: if self.is_caching_enabled(): to_cache_changed = dict((k.lower(), v) for k, v in event.changed.items()) - to_cache_unset = dict((k.lower(), []) for k in event.unset) # [] represents None value in cache + to_cache_unset = dict((k.lower(), []) for k in event.unset) # type: Dict[str, List[str]] # [] represents None value in cache to_cache = {} to_cache.update(to_cache_changed) @@ -1068,21 +1072,21 @@ def _confchanged_listener(event): self._confchanged_cache_invalidation(to_cache) - self.add_event_listener(_confchanged_listener, EventType.CONF_CHANGED) + self.add_event_listener(_confchanged_listener, EventType.CONF_CHANGED) # type: ignore - def _address_changed_listener(event): + def _address_changed_listener(event: stem.response.events.StatusEvent) -> None: if event.action in ('EXTERNAL_ADDRESS', 'DNS_USELESS'): self._set_cache({'exit_policy': None}) self._set_cache({'address': None}, 'getinfo') self._last_address_exc = None - self.add_event_listener(_address_changed_listener, EventType.STATUS_SERVER) + self.add_event_listener(_address_changed_listener, EventType.STATUS_SERVER) # type: ignore - def close(self): + def close(self) -> None: self.clear_cache() super(Controller, self).close() - def authenticate(self, *args, **kwargs): + def authenticate(self, *args: Any, **kwargs: Any) -> None: """ A convenience method to authenticate the controller. This is just a pass-through to :func:`stem.connection.authenticate`. @@ -1091,7 +1095,7 @@ def authenticate(self, *args, **kwargs): import stem.connection stem.connection.authenticate(self, *args, **kwargs) - def reconnect(self, *args, **kwargs): + def reconnect(self, *args: Any, **kwargs: Any) -> None: """ Reconnects and authenticates to our control socket. @@ -1108,7 +1112,7 @@ def reconnect(self, *args, **kwargs): self.authenticate(*args, **kwargs) @with_default() - def get_info(self, params, default = UNDEFINED, get_bytes = False): + def get_info(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, get_bytes: bool = False) -> Union[str, Dict[str, str]]: """ get_info(params, default = UNDEFINED, get_bytes = False) @@ -1148,15 +1152,15 @@ def get_info(self, params, default = UNDEFINED, get_bytes = False): if isinstance(params, (bytes, str)): is_multiple = False - params = set([params]) + param_set = set([params]) else: if not params: return {} is_multiple = True - params = set(params) + param_set = set(params) - for param in params: + for param in param_set: if param.startswith('ip-to-country/') and param != 'ip-to-country/0.0.0.0' and self.get_info('ip-to-country/ipv4-available', '0') != '1': raise stem.ProtocolError('Tor geoip database is unavailable') elif param == 'address' and self._last_address_exc: @@ -1166,16 +1170,16 @@ def get_info(self, params, default = UNDEFINED, get_bytes = False): # check for cached results - from_cache = [param.lower() for param in params] + from_cache = [param.lower() for param in param_set] cached_results = self._get_cache_map(from_cache, 'getinfo') for key in cached_results: - user_expected_key = _case_insensitive_lookup(params, key) + user_expected_key = _case_insensitive_lookup(param_set, key) reply[user_expected_key] = cached_results[key] - params.remove(user_expected_key) + param_set.remove(user_expected_key) # if everything was cached then short circuit making the query - if not params: + if not param_set: if LOG_CACHE_FETCHES: log.trace('GETINFO %s (cache fetch)' % ' '.join(reply.keys())) @@ -1185,14 +1189,13 @@ def get_info(self, params, default = UNDEFINED, get_bytes = False): return list(reply.values())[0] try: - response = self.msg('GETINFO %s' % ' '.join(params)) - stem.response.convert('GETINFO', response) - response._assert_matches(params) + response = stem.response._convert_to_getinfo(self.msg('GETINFO %s' % ' '.join(param_set))) + response._assert_matches(param_set) # usually we want unicode values under python 3.x if not get_bytes: - response.entries = dict((k, stem.util.str_tools._to_unicode(v)) for (k, v) in response.entries.items()) + response.entries = dict((k, stem.util.str_tools._to_unicode(v)) for (k, v) in response.entries.items()) # type: ignore reply.update(response.entries) @@ -1209,30 +1212,30 @@ def get_info(self, params, default = UNDEFINED, get_bytes = False): self._set_cache(to_cache, 'getinfo') - if 'address' in params: + if 'address' in param_set: self._last_address_exc = None - if 'fingerprint' in params: + if 'fingerprint' in param_set: self._last_fingerprint_exc = None - log.debug('GETINFO %s (runtime: %0.4f)' % (' '.join(params), time.time() - start_time)) + log.debug('GETINFO %s (runtime: %0.4f)' % (' '.join(param_set), time.time() - start_time)) if is_multiple: return reply else: return list(reply.values())[0] except stem.ControllerError as exc: - if 'address' in params: + if 'address' in param_set: self._last_address_exc = exc - if 'fingerprint' in params: + if 'fingerprint' in param_set: self._last_fingerprint_exc = exc - log.debug('GETINFO %s (failed: %s)' % (' '.join(params), exc)) + log.debug('GETINFO %s (failed: %s)' % (' '.join(param_set), exc)) raise @with_default() - def get_version(self, default = UNDEFINED): + def get_version(self, default: Any = UNDEFINED) -> stem.version.Version: """ get_version(default = UNDEFINED) @@ -1261,7 +1264,7 @@ def get_version(self, default = UNDEFINED): return version @with_default() - def get_exit_policy(self, default = UNDEFINED): + def get_exit_policy(self, default: Any = UNDEFINED) -> stem.exit_policy.ExitPolicy: """ get_exit_policy(default = UNDEFINED) @@ -1293,7 +1296,7 @@ def get_exit_policy(self, default = UNDEFINED): return policy @with_default() - def get_ports(self, listener_type, default = UNDEFINED): + def get_ports(self, listener_type: 'stem.control.Listener', default: Any = UNDEFINED) -> Sequence[int]: """ get_ports(listener_type, default = UNDEFINED) @@ -1315,7 +1318,7 @@ def get_ports(self, listener_type, default = UNDEFINED): and no default was provided """ - def is_localhost(address): + def is_localhost(address: str) -> bool: if stem.util.connection.is_valid_ipv4_address(address): return address == '0.0.0.0' or address.startswith('127.') elif stem.util.connection.is_valid_ipv6_address(address): @@ -1330,7 +1333,7 @@ def is_localhost(address): return [port for (addr, port) in self.get_listeners(listener_type) if is_localhost(addr)] @with_default() - def get_listeners(self, listener_type, default = UNDEFINED): + def get_listeners(self, listener_type: 'stem.control.Listener', default: Any = UNDEFINED) -> Sequence[Tuple[str, int]]: """ get_listeners(listener_type, default = UNDEFINED) @@ -1359,7 +1362,7 @@ def get_listeners(self, listener_type, default = UNDEFINED): if listeners is None: proxy_addrs = [] - query = 'net/listeners/%s' % listener_type.lower() + query = 'net/listeners/%s' % str(listener_type).lower() try: for listener in self.get_info(query).split(): @@ -1409,7 +1412,7 @@ def get_listeners(self, listener_type, default = UNDEFINED): Listener.CONTROL: 'ControlListenAddress', }[listener_type] - port_value = self.get_conf(port_option).split()[0] + port_value = self._get_conf_single(port_option).split()[0] for listener in self.get_conf(listener_option, multiple = True): if ':' in listener: @@ -1436,7 +1439,7 @@ def get_listeners(self, listener_type, default = UNDEFINED): return listeners @with_default() - def get_accounting_stats(self, default = UNDEFINED): + def get_accounting_stats(self, default: Any = UNDEFINED) -> 'stem.control.AccountingStats': """ get_accounting_stats(default = UNDEFINED) @@ -1480,7 +1483,7 @@ def get_accounting_stats(self, default = UNDEFINED): ) @with_default() - def get_protocolinfo(self, default = UNDEFINED): + def get_protocolinfo(self, default: Any = UNDEFINED) -> stem.response.protocolinfo.ProtocolInfoResponse: """ get_protocolinfo(default = UNDEFINED) @@ -1503,7 +1506,7 @@ def get_protocolinfo(self, default = UNDEFINED): return stem.connection.get_protocolinfo(self) @with_default() - def get_user(self, default = UNDEFINED): + def get_user(self, default: Any = UNDEFINED) -> str: """ get_user(default = UNDEFINED) @@ -1538,7 +1541,7 @@ def get_user(self, default = UNDEFINED): raise ValueError("Unable to resolve tor's user" if self.is_localhost() else "Tor isn't running locally") @with_default() - def get_pid(self, default = UNDEFINED): + def get_pid(self, default: Any = UNDEFINED) -> int: """ get_pid(default = UNDEFINED) @@ -1567,7 +1570,7 @@ def get_pid(self, default = UNDEFINED): pid = int(getinfo_pid) if not pid and self.is_localhost(): - pid_file_path = self.get_conf('PidFile', None) + pid_file_path = self._get_conf_single('PidFile', None) if pid_file_path is not None: with open(pid_file_path) as pid_file: @@ -1594,7 +1597,7 @@ def get_pid(self, default = UNDEFINED): raise ValueError("Unable to resolve tor's pid" if self.is_localhost() else "Tor isn't running locally") @with_default() - def get_start_time(self, default = UNDEFINED): + def get_start_time(self, default: Any = UNDEFINED) -> float: """ get_start_time(default = UNDEFINED) @@ -1644,7 +1647,7 @@ def get_start_time(self, default = UNDEFINED): raise ValueError("Unable to resolve when tor began" if self.is_localhost() else "Tor isn't running locally") @with_default() - def get_uptime(self, default = UNDEFINED): + def get_uptime(self, default: Any = UNDEFINED) -> float: """ get_uptime(default = UNDEFINED) @@ -1662,7 +1665,7 @@ def get_uptime(self, default = UNDEFINED): return time.time() - self.get_start_time() - def is_user_traffic_allowed(self): + def is_user_traffic_allowed(self) -> 'stem.control.UserTrafficAllowed': """ Checks if we're likely to service direct user traffic. This essentially boils down to... @@ -1683,7 +1686,7 @@ def is_user_traffic_allowed(self): .. versionadded:: 1.5.0 - :returns: :class:`~stem.cotroller.UserTrafficAllowed` with **inbound** and + :returns: :class:`~stem.control.UserTrafficAllowed` with **inbound** and **outbound** boolean attributes to indicate if we're likely servicing direct user traffic """ @@ -1704,7 +1707,7 @@ def is_user_traffic_allowed(self): return UserTrafficAllowed(inbound_allowed, outbound_allowed) @with_default() - def get_microdescriptor(self, relay = None, default = UNDEFINED): + def get_microdescriptor(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.microdescriptor.Microdescriptor: """ get_microdescriptor(relay = None, default = UNDEFINED) @@ -1762,7 +1765,7 @@ def get_microdescriptor(self, relay = None, default = UNDEFINED): return stem.descriptor.microdescriptor.Microdescriptor(desc_content) @with_default(yields = True) - def get_microdescriptors(self, default = UNDEFINED): + def get_microdescriptors(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.microdescriptor.Microdescriptor]: """ get_microdescriptors(default = UNDEFINED) @@ -1793,7 +1796,7 @@ def get_microdescriptors(self, default = UNDEFINED): yield desc @with_default() - def get_server_descriptor(self, relay = None, default = UNDEFINED): + def get_server_descriptor(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.server_descriptor.RelayDescriptor: """ get_server_descriptor(relay = None, default = UNDEFINED) @@ -1856,7 +1859,7 @@ def get_server_descriptor(self, relay = None, default = UNDEFINED): return stem.descriptor.server_descriptor.RelayDescriptor(desc_content) @with_default(yields = True) - def get_server_descriptors(self, default = UNDEFINED): + def get_server_descriptors(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.server_descriptor.RelayDescriptor]: """ get_server_descriptors(default = UNDEFINED) @@ -1889,10 +1892,10 @@ def get_server_descriptors(self, default = UNDEFINED): raise stem.DescriptorUnavailable('Descriptor information is unavailable, tor might still be downloading it') for desc in stem.descriptor.server_descriptor._parse_file(io.BytesIO(desc_content)): - yield desc + yield desc # type: ignore @with_default() - def get_network_status(self, relay = None, default = UNDEFINED): + def get_network_status(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.router_status_entry.RouterStatusEntryV3: """ get_network_status(relay = None, default = UNDEFINED) @@ -1951,7 +1954,7 @@ def get_network_status(self, relay = None, default = UNDEFINED): return stem.descriptor.router_status_entry.RouterStatusEntryV3(desc_content) @with_default(yields = True) - def get_network_statuses(self, default = UNDEFINED): + def get_network_statuses(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.router_status_entry.RouterStatusEntryV3]: """ get_network_statuses(default = UNDEFINED) @@ -1985,10 +1988,10 @@ def get_network_statuses(self, default = UNDEFINED): ) for desc in desc_iterator: - yield desc + yield desc # type: ignore @with_default() - def get_hidden_service_descriptor(self, address, default = UNDEFINED, servers = None, await_result = True, timeout = None): + def get_hidden_service_descriptor(self, address: str, default: Any = UNDEFINED, servers: Optional[Sequence[str]] = None, await_result: bool = True, timeout: Optional[float] = None) -> stem.descriptor.hidden_service.HiddenServiceDescriptorV2: """ get_hidden_service_descriptor(address, default = UNDEFINED, servers = None, await_result = True) @@ -2031,15 +2034,19 @@ def get_hidden_service_descriptor(self, address, default = UNDEFINED, servers = if not stem.util.tor_tools.is_valid_hidden_service_address(address): raise ValueError("'%s.onion' isn't a valid hidden service address" % address) - hs_desc_queue, hs_desc_listener = queue.Queue(), None - hs_desc_content_queue, hs_desc_content_listener = queue.Queue(), None + hs_desc_queue = queue.Queue() # type: queue.Queue[stem.response.events.Event] + hs_desc_listener = None + + hs_desc_content_queue = queue.Queue() # type: queue.Queue[stem.response.events.Event] + hs_desc_content_listener = None + start_time = time.time() if await_result: - def hs_desc_listener(event): + def hs_desc_listener(event: stem.response.events.Event) -> None: hs_desc_queue.put(event) - def hs_desc_content_listener(event): + def hs_desc_content_listener(event: stem.response.events.Event) -> None: hs_desc_content_queue.put(event) self.add_event_listener(hs_desc_listener, EventType.HS_DESC) @@ -2051,8 +2058,7 @@ def hs_desc_content_listener(event): if servers: request += ' ' + ' '.join(['SERVER=%s' % s for s in servers]) - response = self.msg(request) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg(request)) if not response.is_ok(): raise stem.ProtocolError('HSFETCH returned unexpected response code: %s' % response.code) @@ -2084,7 +2090,7 @@ def hs_desc_content_listener(event): if hs_desc_content_listener: self.remove_event_listener(hs_desc_content_listener) - def get_conf(self, param, default = UNDEFINED, multiple = False): + def get_conf(self, param: str, default: Any = UNDEFINED, multiple: bool = False) -> Union[str, Sequence[str]]: """ get_conf(param, default = UNDEFINED, multiple = False) @@ -2133,7 +2139,15 @@ def get_conf(self, param, default = UNDEFINED, multiple = False): entries = self.get_conf_map(param, default, multiple) return _case_insensitive_lookup(entries, param, default) - def get_conf_map(self, params, default = UNDEFINED, multiple = True): + # TODO: temporary aliases until we have better type support in our API + + def _get_conf_single(self, param: str, default: Any = UNDEFINED) -> str: + return self.get_conf(param, default) # type: ignore + + def _get_conf_multiple(self, param: str, default: Any = UNDEFINED) -> List[str]: + return self.get_conf(param, default, multiple = True) # type: ignore + + def get_conf_map(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, multiple: bool = True) -> Dict[str, Union[str, Sequence[str]]]: """ get_conf_map(params, default = UNDEFINED, multiple = True) @@ -2214,8 +2228,7 @@ def get_conf_map(self, params, default = UNDEFINED, multiple = True): return self._get_conf_dict_to_response(reply, default, multiple) try: - response = self.msg('GETCONF %s' % ' '.join(lookup_params)) - stem.response.convert('GETCONF', response) + response = stem.response._convert_to_getconf(self.msg('GETCONF %s' % ' '.join(lookup_params))) reply.update(response.entries) if self.is_caching_enabled(): @@ -2251,7 +2264,7 @@ def get_conf_map(self, params, default = UNDEFINED, multiple = True): else: raise - def _get_conf_dict_to_response(self, config_dict, default, multiple): + def _get_conf_dict_to_response(self, config_dict: Mapping[str, Sequence[str]], default: Any, multiple: bool) -> Dict[str, Union[str, Sequence[str]]]: """ Translates a dictionary of 'config key => [value1, value2...]' into the return value of :func:`~stem.control.Controller.get_conf_map`, taking into @@ -2273,7 +2286,7 @@ def _get_conf_dict_to_response(self, config_dict, default, multiple): return return_dict @with_default() - def is_set(self, param, default = UNDEFINED): + def is_set(self, param: str, default: Any = UNDEFINED) -> bool: """ is_set(param, default = UNDEFINED) @@ -2293,7 +2306,7 @@ def is_set(self, param, default = UNDEFINED): return param in self._get_custom_options() - def _get_custom_options(self): + def _get_custom_options(self) -> Dict[str, str]: result = self._get_cache('get_custom_options') if not result: @@ -2320,7 +2333,7 @@ def _get_custom_options(self): return result - def set_conf(self, param, value): + def set_conf(self, param: str, value: Union[str, Sequence[str]]) -> None: """ Changes the value of a tor configuration option. Our value can be any of the following... @@ -2342,7 +2355,7 @@ def set_conf(self, param, value): self.set_options({param: value}, False) - def reset_conf(self, *params): + def reset_conf(self, *params: str) -> None: """ Reverts one or more parameters to their default values. @@ -2357,7 +2370,7 @@ def reset_conf(self, *params): self.set_options(dict([(entry, None) for entry in params]), True) - def set_options(self, params, reset = False): + def set_options(self, params: Union[Mapping[str, Union[str, Sequence[str]]], Sequence[Tuple[str, Union[str, Sequence[str]]]]], reset: bool = False) -> None: """ Changes multiple tor configuration options via either a SETCONF or RESETCONF query. Both behave identically unless our value is None, in which @@ -2410,8 +2423,7 @@ def set_options(self, params, reset = False): raise ValueError('Cannot set %s to %s since the value was a %s but we only accept strings' % (param, value, type(value).__name__)) query = ' '.join(query_comp) - response = self.msg(query) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg(query)) if response.is_ok(): log.debug('%s (runtime: %0.4f)' % (query, time.time() - start_time)) @@ -2439,7 +2451,7 @@ def set_options(self, params, reset = False): raise stem.ProtocolError('Returned unexpected status code: %s' % response.code) @with_default() - def get_hidden_service_conf(self, default = UNDEFINED): + def get_hidden_service_conf(self, default: Any = UNDEFINED) -> Dict[str, Any]: """ get_hidden_service_conf(default = UNDEFINED) @@ -2485,15 +2497,14 @@ def get_hidden_service_conf(self, default = UNDEFINED): start_time = time.time() try: - response = self.msg('GETCONF HiddenServiceOptions') - stem.response.convert('GETCONF', response) + response = stem.response._convert_to_getconf(self.msg('GETCONF HiddenServiceOptions')) log.debug('GETCONF HiddenServiceOptions (runtime: %0.4f)' % (time.time() - start_time)) except stem.ControllerError as exc: log.debug('GETCONF HiddenServiceOptions (failed: %s)' % exc) raise - service_dir_map = collections.OrderedDict() + service_dir_map = collections.OrderedDict() # type: collections.OrderedDict[str, Any] directory = None for status_code, divider, content in response.content(): @@ -2534,7 +2545,7 @@ def get_hidden_service_conf(self, default = UNDEFINED): self._set_cache({'hidden_service_conf': service_dir_map}) return service_dir_map - def set_hidden_service_conf(self, conf): + def set_hidden_service_conf(self, conf: Mapping[str, Any]) -> None: """ Update all the configured hidden services from a dictionary having the same format as @@ -2599,7 +2610,7 @@ def set_hidden_service_conf(self, conf): self.set_options(hidden_service_options) - def create_hidden_service(self, path, port, target_address = None, target_port = None, auth_type = None, client_names = None): + def create_hidden_service(self, path: str, port: int, target_address: Optional[str] = None, target_port: Optional[int] = None, auth_type: Optional[str] = None, client_names: Optional[Sequence[str]] = None) -> 'stem.control.CreateHiddenServiceOutput': """ Create a new hidden service. If the directory is already present, a new port is added. @@ -2625,7 +2636,7 @@ def create_hidden_service(self, path, port, target_address = None, target_port = :param str auth_type: authentication type: basic, stealth or None to disable auth :param list client_names: client names (1-16 characters "A-Za-z0-9+-_") - :returns: :class:`~stem.cotroller.CreateHiddenServiceOutput` if we create + :returns: :class:`~stem.control.CreateHiddenServiceOutput` if we create or update a hidden service, **None** otherwise :raises: :class:`stem.ControllerError` if the call fails @@ -2717,7 +2728,7 @@ def create_hidden_service(self, path, port, target_address = None, target_port = config = conf, ) - def remove_hidden_service(self, path, port = None): + def remove_hidden_service(self, path: str, port: Optional[int] = None) -> bool: """ Discontinues a given hidden service. @@ -2759,7 +2770,7 @@ def remove_hidden_service(self, path, port = None): return True @with_default() - def list_ephemeral_hidden_services(self, default = UNDEFINED, our_services = True, detached = False): + def list_ephemeral_hidden_services(self, default: Any = UNDEFINED, our_services: bool = True, detached: bool = False) -> Sequence[str]: """ list_ephemeral_hidden_services(default = UNDEFINED, our_services = True, detached = False) @@ -2799,7 +2810,7 @@ def list_ephemeral_hidden_services(self, default = UNDEFINED, our_services = Tru return [r for r in result if r] # drop any empty responses (GETINFO is blank if unset) - def create_ephemeral_hidden_service(self, ports, key_type = 'NEW', key_content = 'BEST', discard_key = False, detached = False, await_publication = False, timeout = None, basic_auth = None, max_streams = None): + def create_ephemeral_hidden_service(self, ports: Union[int, Sequence[int], Mapping[int, str]], key_type: str = 'NEW', key_content: str = 'BEST', discard_key: bool = False, detached: bool = False, await_publication: bool = False, timeout: Optional[float] = None, basic_auth: Optional[Mapping[str, str]] = None, max_streams: Optional[int] = None) -> stem.response.add_onion.AddOnionResponse: """ Creates a new hidden service. Unlike :func:`~stem.control.Controller.create_hidden_service` this style of @@ -2901,11 +2912,12 @@ def create_ephemeral_hidden_service(self, ports, key_type = 'NEW', key_content = * :class:`stem.Timeout` if **timeout** was reached """ - hs_desc_queue, hs_desc_listener = queue.Queue(), None + hs_desc_queue = queue.Queue() # type: queue.Queue[stem.response.events.Event] + hs_desc_listener = None start_time = time.time() if await_publication: - def hs_desc_listener(event): + def hs_desc_listener(event: stem.response.events.Event) -> None: hs_desc_queue.put(event) self.add_event_listener(hs_desc_listener, EventType.HS_DESC) @@ -2953,8 +2965,7 @@ def hs_desc_listener(event): else: request += ' ClientAuth=%s' % client_name - response = self.msg(request) - stem.response.convert('ADD_ONION', response) + response = stem.response._convert_to_add_onion(stem.response._convert_to_add_onion(self.msg(request))) if await_publication: # We should receive five UPLOAD events, followed by up to another five @@ -2983,7 +2994,7 @@ def hs_desc_listener(event): return response - def remove_ephemeral_hidden_service(self, service_id): + def remove_ephemeral_hidden_service(self, service_id: str) -> bool: """ Discontinues a given hidden service that was created with :func:`~stem.control.Controller.create_ephemeral_hidden_service`. @@ -2998,8 +3009,7 @@ def remove_ephemeral_hidden_service(self, service_id): :raises: :class:`stem.ControllerError` if the call fails """ - response = self.msg('DEL_ONION %s' % service_id) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('DEL_ONION %s' % service_id)) if response.is_ok(): return True @@ -3008,7 +3018,7 @@ def remove_ephemeral_hidden_service(self, service_id): else: raise stem.ProtocolError('DEL_ONION returned unexpected response code: %s' % response.code) - def add_event_listener(self, listener, *events): + def add_event_listener(self, listener: Callable[[stem.response.events.Event], None], *events: 'stem.control.EventType') -> None: """ Directs further tor controller events to a given function. The function is expected to take a single argument, which is a @@ -3052,7 +3062,7 @@ def print_bw(event): event_type = stem.response.events.EVENT_TYPE_TO_CLASS.get(event_type) if event_type and (self.get_version() < event_type._VERSION_ADDED): - raise stem.InvalidRequest(552, '%s event requires Tor version %s or later' % (event_type, event_type._VERSION_ADDED)) + raise stem.InvalidRequest('552', '%s event requires Tor version %s or later' % (event_type, event_type._VERSION_ADDED)) for event_type in events: self._event_listeners.setdefault(event_type, []).append(listener) @@ -3066,7 +3076,7 @@ def print_bw(event): if failed_events: raise stem.ProtocolError('SETEVENTS rejected %s' % ', '.join(failed_events)) - def remove_event_listener(self, listener): + def remove_event_listener(self, listener: Callable[[stem.response.events.Event], None]) -> None: """ Stops a listener from being notified of further tor events. @@ -3092,7 +3102,7 @@ def remove_event_listener(self, listener): if not response.is_ok(): raise stem.ProtocolError('SETEVENTS received unexpected response\n%s' % response) - def _get_cache(self, param, namespace = None): + def _get_cache(self, param: str, namespace: Optional[str] = None) -> Any: """ Queries our request cache for the given key. @@ -3109,7 +3119,7 @@ def _get_cache(self, param, namespace = None): cache_key = '%s.%s' % (namespace, param) if namespace else param return self._request_cache.get(cache_key, None) - def _get_cache_map(self, params, namespace = None): + def _get_cache_map(self, params: Sequence[str], namespace: Optional[str] = None) -> Dict[str, Any]: """ Queries our request cache for multiple entries. @@ -3131,7 +3141,7 @@ def _get_cache_map(self, params, namespace = None): return cached_values - def _set_cache(self, params, namespace = None): + def _set_cache(self, params: Dict[str, Any], namespace: Optional[str] = None) -> None: """ Sets the given request cache entries. If the new cache value is **None** then it is removed from our cache. @@ -3173,7 +3183,7 @@ def _set_cache(self, params, namespace = None): else: self._request_cache[cache_key] = value - def _confchanged_cache_invalidation(self, params): + def _confchanged_cache_invalidation(self, params: Mapping[str, Any]) -> None: """ Drops dependent portions of the cache when configuration changes. @@ -3197,7 +3207,7 @@ def _confchanged_cache_invalidation(self, params): self._set_cache({'exit_policy': None}) # numerous options can change our policy - def is_caching_enabled(self): + def is_caching_enabled(self) -> bool: """ **True** if caching has been enabled, **False** otherwise. @@ -3206,7 +3216,7 @@ def is_caching_enabled(self): return self._is_caching_enabled - def set_caching(self, enabled): + def set_caching(self, enabled: bool) -> None: """ Enables or disables caching of information retrieved from tor. @@ -3218,7 +3228,7 @@ def set_caching(self, enabled): if not self._is_caching_enabled: self.clear_cache() - def clear_cache(self): + def clear_cache(self) -> None: """ Drops any cached results. """ @@ -3227,7 +3237,7 @@ def clear_cache(self): self._request_cache = {} self._last_newnym = 0.0 - def load_conf(self, configtext): + def load_conf(self, configtext: str) -> None: """ Sends the configuration text to Tor and loads it as if it has been read from the torrc. @@ -3237,8 +3247,7 @@ def load_conf(self, configtext): :raises: :class:`stem.ControllerError` if the call fails """ - response = self.msg('LOADCONF\n%s' % configtext) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('LOADCONF\n%s' % configtext)) if response.code in ('552', '553'): if response.code == '552' and response.message.startswith('Invalid config file: Failed to parse/validate config: Unknown option'): @@ -3247,7 +3256,7 @@ def load_conf(self, configtext): elif not response.is_ok(): raise stem.ProtocolError('+LOADCONF Received unexpected response\n%s' % str(response)) - def save_conf(self, force = False): + def save_conf(self, force: bool = False) -> None: """ Saves the current configuration options into the active torrc file. @@ -3263,17 +3272,16 @@ def save_conf(self, force = False): the configuration file """ - response = self.msg('SAVECONF FORCE' if force else 'SAVECONF') - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('SAVECONF FORCE' if force else 'SAVECONF')) if response.is_ok(): - return True + pass elif response.code == '551': raise stem.OperationFailed(response.code, response.message) else: raise stem.ProtocolError('SAVECONF returned unexpected response code') - def is_feature_enabled(self, feature): + def is_feature_enabled(self, feature: str) -> bool: """ Checks if a control connection feature is enabled. These features can be enabled using :func:`~stem.control.Controller.enable_feature`. @@ -3290,7 +3298,7 @@ def is_feature_enabled(self, feature): return feature in self._enabled_features - def enable_feature(self, features): + def enable_feature(self, features: Union[str, Sequence[str]]) -> None: """ Enables features that are disabled by default to maintain backward compatibility. Once enabled, a feature cannot be disabled and a new @@ -3307,8 +3315,7 @@ def enable_feature(self, features): if isinstance(features, (bytes, str)): features = [features] - response = self.msg('USEFEATURE %s' % ' '.join(features)) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('USEFEATURE %s' % ' '.join(features))) if not response.is_ok(): if response.code == '552': @@ -3324,7 +3331,7 @@ def enable_feature(self, features): self._enabled_features += [entry.upper() for entry in features] @with_default() - def get_circuit(self, circuit_id, default = UNDEFINED): + def get_circuit(self, circuit_id: int, default: Any = UNDEFINED) -> stem.response.events.CircuitEvent: """ get_circuit(circuit_id, default = UNDEFINED) @@ -3349,7 +3356,7 @@ def get_circuit(self, circuit_id, default = UNDEFINED): raise ValueError("Tor currently does not have a circuit with the id of '%s'" % circuit_id) @with_default() - def get_circuits(self, default = UNDEFINED): + def get_circuits(self, default: Any = UNDEFINED) -> List[stem.response.events.CircuitEvent]: """ get_circuits(default = UNDEFINED) @@ -3362,17 +3369,16 @@ def get_circuits(self, default = UNDEFINED): :raises: :class:`stem.ControllerError` if the call fails and no default was provided """ - circuits = [] + circuits = [] # type: List[stem.response.events.CircuitEvent] response = self.get_info('circuit-status') for circ in response.splitlines(): - circ_message = stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes('650 CIRC %s\r\n' % circ))) - stem.response.convert('EVENT', circ_message) - circuits.append(circ_message) + circ_message = stem.response._convert_to_event(stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes('650 CIRC %s\r\n' % circ)))) + circuits.append(circ_message) # type: ignore return circuits - def new_circuit(self, path = None, purpose = 'general', await_build = False, timeout = None): + def new_circuit(self, path: Union[None, str, Sequence[str]] = None, purpose: str = 'general', await_build: bool = False, timeout: Optional[float] = None) -> str: """ Requests a new circuit. If the path isn't provided, one is automatically selected. @@ -3380,7 +3386,7 @@ def new_circuit(self, path = None, purpose = 'general', await_build = False, tim .. versionchanged:: 1.7.0 Added the timeout argument. - :param list,str path: one or more relays to make a circuit through + :param str,list path: one or more relays to make a circuit through :param str purpose: 'general' or 'controller' :param bool await_build: blocks until the circuit is built if **True** :param float timeout: seconds to wait when **await_build** is **True** @@ -3394,7 +3400,7 @@ def new_circuit(self, path = None, purpose = 'general', await_build = False, tim return self.extend_circuit('0', path, purpose, await_build, timeout) - def extend_circuit(self, circuit_id = '0', path = None, purpose = 'general', await_build = False, timeout = None): + def extend_circuit(self, circuit_id: str = '0', path: Union[None, str, Sequence[str]] = None, purpose: str = 'general', await_build: bool = False, timeout: Optional[float] = None) -> str: """ Either requests the creation of a new circuit or extends an existing one. @@ -3418,7 +3424,7 @@ def extend_circuit(self, circuit_id = '0', path = None, purpose = 'general', awa Added the timeout argument. :param str circuit_id: id of a circuit to be extended - :param list,str path: one or more relays to make a circuit through, this is + :param str,list path: one or more relays to make a circuit through, this is required if the circuit id is non-zero :param str purpose: 'general' or 'controller' :param bool await_build: blocks until the circuit is built if **True** @@ -3438,11 +3444,12 @@ def extend_circuit(self, circuit_id = '0', path = None, purpose = 'general', awa # to build. This is icky, but we can't reliably do this via polling since # we then can't get the failure if it can't be created. - circ_queue, circ_listener = queue.Queue(), None + circ_queue = queue.Queue() # type: queue.Queue[stem.response.events.Event] + circ_listener = None start_time = time.time() if await_build: - def circ_listener(event): + def circ_listener(event: stem.response.events.Event) -> None: circ_queue.put(event) self.add_event_listener(circ_listener, EventType.CIRC) @@ -3459,8 +3466,7 @@ def circ_listener(event): if purpose: args.append('purpose=%s' % purpose) - response = self.msg('EXTENDCIRCUIT %s' % ' '.join(args)) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('EXTENDCIRCUIT %s' % ' '.join(args))) if response.code in ('512', '552'): raise stem.InvalidRequest(response.code, response.message) @@ -3489,7 +3495,7 @@ def circ_listener(event): if circ_listener: self.remove_event_listener(circ_listener) - def repurpose_circuit(self, circuit_id, purpose): + def repurpose_circuit(self, circuit_id: str, purpose: str) -> None: """ Changes a circuit's purpose. Currently, two purposes are recognized... * general @@ -3501,8 +3507,7 @@ def repurpose_circuit(self, circuit_id, purpose): :raises: :class:`stem.InvalidArguments` if the circuit doesn't exist or if the purpose was invalid """ - response = self.msg('SETCIRCUITPURPOSE %s purpose=%s' % (circuit_id, purpose)) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('SETCIRCUITPURPOSE %s purpose=%s' % (circuit_id, purpose))) if not response.is_ok(): if response.code == '552': @@ -3510,7 +3515,7 @@ def repurpose_circuit(self, circuit_id, purpose): else: raise stem.ProtocolError('SETCIRCUITPURPOSE returned unexpected response code: %s' % response.code) - def close_circuit(self, circuit_id, flag = ''): + def close_circuit(self, circuit_id: str, flag: str = '') -> None: """ Closes the specified circuit. @@ -3518,12 +3523,12 @@ def close_circuit(self, circuit_id, flag = ''): :param str flag: optional value to modify closing, the only flag available is 'IfUnused' which will not close the circuit unless it is unused - :raises: :class:`stem.InvalidArguments` if the circuit is unknown - :raises: :class:`stem.InvalidRequest` if not enough information is provided + :raises: + * :class:`stem.InvalidArguments` if the circuit is unknown + * :class:`stem.InvalidRequest` if not enough information is provided """ - response = self.msg('CLOSECIRCUIT %s %s' % (circuit_id, flag)) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('CLOSECIRCUIT %s %s' % (circuit_id, flag))) if not response.is_ok(): if response.code in ('512', '552'): @@ -3534,7 +3539,7 @@ def close_circuit(self, circuit_id, flag = ''): raise stem.ProtocolError('CLOSECIRCUIT returned unexpected response code: %s' % response.code) @with_default() - def get_streams(self, default = UNDEFINED): + def get_streams(self, default: Any = UNDEFINED) -> List[stem.response.events.StreamEvent]: """ get_streams(default = UNDEFINED) @@ -3548,17 +3553,16 @@ def get_streams(self, default = UNDEFINED): provided """ - streams = [] + streams = [] # type: List[stem.response.events.StreamEvent] response = self.get_info('stream-status') for stream in response.splitlines(): - message = stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes('650 STREAM %s\r\n' % stream))) - stem.response.convert('EVENT', message) - streams.append(message) + message = stem.response._convert_to_event(stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes('650 STREAM %s\r\n' % stream)))) + streams.append(message) # type: ignore return streams - def attach_stream(self, stream_id, circuit_id, exiting_hop = None): + def attach_stream(self, stream_id: str, circuit_id: str, exiting_hop: Optional[int] = None) -> None: """ Attaches a stream to a circuit. @@ -3580,8 +3584,7 @@ def attach_stream(self, stream_id, circuit_id, exiting_hop = None): if exiting_hop: query += ' HOP=%s' % exiting_hop - response = self.msg(query) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg(query)) if not response.is_ok(): if response.code == '552': @@ -3593,7 +3596,7 @@ def attach_stream(self, stream_id, circuit_id, exiting_hop = None): else: raise stem.ProtocolError('ATTACHSTREAM returned unexpected response code: %s' % response.code) - def close_stream(self, stream_id, reason = stem.RelayEndReason.MISC, flag = ''): + def close_stream(self, stream_id: str, reason: stem.RelayEndReason = stem.RelayEndReason.MISC, flag: str = '') -> None: """ Closes the specified stream. @@ -3609,8 +3612,7 @@ def close_stream(self, stream_id, reason = stem.RelayEndReason.MISC, flag = ''): # there's a single value offset between RelayEndReason.index_of() and the # value that tor expects since tor's value starts with the index of one - response = self.msg('CLOSESTREAM %s %s %s' % (stream_id, stem.RelayEndReason.index_of(reason) + 1, flag)) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('CLOSESTREAM %s %s %s' % (stream_id, stem.RelayEndReason.index_of(reason) + 1, flag))) if not response.is_ok(): if response.code in ('512', '552'): @@ -3622,7 +3624,7 @@ def close_stream(self, stream_id, reason = stem.RelayEndReason.MISC, flag = ''): else: raise stem.ProtocolError('CLOSESTREAM returned unexpected response code: %s' % response.code) - def signal(self, signal): + def signal(self, signal: stem.Signal) -> None: """ Sends a signal to the Tor client. @@ -3633,8 +3635,7 @@ def signal(self, signal): * :class:`stem.InvalidArguments` if signal provided wasn't recognized """ - response = self.msg('SIGNAL %s' % signal) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('SIGNAL %s' % signal)) if response.is_ok(): if signal == stem.Signal.NEWNYM: @@ -3645,7 +3646,7 @@ def signal(self, signal): raise stem.ProtocolError('SIGNAL response contained unrecognized status code: %s' % response.code) - def is_newnym_available(self): + def is_newnym_available(self) -> bool: """ Indicates if tor would currently accept a NEWNYM signal. This can only account for signals sent via this controller. @@ -3661,7 +3662,7 @@ def is_newnym_available(self): else: return False - def get_newnym_wait(self): + def get_newnym_wait(self) -> float: """ Provides the number of seconds until a NEWNYM signal would be respected. This can only account for signals sent via this controller. @@ -3675,7 +3676,7 @@ def get_newnym_wait(self): return max(0.0, self._last_newnym + 10 - time.time()) @with_default() - def get_effective_rate(self, default = UNDEFINED, burst = False): + def get_effective_rate(self, default: Any = UNDEFINED, burst: bool = False) -> int: """ get_effective_rate(default = UNDEFINED, burst = False) @@ -3698,14 +3699,14 @@ def get_effective_rate(self, default = UNDEFINED, burst = False): """ if not burst: - attributes = ('BandwidthRate', 'RelayBandwidthRate', 'MaxAdvertisedBandwidth') + attributes = ['BandwidthRate', 'RelayBandwidthRate', 'MaxAdvertisedBandwidth'] else: - attributes = ('BandwidthBurst', 'RelayBandwidthBurst') + attributes = ['BandwidthBurst', 'RelayBandwidthBurst'] value = None for attr in attributes: - attr_value = int(self.get_conf(attr)) + attr_value = int(self._get_conf_single(attr)) if attr_value == 0 and attr.startswith('Relay'): continue # RelayBandwidthRate and RelayBandwidthBurst default to zero @@ -3714,7 +3715,7 @@ def get_effective_rate(self, default = UNDEFINED, burst = False): return value - def map_address(self, mapping): + def map_address(self, mapping: Mapping[str, str]) -> Dict[str, str]: """ Map addresses to replacement addresses. Tor replaces subseqent connections to the original addresses with the replacement addresses. @@ -3726,20 +3727,18 @@ def map_address(self, mapping): :param dict mapping: mapping of original addresses to replacement addresses + :returns: **dict** with 'original -> replacement' address mappings + :raises: * :class:`stem.InvalidRequest` if the addresses are malformed * :class:`stem.OperationFailed` if Tor couldn't fulfill the request - - :returns: **dict** with 'original -> replacement' address mappings """ mapaddress_arg = ' '.join(['%s=%s' % (k, v) for (k, v) in list(mapping.items())]) response = self.msg('MAPADDRESS %s' % mapaddress_arg) - stem.response.convert('MAPADDRESS', response) + return stem.response._convert_to_mapaddress(response).entries - return response.entries - - def drop_guards(self): + def drop_guards(self) -> None: """ Drops our present guard nodes and picks a new set. @@ -3750,7 +3749,7 @@ def drop_guards(self): self.msg('DROPGUARDS') - def _post_authentication(self): + def _post_authentication(self) -> None: super(Controller, self)._post_authentication() # try to re-attach event listeners to the new instance @@ -3774,8 +3773,7 @@ def _post_authentication(self): owning_pid = self.get_conf('__OwningControllerProcess', None) if owning_pid == str(os.getpid()) and self.is_localhost(): - response = self.msg('TAKEOWNERSHIP') - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('TAKEOWNERSHIP')) if response.is_ok(): # Now that tor is tracking our ownership of the process via the control @@ -3788,11 +3786,18 @@ def _post_authentication(self): else: log.warn('We were unable assert ownership of tor through TAKEOWNERSHIP, despite being configured to be the owning process through __OwningControllerProcess. (%s)' % response) - def _handle_event(self, event_message): + def _handle_event(self, event_message: stem.response.ControlMessage) -> None: + event = None # type: Optional[stem.response.events.Event] + try: - stem.response.convert('EVENT', event_message) - event_type = event_message.type + event = stem.response._convert_to_event(event_message) + event_type = event.type except stem.ProtocolError as exc: + # TODO: We should change this so malformed events convert to the base + # Event class, so we don't provide raw ControlMessages to listeners. + + event = event_message # type: ignore + log.error('Tor sent a malformed event (%s): %s' % (exc, event_message)) event_type = MALFORMED_EVENTS @@ -3801,11 +3806,11 @@ def _handle_event(self, event_message): if listener_type == event_type: for listener in event_listeners: try: - listener(event_message) + listener(event) except Exception as exc: - log.warn('Event listener raised an uncaught exception (%s): %s' % (exc, event_message)) + log.warn('Event listener raised an uncaught exception (%s): %s' % (exc, event)) - def _attach_listeners(self): + def _attach_listeners(self) -> Tuple[Sequence[str], Sequence[str]]: """ Attempts to subscribe to the self._event_listeners events from tor. This is a no-op if we're not currently authenticated. @@ -3849,7 +3854,7 @@ def _attach_listeners(self): return (set_events, failed_events) -def _parse_circ_path(path): +def _parse_circ_path(path: str) -> Sequence[Tuple[str, str]]: """ Parses a circuit path as a list of **(fingerprint, nickname)** tuples. Tor circuit paths are defined as being of the form... @@ -3892,7 +3897,7 @@ def _parse_circ_path(path): return [] -def _parse_circ_entry(entry): +def _parse_circ_entry(entry: str) -> Tuple[str, str]: """ Parses a single relay's 'LongName' or 'ServerID'. See the :func:`~stem.control._parse_circ_path` function for more information. @@ -3930,7 +3935,7 @@ def _parse_circ_entry(entry): @with_default() -def _case_insensitive_lookup(entries, key, default = UNDEFINED): +def _case_insensitive_lookup(entries: Union[Sequence[str], Mapping[str, Any]], key: str, default: Any = UNDEFINED) -> Any: """ Makes a case insensitive lookup within a list or dictionary, providing the first matching entry that we come across. @@ -3957,7 +3962,7 @@ def _case_insensitive_lookup(entries, key, default = UNDEFINED): raise ValueError("key '%s' doesn't exist in dict: %s" % (key, entries)) -def _get_with_timeout(event_queue, timeout, start_time): +def _get_with_timeout(event_queue: queue.Queue, timeout: float, start_time: float) -> Any: """ Pulls an item from a queue with a given timeout. """ diff --git a/stem/descriptor/__init__.py b/stem/descriptor/__init__.py index ff2734057..477e15e9b 100644 --- a/stem/descriptor/__init__.py +++ b/stem/descriptor/__init__.py @@ -108,6 +108,7 @@ import codecs import collections import copy +import hashlib import io import os import random @@ -120,6 +121,8 @@ import stem.util.str_tools import stem.util.system +from typing import Any, BinaryIO, Callable, Dict, IO, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union + __all__ = [ 'bandwidth_file', 'certificate', @@ -150,7 +153,7 @@ SPECIFIC_KEYWORD_LINE = '^(%%s)(?:[%s]+(.*))?$' % WHITESPACE PGP_BLOCK_START = re.compile('^-----BEGIN ([%s%s]+)-----$' % (KEYWORD_CHAR, WHITESPACE)) PGP_BLOCK_END = '-----END %s-----' -EMPTY_COLLECTION = ([], {}, set()) +EMPTY_COLLECTION = ([], {}, set()) # type: ignore DIGEST_TYPE_INFO = b'\x00\x01' DIGEST_PADDING = b'\xFF' @@ -162,6 +165,8 @@ WPi4Fl2qryzTb3QO5r5x7T8OsG2IBUET1bLQzmtbC560SYR49IvVAgMBAAE= """ +ENTRY_TYPE = Dict[str, List[Tuple[str, str, str]]] + DigestHash = stem.util.enum.UppercaseEnum( 'SHA1', 'SHA256', @@ -192,7 +197,7 @@ class _Compression(object): .. versionadded:: 1.8.0 """ - def __init__(self, name, module, encoding, extension, decompression_func): + def __init__(self, name: str, module: Optional[str], encoding: str, extension: str, decompression_func: Callable[[Any, bytes], bytes]) -> None: if module is None: self._module = None self.available = True @@ -222,7 +227,7 @@ def __init__(self, name, module, encoding, extension, decompression_func): self._module_name = module self._decompression_func = decompression_func - def decompress(self, content): + def decompress(self, content: bytes) -> bytes: """ Decompresses the given content via this method. @@ -250,11 +255,11 @@ def decompress(self, content): except Exception as exc: raise IOError('Failed to decompress as %s: %s' % (self, exc)) - def __str__(self): + def __str__(self) -> str: return self._name -def _zstd_decompress(module, content): +def _zstd_decompress(module: Any, content: bytes) -> bytes: output_buffer = io.BytesIO() with module.ZstdDecompressor().write_to(output_buffer) as decompressor: @@ -286,7 +291,7 @@ class TypeAnnotation(collections.namedtuple('TypeAnnotation', ['name', 'major_ve :var int minor_version: minor version number """ - def __str__(self): + def __str__(self) -> str: return '@type %s %s.%s' % (self.name, self.major_version, self.minor_version) @@ -302,7 +307,7 @@ class SigningKey(collections.namedtuple('SigningKey', ['private', 'public', 'pub """ -def parse_file(descriptor_file, descriptor_type = None, validate = False, document_handler = DocumentHandler.ENTRIES, normalize_newlines = None, **kwargs): +def parse_file(descriptor_file: Union[str, BinaryIO, tarfile.TarFile, IO[bytes]], descriptor_type: str = None, validate: bool = False, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, normalize_newlines: Optional[bool] = None, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: """ Simple function to read the descriptor contents from a file, providing an iterator for its :class:`~stem.descriptor.__init__.Descriptor` contents. @@ -370,7 +375,7 @@ def parse_file(descriptor_file, descriptor_type = None, validate = False, docume # Delegate to a helper if this is a path or tarfile. - handler = None + handler = None # type: Callable if isinstance(descriptor_file, (bytes, str)): if stem.util.system.is_tarfile(descriptor_file): @@ -386,7 +391,7 @@ def parse_file(descriptor_file, descriptor_type = None, validate = False, docume return - if not descriptor_file.seekable(): + if not descriptor_file.seekable(): # type: ignore raise IOError(UNSEEKABLE_MSG) # The tor descriptor specifications do not provide a reliable method for @@ -395,19 +400,19 @@ def parse_file(descriptor_file, descriptor_type = None, validate = False, docume # by an annotation on their first line... # https://trac.torproject.org/5651 - initial_position = descriptor_file.tell() - first_line = stem.util.str_tools._to_unicode(descriptor_file.readline().strip()) + initial_position = descriptor_file.tell() # type: ignore + first_line = stem.util.str_tools._to_unicode(descriptor_file.readline().strip()) # type: ignore metrics_header_match = re.match('^@type (\\S+) (\\d+).(\\d+)$', first_line) if not metrics_header_match: - descriptor_file.seek(initial_position) + descriptor_file.seek(initial_position) # type: ignore descriptor_path = getattr(descriptor_file, 'name', None) - filename = '' if descriptor_path is None else os.path.basename(descriptor_file.name) + filename = '' if descriptor_path is None else os.path.basename(descriptor_file.name) # type: str # type: ignore - def parse(descriptor_file): + def parse(descriptor_file: BinaryIO) -> Iterator['stem.descriptor.Descriptor']: if normalize_newlines: - descriptor_file = NewlineNormalizer(descriptor_file) + descriptor_file = NewlineNormalizer(descriptor_file) # type: ignore if descriptor_type is not None: descriptor_type_match = re.match('^(\\S+) (\\d+).(\\d+)$', descriptor_type) @@ -426,7 +431,7 @@ def parse(descriptor_file): # Cached descriptor handling. These contain multiple descriptors per file. if normalize_newlines is None and stem.util.system.is_windows(): - descriptor_file = NewlineNormalizer(descriptor_file) + descriptor_file = NewlineNormalizer(descriptor_file) # type: ignore if filename == 'cached-descriptors' or filename == 'cached-descriptors.new': return stem.descriptor.server_descriptor._parse_file(descriptor_file, validate = validate, **kwargs) @@ -439,29 +444,29 @@ def parse(descriptor_file): elif filename == 'cached-microdesc-consensus': return stem.descriptor.networkstatus._parse_file(descriptor_file, is_microdescriptor = True, validate = validate, document_handler = document_handler, **kwargs) else: - raise TypeError("Unable to determine the descriptor's type. filename: '%s', first line: '%s'" % (filename, first_line)) + raise TypeError("Unable to determine the descriptor's type. filename: '%s', first line: '%s'" % (filename, stem.util.str_tools._to_unicode(first_line))) - for desc in parse(descriptor_file): + for desc in parse(descriptor_file): # type: ignore if descriptor_path is not None: desc._set_path(os.path.abspath(descriptor_path)) yield desc -def _parse_file_for_path(descriptor_file, *args, **kwargs): +def _parse_file_for_path(descriptor_file: str, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: with open(descriptor_file, 'rb') as desc_file: for desc in parse_file(desc_file, *args, **kwargs): yield desc -def _parse_file_for_tar_path(descriptor_file, *args, **kwargs): +def _parse_file_for_tar_path(descriptor_file: str, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: with tarfile.open(descriptor_file) as tar_file: for desc in parse_file(tar_file, *args, **kwargs): desc._set_path(os.path.abspath(descriptor_file)) yield desc -def _parse_file_for_tarfile(descriptor_file, *args, **kwargs): +def _parse_file_for_tarfile(descriptor_file: tarfile.TarFile, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: for tar_entry in descriptor_file: if tar_entry.isfile(): entry = descriptor_file.extractfile(tar_entry) @@ -477,10 +482,14 @@ def _parse_file_for_tarfile(descriptor_file, *args, **kwargs): entry.close() -def _parse_metrics_file(descriptor_type, major_version, minor_version, descriptor_file, validate, document_handler, **kwargs): +def _parse_metrics_file(descriptor_type: str, major_version: int, minor_version: int, descriptor_file: BinaryIO, validate: bool, document_handler: 'stem.descriptor.DocumentHandler', **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: # Parses descriptor files from metrics, yielding individual descriptors. This # throws a TypeError if the descriptor_type or version isn't recognized. + desc = None # type: Optional[Any] + desc_type = None # type: Optional[Type[stem.descriptor.Descriptor]] + document_type = None # type: Optional[Type] + if descriptor_type == stem.descriptor.server_descriptor.RelayDescriptor.TYPE_ANNOTATION_NAME and major_version == 1: for desc in stem.descriptor.server_descriptor._parse_file(descriptor_file, is_bridge = False, validate = validate, **kwargs): yield desc @@ -505,7 +514,7 @@ def _parse_metrics_file(descriptor_type, major_version, minor_version, descripto for desc in stem.descriptor.networkstatus._parse_file(descriptor_file, document_type, validate = validate, document_handler = document_handler, **kwargs): yield desc elif descriptor_type == stem.descriptor.networkstatus.KeyCertificate.TYPE_ANNOTATION_NAME and major_version == 1: - for desc in stem.descriptor.networkstatus._parse_file_key_certs(descriptor_file, validate = validate, **kwargs): + for desc in stem.descriptor.networkstatus._parse_file_key_certs(descriptor_file, validate = validate): yield desc elif descriptor_type in ('network-status-consensus-3', 'network-status-vote-3') and major_version == 1: document_type = stem.descriptor.networkstatus.NetworkStatusDocumentV3 @@ -547,7 +556,7 @@ def _parse_metrics_file(descriptor_type, major_version, minor_version, descripto raise TypeError("Unrecognized metrics descriptor format. type: '%s', version: '%i.%i'" % (descriptor_type, major_version, minor_version)) -def _descriptor_content(attr = None, exclude = (), header_template = (), footer_template = ()): +def _descriptor_content(attr: Mapping[str, str] = None, exclude: Sequence[str] = (), header_template: Sequence[Tuple[str, Optional[str]]] = (), footer_template: Sequence[Tuple[str, Optional[str]]] = ()) -> bytes: """ Constructs a minimal descriptor with the given attributes. The content we provide back is of the form... @@ -584,8 +593,9 @@ def _descriptor_content(attr = None, exclude = (), header_template = (), footer_ :returns: bytes with the requested descriptor content """ - header_content, footer_content = [], [] - attr = {} if attr is None else collections.OrderedDict(attr) # shallow copy since we're destructive + header_content = [] # type: List[str] + footer_content = [] # type: List[str] + attr = {} if attr is None else collections.OrderedDict(attr) # type: Dict[str, str] # shallow copy since we're destructive for content, template in ((header_content, header_template), (footer_content, footer_template)): @@ -619,28 +629,28 @@ def _descriptor_content(attr = None, exclude = (), header_template = (), footer_ return stem.util.str_tools._to_bytes('\n'.join(header_content + remainder + footer_content)) -def _value(line, entries): +def _value(line: str, entries: ENTRY_TYPE) -> str: return entries[line][0][0] -def _values(line, entries): +def _values(line: str, entries: ENTRY_TYPE) -> Sequence[str]: return [entry[0] for entry in entries[line]] -def _parse_simple_line(keyword, attribute, func = None): - def _parse(descriptor, entries): +def _parse_simple_line(keyword: str, attribute: str, func: Optional[Callable[[str], Any]] = None) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value(keyword, entries) setattr(descriptor, attribute, func(value) if func else value) return _parse -def _parse_if_present(keyword, attribute): +def _parse_if_present(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: return lambda descriptor, entries: setattr(descriptor, attribute, keyword in entries) -def _parse_bytes_line(keyword, attribute): - def _parse(descriptor, entries): +def _parse_bytes_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: line_match = re.search(stem.util.str_tools._to_bytes('^(opt )?%s(?:[%s]+(.*))?$' % (keyword, WHITESPACE)), descriptor.get_bytes(), re.MULTILINE) result = None @@ -653,8 +663,8 @@ def _parse(descriptor, entries): return _parse -def _parse_int_line(keyword, attribute, allow_negative = True): - def _parse(descriptor, entries): +def _parse_int_line(keyword: str, attribute: str, allow_negative: bool = True) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value(keyword, entries) try: @@ -670,10 +680,10 @@ def _parse(descriptor, entries): return _parse -def _parse_timestamp_line(keyword, attribute): +def _parse_timestamp_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: # "" YYYY-MM-DD HH:MM:SS - def _parse(descriptor, entries): + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value(keyword, entries) try: @@ -684,10 +694,10 @@ def _parse(descriptor, entries): return _parse -def _parse_forty_character_hex(keyword, attribute): +def _parse_forty_character_hex(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: # format of fingerprints, sha1 digests, etc - def _parse(descriptor, entries): + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value(keyword, entries) if not stem.util.tor_tools.is_hex_digits(value, 40): @@ -698,15 +708,15 @@ def _parse(descriptor, entries): return _parse -def _parse_protocol_line(keyword, attribute): - def _parse(descriptor, entries): +def _parse_protocol_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # parses 'protocol' entries like: Cons=1-2 Desc=1-2 DirCache=1 HSDir=1 value = _value(keyword, entries) protocols = collections.OrderedDict() for k, v in _mappings_for(keyword, value): - versions = [] + versions = [] # type: List[int] if not v: continue @@ -729,8 +739,8 @@ def _parse(descriptor, entries): return _parse -def _parse_key_block(keyword, attribute, expected_block_type, value_attribute = None): - def _parse(descriptor, entries): +def _parse_key_block(keyword: str, attribute: str, expected_block_type: str, value_attribute: Optional[str] = None) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value, block_type, block_contents = entries[keyword][0] if not block_contents or block_type != expected_block_type: @@ -744,7 +754,7 @@ def _parse(descriptor, entries): return _parse -def _mappings_for(keyword, value, require_value = False, divider = ' '): +def _mappings_for(keyword: str, value: str, require_value: bool = False, divider: str = ' ') -> Iterator[Tuple[str, str]]: """ Parses an attribute as a series of 'key=value' mappings. Unlike _parse_* functions this is a helper, returning the attribute value rather than setting @@ -777,7 +787,7 @@ def _mappings_for(keyword, value, require_value = False, divider = ' '): yield k, v -def _copy(default): +def _copy(default: Any) -> Any: if default is None or isinstance(default, (bool, stem.exit_policy.ExitPolicy)): return default # immutable elif default in EMPTY_COLLECTION: @@ -786,7 +796,7 @@ def _copy(default): return copy.copy(default) -def _encode_digest(hash_value, encoding): +def _encode_digest(hash_value: 'hashlib._HASH', encoding: 'stem.descriptor.DigestEncoding') -> Union[str, 'hashlib._HASH']: # type: ignore """ Encodes a hash value with the given HashEncoding. """ @@ -808,21 +818,21 @@ class Descriptor(object): Common parent for all types of descriptors. """ - ATTRIBUTES = {} # mapping of 'attribute' => (default_value, parsing_function) - PARSER_FOR_LINE = {} # line keyword to its associated parsing function - TYPE_ANNOTATION_NAME = None + ATTRIBUTES = {} # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]] # mapping of 'attribute' => (default_value, parsing_function) + PARSER_FOR_LINE = {} # type: Dict[str, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]] # line keyword to its associated parsing function + TYPE_ANNOTATION_NAME = None # type: Optional[str] - def __init__(self, contents, lazy_load = False): - self._path = None - self._archive_path = None + def __init__(self, contents: bytes, lazy_load: bool = False) -> None: + self._path = None # type: Optional[str] + self._archive_path = None # type: Optional[str] self._raw_contents = contents self._lazy_loading = lazy_load - self._entries = {} - self._hash = None - self._unrecognized_lines = [] + self._entries = {} # type: ENTRY_TYPE + self._hash = None # type: Optional[int] + self._unrecognized_lines = [] # type: List[str] @classmethod - def from_str(cls, content, **kwargs): + def from_str(cls, content: str, **kwargs: Any) -> Union['stem.descriptor.Descriptor', List['stem.descriptor.Descriptor']]: """ Provides a :class:`~stem.descriptor.__init__.Descriptor` for the given content. @@ -871,7 +881,7 @@ def from_str(cls, content, **kwargs): raise ValueError("Descriptor.from_str() expected a single descriptor, but had %i instead. Please include 'multiple = True' if you want a list of results instead." % len(results)) @classmethod - def content(cls, attr = None, exclude = ()): + def content(cls, attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: """ Creates descriptor content with the given attributes. Mandatory fields are filled with dummy information unless data is supplied. This doesn't yet @@ -883,7 +893,7 @@ def content(cls, attr = None, exclude = ()): :param list exclude: mandatory keywords to exclude from the descriptor, this results in an invalid descriptor - :returns: **str** with the content of a descriptor + :returns: **bytes** with the content of a descriptor :raises: * **ImportError** if cryptography is unavailable and sign is True @@ -893,7 +903,7 @@ def content(cls, attr = None, exclude = ()): raise NotImplementedError("The create and content methods haven't been implemented for %s" % cls.__name__) @classmethod - def create(cls, attr = None, exclude = (), validate = True): + def create(cls, attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True) -> 'stem.descriptor.Descriptor': """ Creates a descriptor with the given attributes. Mandatory fields are filled with dummy information unless data is supplied. This doesn't yet create a @@ -915,9 +925,9 @@ def create(cls, attr = None, exclude = (), validate = True): * **NotImplementedError** if not implemented for this descriptor type """ - return cls(cls.content(attr, exclude), validate = validate) + return cls(cls.content(attr, exclude), validate = validate) # type: ignore - def type_annotation(self): + def type_annotation(self) -> 'stem.descriptor.TypeAnnotation': """ Provides the `Tor metrics annotation `_ of this @@ -939,7 +949,7 @@ def type_annotation(self): else: raise NotImplementedError('%s does not have a @type annotation' % type(self).__name__) - def get_path(self): + def get_path(self) -> str: """ Provides the absolute path that we loaded this descriptor from. @@ -948,7 +958,7 @@ def get_path(self): return self._path - def get_archive_path(self): + def get_archive_path(self) -> str: """ If this descriptor came from an archive then provides its path within the archive. This is only set if the descriptor was read by @@ -960,7 +970,7 @@ def get_archive_path(self): return self._archive_path - def get_bytes(self): + def get_bytes(self) -> bytes: """ Provides the ASCII **bytes** of the descriptor. This only differs from **str()** if you're running python 3.x, in which case **str()** provides a @@ -971,7 +981,7 @@ def get_bytes(self): return stem.util.str_tools._to_bytes(self._raw_contents) - def get_unrecognized_lines(self): + def get_unrecognized_lines(self) -> List[str]: """ Provides a list of lines that were either ignored or had data that we did not know how to process. This is most common due to new descriptor fields @@ -987,7 +997,7 @@ def get_unrecognized_lines(self): return list(self._unrecognized_lines) - def _parse(self, entries, validate, parser_for_line = None): + def _parse(self, entries: ENTRY_TYPE, validate: bool, parser_for_line: Optional[Dict[str, Callable]] = None) -> None: """ Parses a series of 'keyword => (value, pgp block)' mappings and applies them as attributes. @@ -1018,16 +1028,16 @@ def _parse(self, entries, validate, parser_for_line = None): if validate: raise - def _set_path(self, path): + def _set_path(self, path: str) -> None: self._path = path - def _set_archive_path(self, path): + def _set_archive_path(self, path: str) -> None: self._archive_path = path - def _name(self, is_plural = False): + def _name(self, is_plural: bool = False) -> str: return str(type(self)) - def _digest_for_signature(self, signing_key, signature): + def _digest_for_signature(self, signing_key: str, signature: str) -> str: """ Provides the signed digest we should have given this key and signature. @@ -1089,13 +1099,15 @@ def _digest_for_signature(self, signing_key, signature): digest_hex = codecs.encode(decrypted_bytes[seperator_index + 1:], 'hex_codec') return stem.util.str_tools._to_unicode(digest_hex.upper()) - def _content_range(self, start = None, end = None): + def _content_range(self, start: Optional[Union[str, bytes]] = None, end: Optional[Union[str, bytes]] = None) -> bytes: """ Provides the descriptor content inclusively between two substrings. :param bytes start: start of the content range to get :param bytes end: end of the content range to get + :returns: **bytes** within the given range + :raises: ValueError if either the start or end substring are not within our content """ @@ -1106,24 +1118,24 @@ def _content_range(self, start = None, end = None): start_index = content.find(stem.util.str_tools._to_bytes(start)) if start_index == -1: - raise ValueError("'%s' is not present within our descriptor content" % start) + raise ValueError("'%s' is not present within our descriptor content" % stem.util.str_tools._to_unicode(start)) if end is not None: end_index = content.find(stem.util.str_tools._to_bytes(end), start_index) if end_index == -1: - raise ValueError("'%s' is not present within our descriptor content" % end) + raise ValueError("'%s' is not present within our descriptor content" % stem.util.str_tools._to_unicode(end)) end_index += len(end) # make the ending index inclusive return content[start_index:end_index] - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: # We can't use standard hasattr() since it calls this function, recursing. # Doing so works since it stops recursing after several dozen iterations # (not sure why), but horrible in terms of performance. - def has_attr(attr): + def has_attr(attr: str) -> bool: try: super(Descriptor, self).__getattribute__(attr) return True @@ -1154,31 +1166,31 @@ def has_attr(attr): return super(Descriptor, self).__getattribute__(name) - def __str__(self): + def __str__(self) -> str: return stem.util.str_tools._to_unicode(self._raw_contents) - def _compare(self, other, method): + def _compare(self, other: Any, method: Callable[[Any, Any], bool]) -> bool: if type(self) != type(other): return False return method(str(self).strip(), str(other).strip()) - def __hash__(self): + def __hash__(self) -> int: if self._hash is None: self._hash = hash(str(self).strip()) return self._hash - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s == o) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other - def __lt__(self, other): + def __lt__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s < o) - def __le__(self, other): + def __le__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s <= o) @@ -1187,27 +1199,31 @@ class NewlineNormalizer(object): File wrapper that normalizes CRLF line endings. """ - def __init__(self, wrapped_file): + def __init__(self, wrapped_file: BinaryIO) -> None: self._wrapped_file = wrapped_file self.name = getattr(wrapped_file, 'name', None) - def read(self, *args): + def read(self, *args: Any) -> bytes: return self._wrapped_file.read(*args).replace(b'\r\n', b'\n') - def readline(self, *args): + def readline(self, *args: Any) -> bytes: return self._wrapped_file.readline(*args).replace(b'\r\n', b'\n') - def readlines(self, *args): + def readlines(self, *args: Any) -> List[bytes]: return [line.rstrip(b'\r') for line in self._wrapped_file.readlines(*args)] - def seek(self, *args): + def seek(self, *args: Any) -> int: return self._wrapped_file.seek(*args) - def tell(self, *args): + def tell(self, *args: Any) -> int: return self._wrapped_file.tell(*args) -def _read_until_keywords(keywords, descriptor_file, inclusive = False, ignore_first = False, skip = False, end_position = None, include_ending_keyword = False): +def _read_until_keywords(keywords: Union[str, Sequence[str]], descriptor_file: BinaryIO, inclusive: bool = False, ignore_first: bool = False, skip: bool = False, end_position: Optional[int] = None) -> List[bytes]: + return _read_until_keywords_with_ending_keyword(keywords, descriptor_file, inclusive, ignore_first, skip, end_position, include_ending_keyword = False) # type: ignore + + +def _read_until_keywords_with_ending_keyword(keywords: Union[str, Sequence[str]], descriptor_file: BinaryIO, inclusive: bool = False, ignore_first: bool = False, skip: bool = False, end_position: Optional[int] = None, include_ending_keyword: bool = False) -> Tuple[List[bytes], str]: """ Reads from the descriptor file until we get to one of the given keywords or reach the end of the file. @@ -1226,7 +1242,7 @@ def _read_until_keywords(keywords, descriptor_file, inclusive = False, ignore_fi **True** """ - content = None if skip else [] + content = None if skip else [] # type: Optional[List[bytes]] ending_keyword = None if isinstance(keywords, (bytes, str)): @@ -1268,10 +1284,10 @@ def _read_until_keywords(keywords, descriptor_file, inclusive = False, ignore_fi if include_ending_keyword: return (content, ending_keyword) else: - return content + return content # type: ignore -def _bytes_for_block(content): +def _bytes_for_block(content: str) -> bytes: """ Provides the base64 decoded content of a pgp-style block. @@ -1289,7 +1305,7 @@ def _bytes_for_block(content): return base64.b64decode(stem.util.str_tools._to_bytes(content)) -def _get_pseudo_pgp_block(remaining_contents): +def _get_pseudo_pgp_block(remaining_contents: List[str]) -> Tuple[str, str]: """ Checks if given contents begins with a pseudo-Open-PGP-style block and, if so, pops it off and provides it back to the caller. @@ -1309,7 +1325,7 @@ def _get_pseudo_pgp_block(remaining_contents): if block_match: block_type = block_match.groups()[0] - block_lines = [] + block_lines = [] # type: List[str] end_line = PGP_BLOCK_END % block_type while True: @@ -1325,7 +1341,7 @@ def _get_pseudo_pgp_block(remaining_contents): return None -def create_signing_key(private_key = None): +def create_signing_key(private_key: Optional['cryptography.hazmat.backends.openssl.rsa._RSAPrivateKey'] = None) -> 'stem.descriptor.SigningKey': # type: ignore """ Serializes a signing key if we have one. Otherwise this creates a new signing key we can use to create descriptors. @@ -1361,11 +1377,11 @@ def create_signing_key(private_key = None): # # https://github.com/pyca/cryptography/issues/3713 - def no_op(*args, **kwargs): + def no_op(*args: Any, **kwargs: Any) -> int: return 1 - private_key._backend._lib.EVP_PKEY_CTX_set_signature_md = no_op - private_key._backend.openssl_assert = no_op + private_key._backend._lib.EVP_PKEY_CTX_set_signature_md = no_op # type: ignore + private_key._backend.openssl_assert = no_op # type: ignore public_key = private_key.public_key() public_digest = b'\n' + public_key.public_bytes( @@ -1376,7 +1392,7 @@ def no_op(*args, **kwargs): return SigningKey(private_key, public_key, public_digest) -def _append_router_signature(content, private_key): +def _append_router_signature(content: bytes, private_key: 'cryptography.hazmat.backends.openssl.rsa._RSAPrivateKey') -> bytes: # type: ignore """ Appends a router signature to a server or extrainfo descriptor. @@ -1397,23 +1413,23 @@ def _append_router_signature(content, private_key): return content + b'\n'.join([b'-----BEGIN SIGNATURE-----'] + stem.util.str_tools._split_by_length(signature, 64) + [b'-----END SIGNATURE-----\n']) -def _random_nickname(): +def _random_nickname() -> str: return ('Unnamed%i' % random.randint(0, 100000000000000))[:19] -def _random_fingerprint(): +def _random_fingerprint() -> str: return ('%040x' % random.randrange(16 ** 40)).upper() -def _random_ipv4_address(): +def _random_ipv4_address() -> str: return '%i.%i.%i.%i' % (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) -def _random_date(): +def _random_date() -> str: return '%i-%02i-%02i %02i:%02i:%02i' % (random.randint(2000, 2015), random.randint(1, 12), random.randint(1, 20), random.randint(0, 23), random.randint(0, 59), random.randint(0, 59)) -def _random_crypto_blob(block_type = None): +def _random_crypto_blob(block_type: Optional[str] = None) -> str: """ Provides a random string that can be used for crypto blocks. """ @@ -1427,7 +1443,11 @@ def _random_crypto_blob(block_type = None): return crypto_blob -def _descriptor_components(raw_contents, validate, extra_keywords = (), non_ascii_fields = ()): +def _descriptor_components(raw_contents: bytes, validate: bool, non_ascii_fields: Sequence[str] = ()) -> ENTRY_TYPE: + return _descriptor_components_with_extra(raw_contents, validate, (), non_ascii_fields) # type: ignore + + +def _descriptor_components_with_extra(raw_contents: bytes, validate: bool, extra_keywords: Sequence[str] = (), non_ascii_fields: Sequence[str] = ()) -> Tuple[ENTRY_TYPE, List[str]]: """ Initial breakup of the server descriptor contents to make parsing easier. @@ -1441,7 +1461,7 @@ def _descriptor_components(raw_contents, validate, extra_keywords = (), non_asci entries because this influences the resulting exit policy, but for everything else in server descriptors the order does not matter. - :param str raw_contents: descriptor content provided by the relay + :param bytes raw_contents: descriptor content provided by the relay :param bool validate: checks the validity of the descriptor's content if True, skips these checks otherwise :param list extra_keywords: entity keywords to put into a separate listing @@ -1454,12 +1474,9 @@ def _descriptor_components(raw_contents, validate, extra_keywords = (), non_asci value tuple, the second being a list of those entries. """ - if isinstance(raw_contents, bytes): - raw_contents = stem.util.str_tools._to_unicode(raw_contents) - - entries = collections.OrderedDict() + entries = collections.OrderedDict() # type: ENTRY_TYPE extra_entries = [] # entries with a keyword in extra_keywords - remaining_lines = raw_contents.split('\n') + remaining_lines = stem.util.str_tools._to_unicode(raw_contents).split('\n') while remaining_lines: line = remaining_lines.pop(0) @@ -1523,7 +1540,7 @@ def _descriptor_components(raw_contents, validate, extra_keywords = (), non_asci if extra_keywords: return entries, extra_entries else: - return entries + return entries # type: ignore # importing at the end to avoid circular dependencies on our Descriptor class diff --git a/stem/descriptor/bandwidth_file.py b/stem/descriptor/bandwidth_file.py index 3cf205958..f1f0b1e2e 100644 --- a/stem/descriptor/bandwidth_file.py +++ b/stem/descriptor/bandwidth_file.py @@ -21,7 +21,10 @@ import stem.util.str_tools +from typing import Any, BinaryIO, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type + from stem.descriptor import ( + ENTRY_TYPE, _mappings_for, Descriptor, ) @@ -50,7 +53,7 @@ class RecentStats(object): :var RelayFailures relay_failures: number of relays we failed to measure """ - def __init__(self): + def __init__(self) -> None: self.consensus_count = None self.prioritized_relays = None self.prioritized_relay_lists = None @@ -73,7 +76,7 @@ class RelayFailures(object): by default) """ - def __init__(self): + def __init__(self) -> None: self.no_measurement = None self.insuffient_period = None self.insufficient_measurements = None @@ -83,22 +86,22 @@ def __init__(self): # Converts header attributes to a given type. Malformed fields should be # ignored according to the spec. -def _str(val): +def _str(val: str) -> str: return val # already a str -def _int(val): +def _int(val: str) -> int: return int(val) if (val and val.isdigit()) else None -def _date(val): +def _date(val: str) -> datetime.datetime: try: return stem.util.str_tools._parse_iso_timestamp(val) except ValueError: return None # not an iso formatted date -def _csv(val): +def _csv(val: str) -> Sequence[str]: return list(map(lambda v: v.strip(), val.split(','))) if val is not None else None @@ -150,7 +153,7 @@ def _csv(val): } -def _parse_file(descriptor_file, validate = False, **kwargs): +def _parse_file(descriptor_file: BinaryIO, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.bandwidth_file.BandwidthFile']: """ Iterates over the bandwidth authority metrics in a file. @@ -166,11 +169,14 @@ def _parse_file(descriptor_file, validate = False, **kwargs): * **IOError** if the file can't be read """ - yield BandwidthFile(descriptor_file.read(), validate, **kwargs) + if kwargs: + raise ValueError('BUG: keyword arguments unused by bandwidth files') + + yield BandwidthFile(descriptor_file.read(), validate) -def _parse_header(descriptor, entries): - header = collections.OrderedDict() +def _parse_header(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: + header = collections.OrderedDict() # type: collections.OrderedDict[str, str] content = io.BytesIO(descriptor.get_bytes()) content.readline() # skip the first line, which should be the timestamp @@ -195,7 +201,7 @@ def _parse_header(descriptor, entries): if key == 'version': version_index = index else: - raise ValueError("Header expected to be key=value pairs, but had '%s'" % line) + raise ValueError("Header expected to be key=value pairs, but had '%s'" % stem.util.str_tools._to_unicode(line)) index += 1 @@ -214,16 +220,16 @@ def _parse_header(descriptor, entries): raise ValueError("The 'version' header must be in the second position") -def _parse_timestamp(descriptor, entries): +def _parse_timestamp(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: first_line = io.BytesIO(descriptor.get_bytes()).readline().strip() if first_line.isdigit(): descriptor.timestamp = datetime.datetime.utcfromtimestamp(int(first_line)) else: - raise ValueError("First line should be a unix timestamp, but was '%s'" % first_line) + raise ValueError("First line should be a unix timestamp, but was '%s'" % stem.util.str_tools._to_unicode(first_line)) -def _parse_body(descriptor, entries): +def _parse_body(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # In version 1.0.0 the body is everything after the first line. Otherwise # it's everything after the header's divider. @@ -237,13 +243,13 @@ def _parse_body(descriptor, entries): measurements = {} - for line in content.readlines(): - line = stem.util.str_tools._to_unicode(line.strip()) + for line_bytes in content.readlines(): + line = stem.util.str_tools._to_unicode(line_bytes.strip()) attr = dict(_mappings_for('measurement', line)) fingerprint = attr.get('node_id', '').lstrip('$') # bwauths prefix fingerprints with '$' if not fingerprint: - raise ValueError("Every meaurement must include 'node_id': %s" % line) + raise ValueError("Every meaurement must include 'node_id': %s" % stem.util.str_tools._to_unicode(line)) elif fingerprint in measurements: raise ValueError('Relay %s is listed multiple times. It should only be present once.' % fingerprint) @@ -296,12 +302,12 @@ class BandwidthFile(Descriptor): 'timestamp': (None, _parse_timestamp), 'header': ({}, _parse_header), 'measurements': ({}, _parse_body), - } + } # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]] ATTRIBUTES.update(dict([(k, (None, _parse_header)) for k in HEADER_ATTR.keys()])) @classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.bandwidth_file.BandwidthFile'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: """ Creates descriptor content with the given attributes. This descriptor type differs somewhat from others and treats our attr/exclude attributes as @@ -326,7 +332,7 @@ def content(cls, attr = None, exclude = ()): header = collections.OrderedDict(attr) if attr is not None else collections.OrderedDict() timestamp = header.pop('timestamp', str(int(time.time()))) - content = header.pop('content', []) + content = header.pop('content', []) # type: List[str] # type: ignore version = header.get('version', HEADER_DEFAULT.get('version')) lines = [] @@ -352,7 +358,7 @@ def content(cls, attr = None, exclude = ()): return b'\n'.join(lines) - def __init__(self, raw_content, validate = False): + def __init__(self, raw_content: bytes, validate: bool = False) -> None: super(BandwidthFile, self).__init__(raw_content, lazy_load = not validate) if validate: diff --git a/stem/descriptor/certificate.py b/stem/descriptor/certificate.py index 0522b883a..bc09be2dd 100644 --- a/stem/descriptor/certificate.py +++ b/stem/descriptor/certificate.py @@ -64,6 +64,8 @@ import stem.util.str_tools from stem.client.datatype import CertType, Field, Size, split +from stem.descriptor import ENTRY_TYPE +from typing import Callable, List, Optional, Sequence, Tuple, Union ED25519_KEY_LENGTH = 32 ED25519_HEADER_LENGTH = 40 @@ -88,7 +90,7 @@ class Ed25519Extension(Field): :var bytes data: data the extension concerns """ - def __init__(self, ext_type, flag_val, data): + def __init__(self, ext_type: 'stem.descriptor.certificate.ExtensionType', flag_val: int, data: bytes) -> None: self.type = ext_type self.flags = [] self.flag_int = flag_val if flag_val else 0 @@ -104,7 +106,7 @@ def __init__(self, ext_type, flag_val, data): if ext_type == ExtensionType.HAS_SIGNING_KEY and len(data) != 32: raise ValueError('Ed25519 HAS_SIGNING_KEY extension must be 32 bytes, but was %i.' % len(data)) - def pack(self): + def pack(self) -> bytes: encoded = bytearray() encoded += Size.SHORT.pack(len(self.data)) encoded += Size.CHAR.pack(self.type) @@ -113,7 +115,7 @@ def pack(self): return bytes(encoded) @staticmethod - def pop(content): + def pop(content: bytes) -> Tuple['stem.descriptor.certificate.Ed25519Extension', bytes]: if len(content) < 4: raise ValueError('Ed25519 extension is missing header fields') @@ -127,7 +129,7 @@ def pop(content): return Ed25519Extension(ext_type, flags, data), content - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'type', 'flag_int', 'data', cache = True) @@ -138,11 +140,11 @@ class Ed25519Certificate(object): :var int version: certificate format version """ - def __init__(self, version): + def __init__(self, version: int) -> None: self.version = version @staticmethod - def unpack(content): + def unpack(content: bytes) -> 'stem.descriptor.certificate.Ed25519Certificate': """ Parses a byte encoded ED25519 certificate. @@ -162,7 +164,7 @@ def unpack(content): raise ValueError('Ed25519 certificate is version %i. Parser presently only supports version 1.' % version) @staticmethod - def from_base64(content): + def from_base64(content: str) -> 'stem.descriptor.certificate.Ed25519Certificate': """ Parses a base64 encoded ED25519 certificate. @@ -189,7 +191,7 @@ def from_base64(content): except (TypeError, binascii.Error) as exc: raise ValueError("Ed25519 certificate wasn't propoerly base64 encoded (%s):\n%s" % (exc, content)) - def pack(self): + def pack(self) -> bytes: """ Encoded byte representation of our certificate. @@ -198,7 +200,7 @@ def pack(self): raise NotImplementedError('Certificate encoding has not been implemented for %s' % type(self).__name__) - def to_base64(self, pem = False): + def to_base64(self, pem: bool = False) -> str: """ Base64 encoded certificate data. @@ -206,7 +208,7 @@ def to_base64(self, pem = False): `_, for more information see `RFC 7468 `_ - :returns: **unicode** for our encoded certificate representation + :returns: **str** for our encoded certificate representation """ encoded = b'\n'.join(stem.util.str_tools._split_by_length(base64.b64encode(self.pack()), 64)) @@ -217,7 +219,7 @@ def to_base64(self, pem = False): return stem.util.str_tools._to_unicode(encoded) @staticmethod - def _from_descriptor(keyword, attribute): + def _from_descriptor(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: def _parse(descriptor, entries): value, block_type, block_contents = entries[keyword][0] @@ -228,7 +230,7 @@ def _parse(descriptor, entries): return _parse - def __str__(self): + def __str__(self) -> str: return self.to_base64(pem = True) @@ -252,7 +254,7 @@ class Ed25519CertificateV1(Ed25519Certificate): is unavailable """ - def __init__(self, cert_type = None, expiration = None, key_type = None, key = None, extensions = None, signature = None, signing_key = None): + def __init__(self, cert_type: Optional['stem.client.datatype.CertType'] = None, expiration: Optional[datetime.datetime] = None, key_type: Optional[int] = None, key: Optional[bytes] = None, extensions: Optional[Sequence['stem.descriptor.certificate.Ed25519Extension']] = None, signature: Optional[bytes] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> None: # type: ignore super(Ed25519CertificateV1, self).__init__(1) if cert_type is None: @@ -260,12 +262,15 @@ def __init__(self, cert_type = None, expiration = None, key_type = None, key = N elif key is None: raise ValueError('Certificate key is required') + self.type = None # type: Optional[stem.client.datatype.CertType] + self.type_int = None # type: Optional[int] + self.type, self.type_int = CertType.get(cert_type) - self.expiration = expiration if expiration else datetime.datetime.utcnow() + datetime.timedelta(hours = DEFAULT_EXPIRATION_HOURS) - self.key_type = key_type if key_type else 1 - self.key = stem.util._pubkey_bytes(key) - self.extensions = extensions if extensions else [] - self.signature = signature + self.expiration = expiration if expiration else datetime.datetime.utcnow() + datetime.timedelta(hours = DEFAULT_EXPIRATION_HOURS) # type: datetime.datetime + self.key_type = key_type if key_type else 1 # type: int + self.key = stem.util._pubkey_bytes(key) # type: bytes + self.extensions = list(extensions) if extensions else [] # type: List[stem.descriptor.certificate.Ed25519Extension] + self.signature = signature # type: Optional[bytes] if signing_key: calculated_sig = signing_key.sign(self.pack()) @@ -284,7 +289,7 @@ def __init__(self, cert_type = None, expiration = None, key_type = None, key = N elif self.type == CertType.UNKNOWN: raise ValueError('Ed25519 certificate type %i is unrecognized' % self.type_int) - def pack(self): + def pack(self) -> bytes: encoded = bytearray() encoded += Size.CHAR.pack(self.version) encoded += Size.CHAR.pack(self.type_int) @@ -302,7 +307,7 @@ def pack(self): return bytes(encoded) @staticmethod - def unpack(content): + def unpack(content: bytes) -> 'stem.descriptor.certificate.Ed25519CertificateV1': if len(content) < ED25519_HEADER_LENGTH + ED25519_SIGNATURE_LENGTH: raise ValueError('Ed25519 certificate was %i bytes, but should be at least %i' % (len(content), ED25519_HEADER_LENGTH + ED25519_SIGNATURE_LENGTH)) @@ -329,7 +334,7 @@ def unpack(content): return Ed25519CertificateV1(cert_type, datetime.datetime.utcfromtimestamp(expiration_hours * 3600), key_type, key, extensions, signature) - def is_expired(self): + def is_expired(self) -> bool: """ Checks if this certificate is presently expired or not. @@ -338,7 +343,7 @@ def is_expired(self): return datetime.datetime.now() > self.expiration - def signing_key(self): + def signing_key(self) -> bytes: """ Provides this certificate's signing key. @@ -354,7 +359,7 @@ def signing_key(self): return None - def validate(self, descriptor): + def validate(self, descriptor: Union['stem.descriptor.server_descriptor.RelayDescriptor', 'stem.descriptor.hidden_service.HiddenServiceDescriptorV3']) -> None: """ Validate our descriptor content matches its ed25519 signature. Supported descriptor types include... @@ -410,7 +415,7 @@ def validate(self, descriptor): raise ValueError('Descriptor Ed25519 certificate signature invalid (signature forged or corrupt)') @staticmethod - def _signed_content(descriptor): + def _signed_content(descriptor: Union['stem.descriptor.server_descriptor.RelayDescriptor', 'stem.descriptor.hidden_service.HiddenServiceDescriptorV3']) -> bytes: """ Provides this descriptor's signing constant, appended with the portion of the descriptor that's signed. diff --git a/stem/descriptor/collector.py b/stem/descriptor/collector.py index 7aeb298bf..9749dadb0 100644 --- a/stem/descriptor/collector.py +++ b/stem/descriptor/collector.py @@ -63,6 +63,7 @@ import stem.util.str_tools from stem.descriptor import Compression, DocumentHandler +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union COLLECTOR_URL = 'https://collector.torproject.org/' REFRESH_INDEX_RATE = 3600 # get new index if cached copy is an hour old @@ -76,7 +77,7 @@ FUTURE = datetime.datetime(9999, 1, 1) -def get_instance(): +def get_instance() -> 'stem.descriptor.collector.CollecTor': """ Provides the singleton :class:`~stem.descriptor.collector.CollecTor` used for this module's shorthand functions. @@ -92,7 +93,7 @@ def get_instance(): return SINGLETON_COLLECTOR -def get_server_descriptors(start = None, end = None, cache_to = None, bridge = False, timeout = None, retries = 3): +def get_server_descriptors(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.server_descriptor.RelayDescriptor]: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_server_descriptors` @@ -103,7 +104,7 @@ def get_server_descriptors(start = None, end = None, cache_to = None, bridge = F yield desc -def get_extrainfo_descriptors(start = None, end = None, cache_to = None, bridge = False, timeout = None, retries = 3): +def get_extrainfo_descriptors(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor]: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_extrainfo_descriptors` @@ -114,7 +115,7 @@ def get_extrainfo_descriptors(start = None, end = None, cache_to = None, bridge yield desc -def get_microdescriptors(start = None, end = None, cache_to = None, timeout = None, retries = 3): +def get_microdescriptors(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.microdescriptor.Microdescriptor]: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_microdescriptors` @@ -125,7 +126,7 @@ def get_microdescriptors(start = None, end = None, cache_to = None, timeout = No yield desc -def get_consensus(start = None, end = None, cache_to = None, document_handler = DocumentHandler.ENTRIES, version = 3, microdescriptor = False, bridge = False, timeout = None, retries = 3): +def get_consensus(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, document_handler: stem.descriptor.DocumentHandler = DocumentHandler.ENTRIES, version: int = 3, microdescriptor: bool = False, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.router_status_entry.RouterStatusEntry]: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_consensus` @@ -136,7 +137,7 @@ def get_consensus(start = None, end = None, cache_to = None, document_handler = yield desc -def get_key_certificates(start = None, end = None, cache_to = None, timeout = None, retries = 3): +def get_key_certificates(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.networkstatus.KeyCertificate]: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_key_certificates` @@ -147,7 +148,7 @@ def get_key_certificates(start = None, end = None, cache_to = None, timeout = No yield desc -def get_bandwidth_files(start = None, end = None, cache_to = None, timeout = None, retries = 3): +def get_bandwidth_files(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.bandwidth_file.BandwidthFile]: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_bandwidth_files` @@ -158,7 +159,7 @@ def get_bandwidth_files(start = None, end = None, cache_to = None, timeout = Non yield desc -def get_exit_lists(start = None, end = None, cache_to = None, timeout = None, retries = 3): +def get_exit_lists(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.tordnsel.TorDNSEL]: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_exit_lists` @@ -187,14 +188,14 @@ class File(object): :var datetime last_modified: when the file was last modified """ - def __init__(self, path, types, size, sha256, first_published, last_published, last_modified): + def __init__(self, path: str, types: Tuple[str], size: int, sha256: str, first_published: str, last_published: str, last_modified: str) -> None: self.path = path self.types = tuple(types) if types else () self.compression = File._guess_compression(path) self.size = size self.sha256 = sha256 self.last_modified = datetime.datetime.strptime(last_modified, '%Y-%m-%d %H:%M') - self._downloaded_to = None # location we last downloaded to + self._downloaded_to = None # type: Optional[str] # location we last downloaded to # Most descriptor types have publication time fields, but microdescriptors # don't because these files lack timestamps to parse. @@ -205,7 +206,7 @@ def __init__(self, path, types, size, sha256, first_published, last_published, l else: self.start, self.end = File._guess_time_range(path) - def read(self, directory = None, descriptor_type = None, start = None, end = None, document_handler = DocumentHandler.ENTRIES, timeout = None, retries = 3): + def read(self, directory: Optional[str] = None, descriptor_type: Optional[str] = None, start: datetime.datetime = None, end: datetime.datetime = None, document_handler: stem.descriptor.DocumentHandler = DocumentHandler.ENTRIES, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.Descriptor]: """ Provides descriptors from this archive. Descriptors are downloaded or read from disk as follows... @@ -289,7 +290,7 @@ def read(self, directory = None, descriptor_type = None, start = None, end = Non yield desc - def download(self, directory, decompress = True, timeout = None, retries = 3, overwrite = False): + def download(self, directory: str, decompress: bool = True, timeout: Optional[int] = None, retries: Optional[int] = 3, overwrite: bool = False) -> str: """ Downloads this file to the given location. If a file already exists this is a no-op. @@ -324,8 +325,8 @@ def download(self, directory, decompress = True, timeout = None, retries = 3, ov # check if this file already exists with the correct checksum if os.path.exists(path): - with open(path) as prior_file: - expected_hash = binascii.hexlify(base64.b64decode(self.sha256)) + with open(path, 'b') as prior_file: + expected_hash = binascii.hexlify(base64.b64decode(self.sha256)).decode('utf-8') actual_hash = hashlib.sha256(prior_file.read()).hexdigest() if expected_hash == actual_hash: @@ -345,7 +346,7 @@ def download(self, directory, decompress = True, timeout = None, retries = 3, ov return path @staticmethod - def _guess_compression(path): + def _guess_compression(path: str) -> stem.descriptor._Compression: """ Determine file comprssion from CollecTor's filename. """ @@ -357,7 +358,7 @@ def _guess_compression(path): return Compression.PLAINTEXT @staticmethod - def _guess_time_range(path): + def _guess_time_range(path: str) -> Tuple[datetime.datetime, datetime.datetime]: """ Attemt to determine the (start, end) time range from CollecTor's filename. This provides (None, None) if this cannot be determined. @@ -398,15 +399,15 @@ class CollecTor(object): :var float timeout: duration before we'll time out our request """ - def __init__(self, retries = 2, timeout = None): + def __init__(self, retries: Optional[int] = 2, timeout: Optional[int] = None) -> None: self.retries = retries self.timeout = timeout self._cached_index = None - self._cached_files = None - self._cached_index_at = 0 + self._cached_files = None # type: Optional[List[File]] + self._cached_index_at = 0.0 - def get_server_descriptors(self, start = None, end = None, cache_to = None, bridge = False, timeout = None, retries = 3): + def get_server_descriptors(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.server_descriptor.RelayDescriptor]: """ Provides server descriptors published during the given time range, sorted oldest to newest. @@ -431,9 +432,9 @@ def get_server_descriptors(self, start = None, end = None, cache_to = None, brid for f in self.files(desc_type, start, end): for desc in f.read(cache_to, desc_type, start, end, timeout = timeout, retries = retries): - yield desc + yield desc # type: ignore - def get_extrainfo_descriptors(self, start = None, end = None, cache_to = None, bridge = False, timeout = None, retries = 3): + def get_extrainfo_descriptors(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor]: """ Provides extrainfo descriptors published during the given time range, sorted oldest to newest. @@ -458,9 +459,9 @@ def get_extrainfo_descriptors(self, start = None, end = None, cache_to = None, b for f in self.files(desc_type, start, end): for desc in f.read(cache_to, desc_type, start, end, timeout = timeout, retries = retries): - yield desc + yield desc # type: ignore - def get_microdescriptors(self, start = None, end = None, cache_to = None, timeout = None, retries = 3): + def get_microdescriptors(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.microdescriptor.Microdescriptor]: """ Provides microdescriptors estimated to be published during the given time range, sorted oldest to newest. Unlike server/extrainfo descriptors, @@ -492,9 +493,9 @@ def get_microdescriptors(self, start = None, end = None, cache_to = None, timeou for f in self.files('microdescriptor', start, end): for desc in f.read(cache_to, 'microdescriptor', start, end, timeout = timeout, retries = retries): - yield desc + yield desc # type: ignore - def get_consensus(self, start = None, end = None, cache_to = None, document_handler = DocumentHandler.ENTRIES, version = 3, microdescriptor = False, bridge = False, timeout = None, retries = 3): + def get_consensus(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, document_handler: stem.descriptor.DocumentHandler = DocumentHandler.ENTRIES, version: int = 3, microdescriptor: bool = False, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.router_status_entry.RouterStatusEntry]: """ Provides consensus router status entries published during the given time range, sorted oldest to newest. @@ -536,9 +537,9 @@ def get_consensus(self, start = None, end = None, cache_to = None, document_hand for f in self.files(desc_type, start, end): for desc in f.read(cache_to, desc_type, start, end, document_handler, timeout = timeout, retries = retries): - yield desc + yield desc # type: ignore - def get_key_certificates(self, start = None, end = None, cache_to = None, timeout = None, retries = 3): + def get_key_certificates(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.networkstatus.KeyCertificate]: """ Directory authority key certificates for the given time range, sorted oldest to newest. @@ -560,9 +561,9 @@ def get_key_certificates(self, start = None, end = None, cache_to = None, timeou for f in self.files('dir-key-certificate-3', start, end): for desc in f.read(cache_to, 'dir-key-certificate-3', start, end, timeout = timeout, retries = retries): - yield desc + yield desc # type: ignore - def get_bandwidth_files(self, start = None, end = None, cache_to = None, timeout = None, retries = 3): + def get_bandwidth_files(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.bandwidth_file.BandwidthFile]: """ Bandwidth authority heuristics for the given time range, sorted oldest to newest. @@ -584,9 +585,9 @@ def get_bandwidth_files(self, start = None, end = None, cache_to = None, timeout for f in self.files('bandwidth-file', start, end): for desc in f.read(cache_to, 'bandwidth-file', start, end, timeout = timeout, retries = retries): - yield desc + yield desc # type: ignore - def get_exit_lists(self, start = None, end = None, cache_to = None, timeout = None, retries = 3): + def get_exit_lists(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.tordnsel.TorDNSEL]: """ `TorDNSEL exit lists `_ for the given time range, sorted oldest to newest. @@ -608,9 +609,9 @@ def get_exit_lists(self, start = None, end = None, cache_to = None, timeout = No for f in self.files('tordnsel', start, end): for desc in f.read(cache_to, 'tordnsel', start, end, timeout = timeout, retries = retries): - yield desc + yield desc # type: ignore - def index(self, compression = 'best'): + def index(self, compression: Union[str, stem.descriptor._Compression] = 'best') -> Dict[str, Any]: """ Provides the archives available in CollecTor. @@ -631,21 +632,25 @@ def index(self, compression = 'best'): if compression == 'best': for option in (Compression.LZMA, Compression.BZ2, Compression.GZIP, Compression.PLAINTEXT): if option.available: - compression = option + compression_enum = option break elif compression is None: - compression = Compression.PLAINTEXT + compression_enum = Compression.PLAINTEXT + elif isinstance(compression, stem.descriptor._Compression): + compression_enum = compression + else: + raise ValueError('compression must be a descriptor.Compression, was %s (%s)' % (compression, type(compression).__name__)) - extension = compression.extension if compression != Compression.PLAINTEXT else '' + extension = compression_enum.extension if compression_enum != Compression.PLAINTEXT else '' url = COLLECTOR_URL + 'index/index.json' + extension - response = compression.decompress(stem.util.connection.download(url, self.timeout, self.retries)) + response = compression_enum.decompress(stem.util.connection.download(url, self.timeout, self.retries)) self._cached_index = json.loads(stem.util.str_tools._to_unicode(response)) self._cached_index_at = time.time() return self._cached_index - def files(self, descriptor_type = None, start = None, end = None): + def files(self, descriptor_type: Optional[str] = None, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None) -> List['stem.descriptor.collector.File']: """ Provides files CollecTor presently has, sorted oldest to newest. @@ -680,7 +685,7 @@ def files(self, descriptor_type = None, start = None, end = None): return matches @staticmethod - def _files(val, path): + def _files(val: Dict[str, Any], path: List[str]) -> List['stem.descriptor.collector.File']: """ Recursively provies files within the index. @@ -697,7 +702,7 @@ def _files(val, path): for k, v in val.items(): if k == 'files': - for attr in v: + for attr in v: # Dict[str, str] file_path = '/'.join(path + [attr.get('path')]) files.append(File(file_path, attr.get('types'), attr.get('size'), attr.get('sha256'), attr.get('first_published'), attr.get('last_published'), attr.get('last_modified'))) elif k == 'directories': diff --git a/stem/descriptor/extrainfo_descriptor.py b/stem/descriptor/extrainfo_descriptor.py index d92bb770e..cd9467d16 100644 --- a/stem/descriptor/extrainfo_descriptor.py +++ b/stem/descriptor/extrainfo_descriptor.py @@ -67,6 +67,7 @@ ===================== =========== """ +import datetime import functools import hashlib import re @@ -75,7 +76,10 @@ import stem.util.enum import stem.util.str_tools +from typing import Any, BinaryIO, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union + from stem.descriptor import ( + ENTRY_TYPE, PGP_BLOCK_END, Descriptor, DigestHash, @@ -163,7 +167,7 @@ _locale_re = re.compile('^[a-zA-Z0-9\\?]{2}$') -def _parse_file(descriptor_file, is_bridge = False, validate = False, **kwargs): +def _parse_file(descriptor_file: BinaryIO, is_bridge = False, validate = False, **kwargs: Any) -> Iterator['stem.descriptor.extrainfo_descriptor.ExtraInfoDescriptor']: """ Iterates over the extra-info descriptors in a file. @@ -181,6 +185,9 @@ def _parse_file(descriptor_file, is_bridge = False, validate = False, **kwargs): * **IOError** if the file can't be read """ + if kwargs: + raise ValueError('BUG: keyword arguments unused by extrainfo descriptors') + while True: if not is_bridge: extrainfo_content = _read_until_keywords('router-signature', descriptor_file) @@ -197,14 +204,14 @@ def _parse_file(descriptor_file, is_bridge = False, validate = False, **kwargs): extrainfo_content = extrainfo_content[1:] if is_bridge: - yield BridgeExtraInfoDescriptor(bytes.join(b'', extrainfo_content), validate, **kwargs) + yield BridgeExtraInfoDescriptor(bytes.join(b'', extrainfo_content), validate) else: - yield RelayExtraInfoDescriptor(bytes.join(b'', extrainfo_content), validate, **kwargs) + yield RelayExtraInfoDescriptor(bytes.join(b'', extrainfo_content), validate) else: break # done parsing file -def _parse_timestamp_and_interval(keyword, content): +def _parse_timestamp_and_interval(keyword: str, content: str) -> Tuple[datetime.datetime, int, str]: """ Parses a 'YYYY-MM-DD HH:MM:SS (NSEC s) *' entry. @@ -238,7 +245,7 @@ def _parse_timestamp_and_interval(keyword, content): raise ValueError("%s line's timestamp wasn't parsable: %s" % (keyword, line)) -def _parse_extra_info_line(descriptor, entries): +def _parse_extra_info_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "extra-info" Nickname Fingerprint value = _value('extra-info', entries) @@ -255,7 +262,7 @@ def _parse_extra_info_line(descriptor, entries): descriptor.fingerprint = extra_info_comp[1] -def _parse_transport_line(descriptor, entries): +def _parse_transport_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "transport" transportname address:port [arglist] # Everything after the transportname is scrubbed in published bridge # descriptors, so we'll never see it in practice. @@ -301,7 +308,7 @@ def _parse_transport_line(descriptor, entries): descriptor.transport = transports -def _parse_padding_counts_line(descriptor, entries): +def _parse_padding_counts_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "padding-counts" YYYY-MM-DD HH:MM:SS (NSEC s) key=val key=val... value = _value('padding-counts', entries) @@ -316,7 +323,7 @@ def _parse_padding_counts_line(descriptor, entries): setattr(descriptor, 'padding_counts', counts) -def _parse_dirreq_line(keyword, recognized_counts_attr, unrecognized_counts_attr, descriptor, entries): +def _parse_dirreq_line(keyword: str, recognized_counts_attr: str, unrecognized_counts_attr: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value(keyword, entries) recognized_counts = {} @@ -340,7 +347,7 @@ def _parse_dirreq_line(keyword, recognized_counts_attr, unrecognized_counts_attr setattr(descriptor, unrecognized_counts_attr, unrecognized_counts) -def _parse_dirreq_share_line(keyword, attribute, descriptor, entries): +def _parse_dirreq_share_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value(keyword, entries) if not value.endswith('%'): @@ -353,7 +360,7 @@ def _parse_dirreq_share_line(keyword, attribute, descriptor, entries): setattr(descriptor, attribute, float(value[:-1]) / 100) -def _parse_cell_line(keyword, attribute, descriptor, entries): +def _parse_cell_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "" num,...,num value = _value(keyword, entries) @@ -375,7 +382,7 @@ def _parse_cell_line(keyword, attribute, descriptor, entries): raise exc -def _parse_timestamp_and_interval_line(keyword, end_attribute, interval_attribute, descriptor, entries): +def _parse_timestamp_and_interval_line(keyword: str, end_attribute: str, interval_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "" YYYY-MM-DD HH:MM:SS (NSEC s) timestamp, interval, _ = _parse_timestamp_and_interval(keyword, _value(keyword, entries)) @@ -383,7 +390,7 @@ def _parse_timestamp_and_interval_line(keyword, end_attribute, interval_attribut setattr(descriptor, interval_attribute, interval) -def _parse_conn_bi_direct_line(descriptor, entries): +def _parse_conn_bi_direct_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "conn-bi-direct" YYYY-MM-DD HH:MM:SS (NSEC s) BELOW,READ,WRITE,BOTH value = _value('conn-bi-direct', entries) @@ -401,7 +408,7 @@ def _parse_conn_bi_direct_line(descriptor, entries): descriptor.conn_bi_direct_both = int(stats[3]) -def _parse_history_line(keyword, end_attribute, interval_attribute, values_attribute, descriptor, entries): +def _parse_history_line(keyword: str, end_attribute: str, interval_attribute: str, values_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "" YYYY-MM-DD HH:MM:SS (NSEC s) NUM,NUM,NUM,NUM,NUM... value = _value(keyword, entries) @@ -419,7 +426,7 @@ def _parse_history_line(keyword, end_attribute, interval_attribute, values_attri setattr(descriptor, values_attribute, history_values) -def _parse_port_count_line(keyword, attribute, descriptor, entries): +def _parse_port_count_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "" port=N,port=N,... value, port_mappings = _value(keyword, entries), {} @@ -428,13 +435,13 @@ def _parse_port_count_line(keyword, attribute, descriptor, entries): if (port != 'other' and not stem.util.connection.is_valid_port(port)) or not stat.isdigit(): raise ValueError('Entries in %s line should only be PORT=N entries: %s %s' % (keyword, keyword, value)) - port = int(port) if port.isdigit() else port + port = int(port) if port.isdigit() else port # type: ignore # this can be an int or 'other' port_mappings[port] = int(stat) setattr(descriptor, attribute, port_mappings) -def _parse_geoip_to_count_line(keyword, attribute, descriptor, entries): +def _parse_geoip_to_count_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "" CC=N,CC=N,... # # The maxmind geoip (https://www.maxmind.com/app/iso3166) has numeric @@ -454,7 +461,7 @@ def _parse_geoip_to_count_line(keyword, attribute, descriptor, entries): setattr(descriptor, attribute, locale_usage) -def _parse_bridge_ip_versions_line(descriptor, entries): +def _parse_bridge_ip_versions_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value, ip_versions = _value('bridge-ip-versions', entries), {} for protocol, count in _mappings_for('bridge-ip-versions', value, divider = ','): @@ -466,7 +473,7 @@ def _parse_bridge_ip_versions_line(descriptor, entries): descriptor.ip_versions = ip_versions -def _parse_bridge_ip_transports_line(descriptor, entries): +def _parse_bridge_ip_transports_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value, ip_transports = _value('bridge-ip-transports', entries), {} for protocol, count in _mappings_for('bridge-ip-transports', value, divider = ','): @@ -478,7 +485,7 @@ def _parse_bridge_ip_transports_line(descriptor, entries): descriptor.ip_transports = ip_transports -def _parse_hs_stats(keyword, stat_attribute, extra_attribute, descriptor, entries): +def _parse_hs_stats(keyword: str, stat_attribute: str, extra_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "" num key=val key=val... value, stat, extra = _value(keyword, entries), None, {} @@ -765,7 +772,7 @@ class ExtraInfoDescriptor(Descriptor): 'ip_versions': (None, _parse_bridge_ip_versions_line), 'ip_transports': (None, _parse_bridge_ip_transports_line), - } + } # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]] PARSER_FOR_LINE = { 'extra-info': _parse_extra_info_line, @@ -814,7 +821,7 @@ class ExtraInfoDescriptor(Descriptor): 'bridge-ip-transports': _parse_bridge_ip_transports_line, } - def __init__(self, raw_contents, validate = False): + def __init__(self, raw_contents: bytes, validate: bool = False) -> None: """ Extra-info descriptor constructor. By default this validates the descriptor's content as it's parsed. This validation can be disabled to @@ -851,7 +858,7 @@ def __init__(self, raw_contents, validate = False): else: self._entries = entries - def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore """ Digest of this descriptor's content. These are referenced by... @@ -876,13 +883,13 @@ def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): raise NotImplementedError('Unsupported Operation: this should be implemented by the ExtraInfoDescriptor subclass') - def _required_fields(self): + def _required_fields(self) -> Tuple[str, ...]: return REQUIRED_FIELDS - def _first_keyword(self): + def _first_keyword(self) -> str: return 'extra-info' - def _last_keyword(self): + def _last_keyword(self) -> str: return 'router-signature' @@ -917,7 +924,7 @@ class RelayExtraInfoDescriptor(ExtraInfoDescriptor): }) @classmethod - def content(cls, attr = None, exclude = (), sign = False, signing_key = None): + def content(cls: Type['stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None) -> bytes: base_header = ( ('extra-info', '%s %s' % (_random_nickname(), _random_fingerprint())), ('published', _random_date()), @@ -938,11 +945,11 @@ def content(cls, attr = None, exclude = (), sign = False, signing_key = None): )) @classmethod - def create(cls, attr = None, exclude = (), validate = True, sign = False, signing_key = None): + def create(cls: Type['stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None) -> 'stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor': return cls(cls.content(attr, exclude, sign, signing_key), validate = validate) @functools.lru_cache() - def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore if hash_type == DigestHash.SHA1: # our digest is calculated from everything except our signature @@ -986,7 +993,7 @@ class BridgeExtraInfoDescriptor(ExtraInfoDescriptor): }) @classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.extrainfo_descriptor.BridgeExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('extra-info', 'ec2bridgereaac65a3 %s' % _random_fingerprint()), ('published', _random_date()), @@ -994,7 +1001,7 @@ def content(cls, attr = None, exclude = ()): ('router-digest', _random_fingerprint()), )) - def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore if hash_type == DigestHash.SHA1 and encoding == DigestEncoding.HEX: return self._digest elif hash_type == DigestHash.SHA256 and encoding == DigestEncoding.BASE64: @@ -1002,7 +1009,7 @@ def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): else: raise NotImplementedError('Bridge extrainfo digests are only available as sha1/hex and sha256/base64, not %s/%s' % (hash_type, encoding)) - def _required_fields(self): + def _required_fields(self) -> Tuple[str, ...]: excluded_fields = [ 'router-signature', ] @@ -1013,5 +1020,5 @@ def _required_fields(self): return tuple(included_fields + [f for f in REQUIRED_FIELDS if f not in excluded_fields]) - def _last_keyword(self): + def _last_keyword(self) -> str: return None diff --git a/stem/descriptor/hidden_service.py b/stem/descriptor/hidden_service.py index 75a78d2ef..2eb7d02fb 100644 --- a/stem/descriptor/hidden_service.py +++ b/stem/descriptor/hidden_service.py @@ -51,8 +51,10 @@ from stem.client.datatype import CertType from stem.descriptor.certificate import ExtensionType, Ed25519Extension, Ed25519Certificate, Ed25519CertificateV1 +from typing import Any, BinaryIO, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union from stem.descriptor import ( + ENTRY_TYPE, PGP_BLOCK_END, Descriptor, _descriptor_content, @@ -103,7 +105,7 @@ 'onion_key': None, 'service_key': None, 'intro_authentication': [], -} +} # type: Dict[str, Any] # introduction-point fields that can only appear once @@ -132,7 +134,7 @@ class DecryptionFailure(Exception): """ -class IntroductionPointV2(collections.namedtuple('IntroductionPointV2', INTRODUCTION_POINTS_ATTR.keys())): +class IntroductionPointV2(collections.namedtuple('IntroductionPointV2', INTRODUCTION_POINTS_ATTR.keys())): # type: ignore """ Introduction point for a v2 hidden service. @@ -162,7 +164,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s """ @staticmethod - def parse(content): + def parse(content: bytes) -> 'stem.descriptor.hidden_service.IntroductionPointV3': """ Parses an introduction point from its descriptor content. @@ -174,7 +176,7 @@ def parse(content): """ entry = _descriptor_components(content, False) - link_specifiers = IntroductionPointV3._parse_link_specifiers(_value('introduction-point', entry)) + link_specifiers = IntroductionPointV3._parse_link_specifiers(stem.util.str_tools._to_bytes(_value('introduction-point', entry))) onion_key_line = _value('onion-key', entry) onion_key = onion_key_line[5:] if onion_key_line.startswith('ntor ') else None @@ -200,7 +202,7 @@ def parse(content): return IntroductionPointV3(link_specifiers, onion_key, auth_key_cert, enc_key, enc_key_cert, legacy_key, legacy_key_cert) @staticmethod - def create_for_address(address, port, expiration = None, onion_key = None, enc_key = None, auth_key = None, signing_key = None): + def create_for_address(address: str, port: int, expiration: Optional[datetime.datetime] = None, onion_key: Optional[str] = None, enc_key: Optional[str] = None, auth_key: Optional[str] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> 'stem.descriptor.hidden_service.IntroductionPointV3': # type: ignore """ Simplified constructor for a single address/port link specifier. @@ -222,6 +224,8 @@ def create_for_address(address, port, expiration = None, onion_key = None, enc_k if not stem.util.connection.is_valid_port(port): raise ValueError("'%s' is an invalid port" % port) + link_specifiers = None # type: Optional[List[stem.client.datatype.LinkSpecifier]] + if stem.util.connection.is_valid_ipv4_address(address): link_specifiers = [stem.client.datatype.LinkByIPv4(address, port)] elif stem.util.connection.is_valid_ipv6_address(address): @@ -232,7 +236,7 @@ def create_for_address(address, port, expiration = None, onion_key = None, enc_k return IntroductionPointV3.create_for_link_specifiers(link_specifiers, expiration = None, onion_key = None, enc_key = None, auth_key = None, signing_key = None) @staticmethod - def create_for_link_specifiers(link_specifiers, expiration = None, onion_key = None, enc_key = None, auth_key = None, signing_key = None): + def create_for_link_specifiers(link_specifiers: Sequence['stem.client.datatype.LinkSpecifier'], expiration: Optional[datetime.datetime] = None, onion_key: Optional[str] = None, enc_key: Optional[str] = None, auth_key: Optional[str] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> 'stem.descriptor.hidden_service.IntroductionPointV3': # type: ignore """ Simplified constructor. For more sophisticated use cases you can use this as a template for how introduction points are properly created. @@ -271,7 +275,7 @@ def create_for_link_specifiers(link_specifiers, expiration = None, onion_key = N return IntroductionPointV3(link_specifiers, onion_key, auth_key_cert, enc_key, enc_key_cert, None, None) - def encode(self): + def encode(self) -> str: """ Descriptor representation of this introduction point. @@ -299,7 +303,7 @@ def encode(self): return '\n'.join(lines) - def onion_key(self): + def onion_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey': # type: ignore """ Provides our ntor introduction point public key. @@ -312,7 +316,7 @@ def onion_key(self): return IntroductionPointV3._key_as(self.onion_key_raw, x25519 = True) - def auth_key(self): + def auth_key(self) -> 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey': # type: ignore """ Provides our authentication certificate's public key. @@ -325,7 +329,7 @@ def auth_key(self): return IntroductionPointV3._key_as(self.auth_key_cert.key, ed25519 = True) - def enc_key(self): + def enc_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey': # type: ignore """ Provides our encryption key. @@ -338,7 +342,7 @@ def enc_key(self): return IntroductionPointV3._key_as(self.enc_key_raw, x25519 = True) - def legacy_key(self): + def legacy_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey': # type: ignore """ Provides our legacy introduction point public key. @@ -352,7 +356,7 @@ def legacy_key(self): return IntroductionPointV3._key_as(self.legacy_key_raw, x25519 = True) @staticmethod - def _key_as(value, x25519 = False, ed25519 = False): + def _key_as(value: bytes, x25519: bool = False, ed25519: bool = False) -> Union['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey']: # type: ignore if value is None or (not x25519 and not ed25519): return value @@ -375,11 +379,11 @@ def _key_as(value, x25519 = False, ed25519 = False): return Ed25519PublicKey.from_public_bytes(value) @staticmethod - def _parse_link_specifiers(content): + def _parse_link_specifiers(content: bytes) -> List['stem.client.datatype.LinkSpecifier']: try: content = base64.b64decode(content) except Exception as exc: - raise ValueError('Unable to base64 decode introduction point (%s): %s' % (exc, content)) + raise ValueError('Unable to base64 decode introduction point (%s): %s' % (exc, stem.util.str_tools._to_unicode(content))) link_specifiers = [] count, content = stem.client.datatype.Size.CHAR.pop(content) @@ -389,20 +393,20 @@ def _parse_link_specifiers(content): link_specifiers.append(link_specifier) if content: - raise ValueError('Introduction point had excessive data (%s)' % content) + raise ValueError('Introduction point had excessive data (%s)' % stem.util.str_tools._to_unicode(content)) return link_specifiers - def __hash__(self): + def __hash__(self) -> int: if not hasattr(self, '_hash'): self._hash = hash(self.encode()) return self._hash - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, IntroductionPointV3) else False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other @@ -417,22 +421,22 @@ class AuthorizedClient(object): :var str cookie: base64 encoded authentication cookie """ - def __init__(self, id = None, iv = None, cookie = None): + def __init__(self, id: Optional[str] = None, iv: Optional[str] = None, cookie: Optional[str] = None) -> None: self.id = stem.util.str_tools._to_unicode(id if id else base64.b64encode(os.urandom(8)).rstrip(b'=')) self.iv = stem.util.str_tools._to_unicode(iv if iv else base64.b64encode(os.urandom(16)).rstrip(b'=')) self.cookie = stem.util.str_tools._to_unicode(cookie if cookie else base64.b64encode(os.urandom(16)).rstrip(b'=')) - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'id', 'iv', 'cookie', cache = True) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, AuthorizedClient) else False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other -def _parse_file(descriptor_file, desc_type = None, validate = False, **kwargs): +def _parse_file(descriptor_file: BinaryIO, desc_type: Optional[Type['stem.descriptor.hidden_service.HiddenServiceDescriptor']] = None, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.hidden_service.HiddenServiceDescriptor']: """ Iterates over the hidden service descriptors in a file. @@ -442,7 +446,7 @@ def _parse_file(descriptor_file, desc_type = None, validate = False, **kwargs): **True**, skips these checks otherwise :param dict kwargs: additional arguments for the descriptor constructor - :returns: iterator for :class:`~stem.descriptor.hidden_service.HiddenServiceDescriptorV2` + :returns: iterator for :class:`~stem.descriptor.hidden_service.HiddenServiceDescriptor` instances in the file :raises: @@ -467,12 +471,12 @@ def _parse_file(descriptor_file, desc_type = None, validate = False, **kwargs): if descriptor_content[0].startswith(b'@type'): descriptor_content = descriptor_content[1:] - yield desc_type(bytes.join(b'', descriptor_content), validate, **kwargs) + yield desc_type(bytes.join(b'', descriptor_content), validate, **kwargs) # type: ignore else: break # done parsing file -def _decrypt_layer(encrypted_block, constant, revision_counter, subcredential, blinded_key): +def _decrypt_layer(encrypted_block: str, constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> str: if encrypted_block.startswith('-----BEGIN MESSAGE-----\n') and encrypted_block.endswith('\n-----END MESSAGE-----'): encrypted_block = encrypted_block[24:-22] @@ -491,7 +495,7 @@ def _decrypt_layer(encrypted_block, constant, revision_counter, subcredential, b cipher, mac_for = _layer_cipher(constant, revision_counter, subcredential, blinded_key, salt) if expected_mac != mac_for(ciphertext): - raise ValueError('Malformed mac (expected %s, but was %s)' % (expected_mac, mac_for(ciphertext))) + raise ValueError('Malformed mac (expected %s, but was %s)' % (stem.util.str_tools._to_unicode(expected_mac), stem.util.str_tools._to_unicode(mac_for(ciphertext)))) decryptor = cipher.decryptor() plaintext = decryptor.update(ciphertext) + decryptor.finalize() @@ -499,7 +503,7 @@ def _decrypt_layer(encrypted_block, constant, revision_counter, subcredential, b return stem.util.str_tools._to_unicode(plaintext) -def _encrypt_layer(plaintext, constant, revision_counter, subcredential, blinded_key): +def _encrypt_layer(plaintext: bytes, constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes: salt = os.urandom(16) cipher, mac_for = _layer_cipher(constant, revision_counter, subcredential, blinded_key, salt) @@ -510,7 +514,7 @@ def _encrypt_layer(plaintext, constant, revision_counter, subcredential, blinded return b'-----BEGIN MESSAGE-----\n%s\n-----END MESSAGE-----' % b'\n'.join(stem.util.str_tools._split_by_length(encoded, 64)) -def _layer_cipher(constant, revision_counter, subcredential, blinded_key, salt): +def _layer_cipher(constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes, salt: bytes) -> Tuple['cryptography.hazmat.primitives.ciphers.Cipher', Callable[[bytes], bytes]]: # type: ignore try: from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend @@ -530,7 +534,7 @@ def _layer_cipher(constant, revision_counter, subcredential, blinded_key, salt): return cipher, lambda ciphertext: hashlib.sha3_256(mac_prefix + ciphertext).digest() -def _parse_protocol_versions_line(descriptor, entries): +def _parse_protocol_versions_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value('protocol-versions', entries) try: @@ -545,7 +549,7 @@ def _parse_protocol_versions_line(descriptor, entries): descriptor.protocol_versions = versions -def _parse_introduction_points_line(descriptor, entries): +def _parse_introduction_points_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: _, block_type, block_contents = entries['introduction-points'][0] if not block_contents or block_type != 'MESSAGE': @@ -559,7 +563,7 @@ def _parse_introduction_points_line(descriptor, entries): raise ValueError("'introduction-points' isn't base64 encoded content:\n%s" % block_contents) -def _parse_v3_outer_clients(descriptor, entries): +def _parse_v3_outer_clients(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "auth-client" client-id iv encrypted-cookie clients = {} @@ -575,7 +579,7 @@ def _parse_v3_outer_clients(descriptor, entries): descriptor.clients = clients -def _parse_v3_inner_formats(descriptor, entries): +def _parse_v3_inner_formats(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value, formats = _value('create2-formats', entries), [] for entry in value.split(' '): @@ -587,7 +591,7 @@ def _parse_v3_inner_formats(descriptor, entries): descriptor.formats = formats -def _parse_v3_introduction_points(descriptor, entries): +def _parse_v3_introduction_points(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: if hasattr(descriptor, '_unparsed_introduction_points'): introduction_points = [] remaining = descriptor._unparsed_introduction_points @@ -673,7 +677,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor): 'introduction_points_encoded': (None, _parse_introduction_points_line), 'introduction_points_content': (None, _parse_introduction_points_line), 'signature': (None, _parse_v2_signature_line), - } + } # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]] PARSER_FOR_LINE = { 'rendezvous-service-descriptor': _parse_rendezvous_service_descriptor_line, @@ -687,7 +691,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor): } @classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV2'], attr: Mapping[str, str] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('rendezvous-service-descriptor', 'y3olqqblqw2gbh6phimfuiroechjjafa'), ('version', '2'), @@ -701,10 +705,10 @@ def content(cls, attr = None, exclude = ()): )) @classmethod - def create(cls, attr = None, exclude = (), validate = True): + def create(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV2'], attr: Mapping[str, str] = None, exclude: Sequence[str] = (), validate: bool = True) -> 'stem.descriptor.hidden_service.HiddenServiceDescriptorV2': return cls(cls.content(attr, exclude), validate = validate, skip_crypto_validation = True) - def __init__(self, raw_contents, validate = False, skip_crypto_validation = False): + def __init__(self, raw_contents: bytes, validate: bool = False, skip_crypto_validation: bool = False) -> None: super(HiddenServiceDescriptorV2, self).__init__(raw_contents, lazy_load = not validate) entries = _descriptor_components(raw_contents, validate, non_ascii_fields = ('introduction-points')) @@ -736,10 +740,12 @@ def __init__(self, raw_contents, validate = False, skip_crypto_validation = Fals self._entries = entries @functools.lru_cache() - def introduction_points(self, authentication_cookie = None): + def introduction_points(self, authentication_cookie: Optional[bytes] = None) -> Sequence['stem.descriptor.hidden_service.IntroductionPointV2']: """ Provided this service's introduction points. + :param bytes authentication_cookie: base64 encoded authentication cookie + :returns: **list** of :class:`~stem.descriptor.hidden_service.IntroductionPointV2` :raises: @@ -774,7 +780,7 @@ def introduction_points(self, authentication_cookie = None): return HiddenServiceDescriptorV2._parse_introduction_points(content) @staticmethod - def _decrypt_basic_auth(content, authentication_cookie): + def _decrypt_basic_auth(content: bytes, authentication_cookie: bytes) -> bytes: try: from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend @@ -784,7 +790,7 @@ def _decrypt_basic_auth(content, authentication_cookie): try: client_blocks = int(binascii.hexlify(content[1:2]), 16) except ValueError: - raise DecryptionFailure("When using basic auth the content should start with a number of blocks but wasn't a hex digit: %s" % binascii.hexlify(content[1:2])) + raise DecryptionFailure("When using basic auth the content should start with a number of blocks but wasn't a hex digit: %s" % binascii.hexlify(content[1:2]).decode('utf-8')) # parse the client id and encrypted session keys @@ -821,7 +827,7 @@ def _decrypt_basic_auth(content, authentication_cookie): return content # nope, unable to decrypt the content @staticmethod - def _decrypt_stealth_auth(content, authentication_cookie): + def _decrypt_stealth_auth(content: bytes, authentication_cookie: bytes) -> bytes: try: from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend @@ -836,7 +842,7 @@ def _decrypt_stealth_auth(content, authentication_cookie): return decryptor.update(encrypted) + decryptor.finalize() @staticmethod - def _parse_introduction_points(content): + def _parse_introduction_points(content: bytes) -> Sequence['stem.descriptor.hidden_service.IntroductionPointV2']: """ Provides the parsed list of IntroductionPointV2 for the unencrypted content. """ @@ -885,7 +891,7 @@ def _parse_introduction_points(content): auth_type, auth_data = auth_value.split(' ')[:2] auth_entries.append((auth_type, auth_data)) - introduction_points.append(IntroductionPointV2(**attr)) + introduction_points.append(IntroductionPointV2(**attr)) # type: ignore return introduction_points @@ -928,7 +934,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor): } @classmethod - def content(cls, attr = None, exclude = (), sign = False, inner_layer = None, outer_layer = None, identity_key = None, signing_key = None, signing_cert = None, revision_counter = None, blinding_nonce = None): + def content(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, outer_layer: Optional['stem.descriptor.hidden_service.OuterLayer'] = None, identity_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_cert: Optional['stem.descriptor.certificate.Ed25519CertificateV1'] = None, revision_counter: int = None, blinding_nonce: bytes = None) -> bytes: # type: ignore """ Hidden service v3 descriptors consist of three parts: @@ -989,7 +995,12 @@ def content(cls, attr = None, exclude = (), sign = False, inner_layer = None, ou blinded_key = _blinded_pubkey(identity_key, blinding_nonce) if blinding_nonce else b'a' * 32 subcredential = HiddenServiceDescriptorV3._subcredential(identity_key, blinded_key) - custom_sig = attr.pop('signature') if (attr and 'signature' in attr) else None + + if attr and 'signature' in attr: + custom_sig = attr['signature'] + attr = dict(filter(lambda entry: entry[0] != 'signature', attr.items())) + else: + custom_sig = None if not outer_layer: outer_layer = OuterLayer.create( @@ -1011,7 +1022,7 @@ def content(cls, attr = None, exclude = (), sign = False, inner_layer = None, ou ('descriptor-lifetime', '180'), ('descriptor-signing-key-cert', '\n' + signing_cert.to_base64(pem = True)), ('revision-counter', str(revision_counter)), - ('superencrypted', b'\n' + outer_layer._encrypt(revision_counter, subcredential, blinded_key)), + ('superencrypted', stem.util.str_tools._to_unicode(b'\n' + outer_layer._encrypt(revision_counter, subcredential, blinded_key))), ), ()) + b'\n' if custom_sig: @@ -1023,13 +1034,13 @@ def content(cls, attr = None, exclude = (), sign = False, inner_layer = None, ou return desc_content @classmethod - def create(cls, attr = None, exclude = (), validate = True, sign = False, inner_layer = None, outer_layer = None, identity_key = None, signing_key = None, signing_cert = None, revision_counter = None, blinding_nonce = None): + def create(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, outer_layer: Optional['stem.descriptor.hidden_service.OuterLayer'] = None, identity_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_cert: Optional['stem.descriptor.certificate.Ed25519CertificateV1'] = None, revision_counter: int = None, blinding_nonce: bytes = None) -> 'stem.descriptor.hidden_service.HiddenServiceDescriptorV3': # type: ignore return cls(cls.content(attr, exclude, sign, inner_layer, outer_layer, identity_key, signing_key, signing_cert, revision_counter, blinding_nonce), validate = validate) - def __init__(self, raw_contents, validate = False): + def __init__(self, raw_contents: bytes, validate: bool = False) -> None: super(HiddenServiceDescriptorV3, self).__init__(raw_contents, lazy_load = not validate) - self._inner_layer = None + self._inner_layer = None # type: Optional[stem.descriptor.hidden_service.InnerLayer] entries = _descriptor_components(raw_contents, validate) if validate: @@ -1054,7 +1065,7 @@ def __init__(self, raw_contents, validate = False): else: self._entries = entries - def decrypt(self, onion_address): + def decrypt(self, onion_address: str) -> 'stem.descriptor.hidden_service.InnerLayer': """ Decrypt this descriptor. Hidden serice descriptors contain two encryption layers (:class:`~stem.descriptor.hidden_service.OuterLayer` and @@ -1086,7 +1097,7 @@ def decrypt(self, onion_address): return self._inner_layer @staticmethod - def address_from_identity_key(key, suffix = True): + def address_from_identity_key(key: Union[bytes, 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'], suffix: bool = True) -> str: # type: ignore """ Converts a hidden service identity key into its address. This accepts all key formats (private, public, or public bytes). @@ -1094,7 +1105,7 @@ def address_from_identity_key(key, suffix = True): :param Ed25519PublicKey,Ed25519PrivateKey,bytes key: hidden service identity key :param bool suffix: includes the '.onion' suffix if true, excluded otherwise - :returns: **unicode** hidden service address + :returns: **str** hidden service address :raises: **ImportError** if key is a cryptographic type and ed25519 support is unavailable @@ -1109,7 +1120,7 @@ def address_from_identity_key(key, suffix = True): return stem.util.str_tools._to_unicode(onion_address + b'.onion' if suffix else onion_address).lower() @staticmethod - def identity_key_from_address(onion_address): + def identity_key_from_address(onion_address: str) -> bytes: """ Converts a hidden service address into its public identity key. @@ -1146,7 +1157,7 @@ def identity_key_from_address(onion_address): return pubkey @staticmethod - def _subcredential(identity_key, blinded_key): + def _subcredential(identity_key: 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey', blinded_key: bytes) -> bytes: # type: ignore # credential = H('credential' | public-identity-key) # subcredential = H('subcredential' | credential | blinded-public-key) @@ -1176,7 +1187,7 @@ class OuterLayer(Descriptor): 'ephemeral_key': (None, _parse_v3_outer_ephemeral_key), 'clients': ({}, _parse_v3_outer_clients), 'encrypted': (None, _parse_v3_outer_encrypted), - } + } # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]] PARSER_FOR_LINE = { 'desc-auth-type': _parse_v3_outer_auth_type, @@ -1186,11 +1197,11 @@ class OuterLayer(Descriptor): } @staticmethod - def _decrypt(encrypted, revision_counter, subcredential, blinded_key): + def _decrypt(encrypted: str, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> 'stem.descriptor.hidden_service.OuterLayer': plaintext = _decrypt_layer(encrypted, b'hsdir-superencrypted-data', revision_counter, subcredential, blinded_key) - return OuterLayer(plaintext) + return OuterLayer(stem.util.str_tools._to_bytes(plaintext)) - def _encrypt(self, revision_counter, subcredential, blinded_key): + def _encrypt(self, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes: # Spec mandated padding: "Before encryption the plaintext is padded with # NUL bytes to the nearest multiple of 10k bytes." @@ -1201,7 +1212,7 @@ def _encrypt(self, revision_counter, subcredential, blinded_key): return _encrypt_layer(content, b'hsdir-superencrypted-data', revision_counter, subcredential, blinded_key) @classmethod - def content(cls, attr = None, exclude = (), validate = True, sign = False, inner_layer = None, revision_counter = None, authorized_clients = None, subcredential = None, blinded_key = None): + def content(cls: Type['stem.descriptor.hidden_service.OuterLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, revision_counter: Optional[int] = None, authorized_clients: Optional[Sequence['stem.descriptor.hidden_service.AuthorizedClient']] = None, subcredential: bytes = None, blinded_key: bytes = None) -> bytes: try: from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey @@ -1227,18 +1238,18 @@ def content(cls, attr = None, exclude = (), validate = True, sign = False, inner return _descriptor_content(attr, exclude, [ ('desc-auth-type', 'x25519'), - ('desc-auth-ephemeral-key', base64.b64encode(stem.util._pubkey_bytes(X25519PrivateKey.generate()))), + ('desc-auth-ephemeral-key', stem.util.str_tools._to_unicode(base64.b64encode(stem.util._pubkey_bytes(X25519PrivateKey.generate())))), ] + [ ('auth-client', '%s %s %s' % (c.id, c.iv, c.cookie)) for c in authorized_clients ], ( - ('encrypted', b'\n' + inner_layer._encrypt(revision_counter, subcredential, blinded_key)), + ('encrypted', stem.util.str_tools._to_unicode(b'\n' + inner_layer._encrypt(revision_counter, subcredential, blinded_key))), )) @classmethod - def create(cls, attr = None, exclude = (), validate = True, sign = False, inner_layer = None, revision_counter = None, authorized_clients = None, subcredential = None, blinded_key = None): + def create(cls: Type['stem.descriptor.hidden_service.OuterLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, revision_counter: int = None, authorized_clients: Optional[Sequence['stem.descriptor.hidden_service.AuthorizedClient']] = None, subcredential: bytes = None, blinded_key: bytes = None) -> 'stem.descriptor.hidden_service.OuterLayer': return cls(cls.content(attr, exclude, validate, sign, inner_layer, revision_counter, authorized_clients, subcredential, blinded_key), validate = validate) - def __init__(self, content, validate = False): + def __init__(self, content: bytes, validate: bool = False) -> None: content = stem.util.str_tools._to_bytes(content).rstrip(b'\x00') # strip null byte padding super(OuterLayer, self).__init__(content, lazy_load = not validate) @@ -1282,17 +1293,17 @@ class InnerLayer(Descriptor): } @staticmethod - def _decrypt(outer_layer, revision_counter, subcredential, blinded_key): + def _decrypt(outer_layer: 'stem.descriptor.hidden_service.OuterLayer', revision_counter: int, subcredential: bytes, blinded_key: bytes) -> 'stem.descriptor.hidden_service.InnerLayer': plaintext = _decrypt_layer(outer_layer.encrypted, b'hsdir-encrypted-data', revision_counter, subcredential, blinded_key) - return InnerLayer(plaintext, validate = True, outer_layer = outer_layer) + return InnerLayer(stem.util.str_tools._to_bytes(plaintext), validate = True, outer_layer = outer_layer) - def _encrypt(self, revision_counter, subcredential, blinded_key): + def _encrypt(self, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes: # encrypt back into an outer layer's 'encrypted' field return _encrypt_layer(self.get_bytes(), b'hsdir-encrypted-data', revision_counter, subcredential, blinded_key) @classmethod - def content(cls, attr = None, exclude = (), introduction_points = None): + def content(cls: Type['stem.descriptor.hidden_service.InnerLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), introduction_points: Optional[Sequence['stem.descriptor.hidden_service.IntroductionPointV3']] = None) -> bytes: if introduction_points: suffix = '\n' + '\n'.join(map(IntroductionPointV3.encode, introduction_points)) else: @@ -1303,10 +1314,10 @@ def content(cls, attr = None, exclude = (), introduction_points = None): )) + stem.util.str_tools._to_bytes(suffix) @classmethod - def create(cls, attr = None, exclude = (), validate = True, introduction_points = None): + def create(cls: Type['stem.descriptor.hidden_service.InnerLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, introduction_points: Optional[Sequence['stem.descriptor.hidden_service.IntroductionPointV3']] = None) -> 'stem.descriptor.hidden_service.InnerLayer': return cls(cls.content(attr, exclude, introduction_points), validate = validate) - def __init__(self, content, validate = False, outer_layer = None): + def __init__(self, content: bytes, validate: bool = False, outer_layer: Optional['stem.descriptor.hidden_service.OuterLayer'] = None) -> None: super(InnerLayer, self).__init__(content, lazy_load = not validate) self.outer = outer_layer @@ -1331,7 +1342,7 @@ def __init__(self, content, validate = False, outer_layer = None): self._entries = entries -def _blinded_pubkey(identity_key, blinding_nonce): +def _blinded_pubkey(identity_key: bytes, blinding_nonce: bytes) -> bytes: from stem.util import ed25519 mult = 2 ** (ed25519.b - 2) + sum(2 ** i * ed25519.bit(blinding_nonce, i) for i in range(3, ed25519.b - 2)) @@ -1339,7 +1350,7 @@ def _blinded_pubkey(identity_key, blinding_nonce): return ed25519.encodepoint(ed25519.scalarmult(P, mult)) -def _blinded_sign(msg, identity_key, blinded_key, blinding_nonce): +def _blinded_sign(msg: bytes, identity_key: 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey', blinded_key: bytes, blinding_nonce: bytes) -> bytes: # type: ignore try: from cryptography.hazmat.primitives import serialization except ImportError: diff --git a/stem/descriptor/microdescriptor.py b/stem/descriptor/microdescriptor.py index c62a3d0d5..7bd241e47 100644 --- a/stem/descriptor/microdescriptor.py +++ b/stem/descriptor/microdescriptor.py @@ -69,7 +69,10 @@ import stem.exit_policy +from typing import Any, BinaryIO, Dict, Iterator, Mapping, Optional, Sequence, Type, Union + from stem.descriptor import ( + ENTRY_TYPE, Descriptor, DigestHash, DigestEncoding, @@ -102,7 +105,7 @@ ) -def _parse_file(descriptor_file, validate = False, **kwargs): +def _parse_file(descriptor_file: BinaryIO, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.microdescriptor.Microdescriptor']: """ Iterates over the microdescriptors in a file. @@ -118,6 +121,9 @@ def _parse_file(descriptor_file, validate = False, **kwargs): * **IOError** if the file can't be read """ + if kwargs: + raise ValueError('BUG: keyword arguments unused by microdescriptors') + while True: annotations = _read_until_keywords('onion-key', descriptor_file) @@ -154,12 +160,12 @@ def _parse_file(descriptor_file, validate = False, **kwargs): descriptor_text = bytes.join(b'', descriptor_lines) - yield Microdescriptor(descriptor_text, validate, annotations, **kwargs) + yield Microdescriptor(descriptor_text, validate, annotations) else: break # done parsing descriptors -def _parse_id_line(descriptor, entries): +def _parse_id_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: identities = {} for entry in _values('id', entries): @@ -244,12 +250,12 @@ class Microdescriptor(Descriptor): } @classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.microdescriptor.Microdescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('onion-key', _random_crypto_blob('RSA PUBLIC KEY')), )) - def __init__(self, raw_contents, validate = False, annotations = None): + def __init__(self, raw_contents: bytes, validate: bool = False, annotations: Optional[Sequence[bytes]] = None) -> None: super(Microdescriptor, self).__init__(raw_contents, lazy_load = not validate) self._annotation_lines = annotations if annotations else [] entries = _descriptor_components(raw_contents, validate) @@ -260,7 +266,7 @@ def __init__(self, raw_contents, validate = False, annotations = None): else: self._entries = entries - def digest(self, hash_type = DigestHash.SHA256, encoding = DigestEncoding.BASE64): + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA256, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.BASE64) -> Union[str, 'hashlib._HASH']: # type: ignore """ Digest of this microdescriptor. These are referenced by... @@ -285,7 +291,7 @@ def digest(self, hash_type = DigestHash.SHA256, encoding = DigestEncoding.BASE64 raise NotImplementedError('Microdescriptor digests are only available in sha1 and sha256, not %s' % hash_type) @functools.lru_cache() - def get_annotations(self): + def get_annotations(self) -> Dict[bytes, bytes]: """ Provides content that appeared prior to the descriptor. If this comes from the cached-microdescs then this commonly contains content like... @@ -308,7 +314,7 @@ def get_annotations(self): return annotation_dict - def get_annotation_lines(self): + def get_annotation_lines(self) -> Sequence[bytes]: """ Provides the lines of content that appeared prior to the descriptor. This is the same as the @@ -320,7 +326,7 @@ def get_annotation_lines(self): return self._annotation_lines - def _check_constraints(self, entries): + def _check_constraints(self, entries: ENTRY_TYPE) -> None: """ Does a basic check that the entries conform to this descriptor type's constraints. @@ -341,5 +347,5 @@ def _check_constraints(self, entries): if 'onion-key' != list(entries.keys())[0]: raise ValueError("Microdescriptor must start with a 'onion-key' entry") - def _name(self, is_plural = False): + def _name(self, is_plural: bool = False) -> str: return 'microdescriptors' if is_plural else 'microdescriptor' diff --git a/stem/descriptor/networkstatus.py b/stem/descriptor/networkstatus.py index 77c6d6127..6c0f5e8f4 100644 --- a/stem/descriptor/networkstatus.py +++ b/stem/descriptor/networkstatus.py @@ -65,7 +65,10 @@ import stem.util.tor_tools import stem.version +from typing import Any, BinaryIO, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union + from stem.descriptor import ( + ENTRY_TYPE, PGP_BLOCK_END, Descriptor, DigestHash, @@ -293,7 +296,7 @@ class DocumentDigest(collections.namedtuple('DocumentDigest', ['flavor', 'algori """ -def _parse_file(document_file, document_type = None, validate = False, is_microdescriptor = False, document_handler = DocumentHandler.ENTRIES, **kwargs): +def _parse_file(document_file: BinaryIO, document_type: Optional[Type] = None, validate: bool = False, is_microdescriptor: bool = False, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, **kwargs: Any) -> Iterator[Union['stem.descriptor.networkstatus.NetworkStatusDocument', 'stem.descriptor.router_status_entry.RouterStatusEntry']]: """ Parses a network status and iterates over the RouterStatusEntry in it. The document that these instances reference have an empty 'routers' attribute to @@ -322,6 +325,8 @@ def _parse_file(document_file, document_type = None, validate = False, is_microd if document_type is None: document_type = NetworkStatusDocumentV3 + router_type = None # type: Optional[Type[stem.descriptor.router_status_entry.RouterStatusEntry]] + if document_type == NetworkStatusDocumentV2: document_type, router_type = NetworkStatusDocumentV2, RouterStatusEntryV2 elif document_type == NetworkStatusDocumentV3: @@ -332,10 +337,10 @@ def _parse_file(document_file, document_type = None, validate = False, is_microd yield document_type(document_file.read(), validate, **kwargs) return else: - raise ValueError("Document type %i isn't recognized (only able to parse v2, v3, and bridge)" % document_type) + raise ValueError("Document type %s isn't recognized (only able to parse v2, v3, and bridge)" % document_type) if document_handler == DocumentHandler.DOCUMENT: - yield document_type(document_file.read(), validate, **kwargs) + yield document_type(document_file.read(), validate, **kwargs) # type: ignore return # getting the document without the routers section @@ -353,7 +358,7 @@ def _parse_file(document_file, document_type = None, validate = False, is_microd document_content = bytes.join(b'', header + footer) if document_handler == DocumentHandler.BARE_DOCUMENT: - yield document_type(document_content, validate, **kwargs) + yield document_type(document_content, validate, **kwargs) # type: ignore elif document_handler == DocumentHandler.ENTRIES: desc_iterator = stem.descriptor.router_status_entry._parse_file( document_file, @@ -372,7 +377,7 @@ def _parse_file(document_file, document_type = None, validate = False, is_microd raise ValueError('Unrecognized document_handler: %s' % document_handler) -def _parse_file_key_certs(certificate_file, validate = False): +def _parse_file_key_certs(certificate_file: BinaryIO, validate: bool = False) -> Iterator['stem.descriptor.networkstatus.KeyCertificate']: """ Parses a file containing one or more authority key certificates. @@ -401,7 +406,7 @@ def _parse_file_key_certs(certificate_file, validate = False): break # done parsing file -def _parse_file_detached_sigs(detached_signature_file, validate = False): +def _parse_file_detached_sigs(detached_signature_file: BinaryIO, validate: bool = False) -> Iterator['stem.descriptor.networkstatus.DetachedSignature']: """ Parses a file containing one or more detached signatures. @@ -431,7 +436,7 @@ class NetworkStatusDocument(Descriptor): Common parent for network status documents. """ - def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore """ Digest of this descriptor's content. These are referenced by... @@ -458,8 +463,8 @@ def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): raise NotImplementedError('Network status document digests are only available in sha1 and sha256, not %s' % hash_type) -def _parse_version_line(keyword, attribute, expected_version): - def _parse(descriptor, entries): +def _parse_version_line(keyword: str, attribute: str, expected_version: int) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value(keyword, entries) if not value.isdigit(): @@ -473,7 +478,7 @@ def _parse(descriptor, entries): return _parse -def _parse_dir_source_line(descriptor, entries): +def _parse_dir_source_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value('dir-source', entries) dir_source_comp = value.split() @@ -493,7 +498,7 @@ def _parse_dir_source_line(descriptor, entries): descriptor.dir_port = None if dir_source_comp[2] == '0' else int(dir_source_comp[2]) -def _parse_additional_digests(descriptor, entries): +def _parse_additional_digests(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: digests = [] for val in _values('additional-digest', entries): @@ -507,7 +512,7 @@ def _parse_additional_digests(descriptor, entries): descriptor.additional_digests = digests -def _parse_additional_signatures(descriptor, entries): +def _parse_additional_signatures(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: signatures = [] for val, block_type, block_contents in entries['additional-signature']: @@ -582,7 +587,7 @@ class NetworkStatusDocumentV2(NetworkStatusDocument): 'signing_authority': (None, _parse_directory_signature_line), 'signatures': (None, _parse_directory_signature_line), - } + } # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]] PARSER_FOR_LINE = { 'network-status-version': _parse_network_status_version_line, @@ -598,7 +603,7 @@ class NetworkStatusDocumentV2(NetworkStatusDocument): } @classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.networkstatus.NetworkStatusDocumentV2'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('network-status-version', '2'), ('dir-source', '%s %s 80' % (_random_ipv4_address(), _random_ipv4_address())), @@ -610,7 +615,7 @@ def content(cls, attr = None, exclude = ()): ('directory-signature', 'moria2' + _random_crypto_blob('SIGNATURE')), )) - def __init__(self, raw_content, validate = False): + def __init__(self, raw_content: bytes, validate: bool = False) -> None: super(NetworkStatusDocumentV2, self).__init__(raw_content, lazy_load = not validate) # Splitting the document from the routers. Unlike v3 documents we're not @@ -646,7 +651,7 @@ def __init__(self, raw_content, validate = False): else: self._entries = entries - def _check_constraints(self, entries): + def _check_constraints(self, entries: ENTRY_TYPE) -> None: required_fields = [field for (field, is_mandatory) in NETWORK_STATUS_V2_FIELDS if is_mandatory] for keyword in required_fields: if keyword not in entries: @@ -662,7 +667,7 @@ def _check_constraints(self, entries): raise ValueError("Network status document (v2) are expected to start with a 'network-status-version' line:\n%s" % str(self)) -def _parse_header_network_status_version_line(descriptor, entries): +def _parse_header_network_status_version_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "network-status-version" version value = _value('network-status-version', entries) @@ -683,7 +688,7 @@ def _parse_header_network_status_version_line(descriptor, entries): raise ValueError("Expected a version 3 network status document, got version '%s' instead" % descriptor.version) -def _parse_header_vote_status_line(descriptor, entries): +def _parse_header_vote_status_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "vote-status" type # # The consensus-method and consensus-methods fields are optional since @@ -700,7 +705,7 @@ def _parse_header_vote_status_line(descriptor, entries): raise ValueError("A network status document's vote-status line can only be 'consensus' or 'vote', got '%s' instead" % value) -def _parse_header_consensus_methods_line(descriptor, entries): +def _parse_header_consensus_methods_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "consensus-methods" IntegerList if descriptor._lazy_loading and descriptor.is_vote: @@ -717,7 +722,7 @@ def _parse_header_consensus_methods_line(descriptor, entries): descriptor.consensus_methods = consensus_methods -def _parse_header_consensus_method_line(descriptor, entries): +def _parse_header_consensus_method_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "consensus-method" Integer if descriptor._lazy_loading and descriptor.is_consensus: @@ -731,7 +736,7 @@ def _parse_header_consensus_method_line(descriptor, entries): descriptor.consensus_method = int(value) -def _parse_header_voting_delay_line(descriptor, entries): +def _parse_header_voting_delay_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "voting-delay" VoteSeconds DistSeconds value = _value('voting-delay', entries) @@ -744,8 +749,8 @@ def _parse_header_voting_delay_line(descriptor, entries): raise ValueError("A network status document's 'voting-delay' line must be a pair of integer values, but was '%s'" % value) -def _parse_versions_line(keyword, attribute): - def _parse(descriptor, entries): +def _parse_versions_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value, entries = _value(keyword, entries), [] for entry in value.split(','): @@ -759,7 +764,7 @@ def _parse(descriptor, entries): return _parse -def _parse_header_flag_thresholds_line(descriptor, entries): +def _parse_header_flag_thresholds_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "flag-thresholds" SP THRESHOLDS value, thresholds = _value('flag-thresholds', entries).strip(), {} @@ -782,7 +787,7 @@ def _parse_header_flag_thresholds_line(descriptor, entries): descriptor.flag_thresholds = thresholds -def _parse_header_parameters_line(descriptor, entries): +def _parse_header_parameters_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "params" [Parameters] # Parameter ::= Keyword '=' Int32 # Int32 ::= A decimal integer between -2147483648 and 2147483647. @@ -798,7 +803,7 @@ def _parse_header_parameters_line(descriptor, entries): descriptor._check_params_constraints() -def _parse_directory_footer_line(descriptor, entries): +def _parse_directory_footer_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # nothing to parse, simply checking that we don't have a value value = _value('directory-footer', entries) @@ -807,7 +812,7 @@ def _parse_directory_footer_line(descriptor, entries): raise ValueError("A network status document's 'directory-footer' line shouldn't have any content, got 'directory-footer %s'" % value) -def _parse_footer_directory_signature_line(descriptor, entries): +def _parse_footer_directory_signature_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: signatures = [] for sig_value, block_type, block_contents in entries['directory-signature']: @@ -828,7 +833,7 @@ def _parse_footer_directory_signature_line(descriptor, entries): descriptor.signatures = signatures -def _parse_package_line(descriptor, entries): +def _parse_package_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: package_versions = [] for value, _, _ in entries['package']: @@ -849,7 +854,7 @@ def _parse_package_line(descriptor, entries): descriptor.packages = package_versions -def _parsed_shared_rand_commit(descriptor, entries): +def _parsed_shared_rand_commit(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "shared-rand-commit" Version AlgName Identity Commit [Reveal] commitments = [] @@ -871,7 +876,7 @@ def _parsed_shared_rand_commit(descriptor, entries): descriptor.shared_randomness_commitments = commitments -def _parse_shared_rand_previous_value(descriptor, entries): +def _parse_shared_rand_previous_value(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "shared-rand-previous-value" NumReveals Value value = _value('shared-rand-previous-value', entries) @@ -884,7 +889,7 @@ def _parse_shared_rand_previous_value(descriptor, entries): raise ValueError("A network status document's 'shared-rand-previous-value' line must be a pair of values, the first an integer but was '%s'" % value) -def _parse_shared_rand_current_value(descriptor, entries): +def _parse_shared_rand_current_value(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "shared-rand-current-value" NumReveals Value value = _value('shared-rand-current-value', entries) @@ -897,7 +902,7 @@ def _parse_shared_rand_current_value(descriptor, entries): raise ValueError("A network status document's 'shared-rand-current-value' line must be a pair of values, the first an integer but was '%s'" % value) -def _parse_bandwidth_file_headers(descriptor, entries): +def _parse_bandwidth_file_headers(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "bandwidth-file-headers" KeyValues # KeyValues ::= "" | KeyValue | KeyValues SP KeyValue # KeyValue ::= Keyword '=' Value @@ -912,7 +917,7 @@ def _parse_bandwidth_file_headers(descriptor, entries): descriptor.bandwidth_file_headers = results -def _parse_bandwidth_file_digest(descriptor, entries): +def _parse_bandwidth_file_digest(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "bandwidth-file-digest" 1*(SP algorithm "=" digest) value = _value('bandwidth-file-digest', entries) @@ -1096,7 +1101,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument): } @classmethod - def content(cls, attr = None, exclude = (), authorities = None, routers = None): + def content(cls: Type['stem.descriptor.networkstatus.NetworkStatusDocumentV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), authorities: Optional[Sequence['stem.descriptor.networkstatus.DirectoryAuthority']] = None, routers: Optional[Sequence['stem.descriptor.router_status_entry.RouterStatusEntryV3']] = None) -> bytes: attr = {} if attr is None else dict(attr) is_vote = attr.get('vote-status') == 'vote' @@ -1168,10 +1173,10 @@ def content(cls, attr = None, exclude = (), authorities = None, routers = None): return desc_content @classmethod - def create(cls, attr = None, exclude = (), validate = True, authorities = None, routers = None): + def create(cls: Type['stem.descriptor.networkstatus.NetworkStatusDocumentV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, authorities: Optional[Sequence['stem.descriptor.networkstatus.DirectoryAuthority']] = None, routers: Optional[Sequence['stem.descriptor.router_status_entry.RouterStatusEntryV3']] = None) -> 'stem.descriptor.networkstatus.NetworkStatusDocumentV3': return cls(cls.content(attr, exclude, authorities, routers), validate = validate) - def __init__(self, raw_content, validate = False, default_params = True): + def __init__(self, raw_content: bytes, validate: bool = False, default_params: bool = True) -> None: """ Parse a v3 network status document. @@ -1186,13 +1191,15 @@ def __init__(self, raw_content, validate = False, default_params = True): super(NetworkStatusDocumentV3, self).__init__(raw_content, lazy_load = not validate) document_file = io.BytesIO(raw_content) + self._header_entries = None # type: Optional[ENTRY_TYPE] + self._default_params = default_params self._header(document_file, validate) self.directory_authorities = tuple(stem.descriptor.router_status_entry._parse_file( document_file, validate, - entry_class = DirectoryAuthority, + entry_class = DirectoryAuthority, # type: ignore # TODO: move to another parse_file() entry_keyword = AUTH_START, section_end_keywords = (ROUTERS_START, FOOTER_START, V2_FOOTER_START), extra_args = (self.is_vote,), @@ -1213,7 +1220,7 @@ def __init__(self, raw_content, validate = False, default_params = True): self.routers = dict((desc.fingerprint, desc) for desc in router_iter) self._footer(document_file, validate) - def type_annotation(self): + def type_annotation(self) -> 'stem.descriptor.TypeAnnotation': if isinstance(self, BridgeNetworkStatusDocument): return TypeAnnotation('bridge-network-status', 1, 0) elif not self.is_microdescriptor: @@ -1225,7 +1232,7 @@ def type_annotation(self): return TypeAnnotation('network-status-microdesc-consensus-3', 1, 0) - def is_valid(self): + def is_valid(self) -> bool: """ Checks if the current time is between this document's **valid_after** and **valid_until** timestamps. To be valid means the information within this @@ -1239,7 +1246,7 @@ def is_valid(self): return self.valid_after < datetime.datetime.utcnow() < self.valid_until - def is_fresh(self): + def is_fresh(self) -> bool: """ Checks if the current time is between this document's **valid_after** and **fresh_until** timestamps. To be fresh means this should be the latest @@ -1253,13 +1260,13 @@ def is_fresh(self): return self.valid_after < datetime.datetime.utcnow() < self.fresh_until - def validate_signatures(self, key_certs): + def validate_signatures(self, key_certs: Sequence['stem.descriptor.networkstatus.KeyCertificate']) -> None: """ Validates we're properly signed by the signing certificates. .. versionadded:: 1.6.0 - :param list key_certs: :class:`~stem.descriptor.networkstatus.KeyCertificates` + :param list key_certs: :class:`~stem.descriptor.networkstatus.KeyCertificate` to validate the consensus against :raises: **ValueError** if an insufficient number of valid signatures are present. @@ -1287,7 +1294,7 @@ def validate_signatures(self, key_certs): if valid_digests < required_digests: raise ValueError('Network Status Document has %i valid signatures out of %i total, needed %i' % (valid_digests, total_digests, required_digests)) - def get_unrecognized_lines(self): + def get_unrecognized_lines(self) -> List[str]: if self._lazy_loading: self._parse(self._header_entries, False, parser_for_line = self._HEADER_PARSER_FOR_LINE) self._parse(self._footer_entries, False, parser_for_line = self._FOOTER_PARSER_FOR_LINE) @@ -1295,7 +1302,7 @@ def get_unrecognized_lines(self): return super(NetworkStatusDocumentV3, self).get_unrecognized_lines() - def meets_consensus_method(self, method): + def meets_consensus_method(self, method: int) -> bool: """ Checks if we meet the given consensus-method. This works for both votes and consensuses, checking our 'consensus-method' and 'consensus-methods' @@ -1306,14 +1313,14 @@ def meets_consensus_method(self, method): :returns: **True** if we meet the given consensus-method, and **False** otherwise """ - if self.consensus_method is not None: - return self.consensus_method >= method - elif self.consensus_methods is not None: - return bool([x for x in self.consensus_methods if x >= method]) + if self.consensus_method is not None: # type: ignore + return self.consensus_method >= method # type: ignore + elif self.consensus_methods is not None: # type: ignore + return bool([x for x in self.consensus_methods if x >= method]) # type: ignore else: return False # malformed document - def _header(self, document_file, validate): + def _header(self, document_file: BinaryIO, validate: bool) -> None: content = bytes.join(b'', _read_until_keywords((AUTH_START, ROUTERS_START, FOOTER_START), document_file)) entries = _descriptor_components(content, validate) header_fields = [attr[0] for attr in HEADER_STATUS_DOCUMENT_FIELDS] @@ -1339,15 +1346,15 @@ def _header(self, document_file, validate): # default consensus_method and consensus_methods based on if we're a consensus or vote - if self.is_consensus and not self.consensus_method: + if self.is_consensus and not self.consensus_method: # type: ignore self.consensus_method = 1 - elif self.is_vote and not self.consensus_methods: + elif self.is_vote and not self.consensus_methods: # type: ignore self.consensus_methods = [1] else: self._header_entries = entries self._entries.update(entries) - def _footer(self, document_file, validate): + def _footer(self, document_file: BinaryIO, validate: bool) -> None: entries = _descriptor_components(document_file.read(), validate) footer_fields = [attr[0] for attr in FOOTER_STATUS_DOCUMENT_FIELDS] @@ -1379,7 +1386,7 @@ def _footer(self, document_file, validate): self._footer_entries = entries self._entries.update(entries) - def _check_params_constraints(self): + def _check_params_constraints(self) -> None: """ Checks that the params we know about are within their documented ranges. """ @@ -1398,7 +1405,7 @@ def _check_params_constraints(self): raise ValueError("'%s' value on the params line must be in the range of %i - %i, was %i" % (key, minimum, maximum, value)) -def _check_for_missing_and_disallowed_fields(document, entries, fields): +def _check_for_missing_and_disallowed_fields(document: 'stem.descriptor.networkstatus.NetworkStatusDocumentV3', entries: ENTRY_TYPE, fields: Sequence[Tuple[str, bool, bool, bool]]) -> None: """ Checks that we have mandatory fields for our type, and that we don't have any fields exclusive to the other (ie, no vote-only fields appear in a @@ -1431,12 +1438,13 @@ def _check_for_missing_and_disallowed_fields(document, entries, fields): raise ValueError("Network status document has fields that shouldn't appear in this document type or version: %s" % ', '.join(disallowed_fields)) -def _parse_int_mappings(keyword, value, validate): +def _parse_int_mappings(keyword: str, value: str, validate: bool) -> Dict[str, int]: # Parse a series of 'key=value' entries, checking the following: # - values are integers # - keys are sorted in lexical order - results, seen_keys = {}, [] + results = {} # type: Dict[str, int] + seen_keys = [] # type: List[str] error_template = "Unable to parse network status document's '%s' line (%%s): %s'" % (keyword, value) for key, val in _mappings_for(keyword, value): @@ -1461,7 +1469,7 @@ def _parse_int_mappings(keyword, value, validate): return results -def _parse_dirauth_source_line(descriptor, entries): +def _parse_dirauth_source_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "dir-source" nickname identity address IP dirport orport value = _value('dir-source', entries) @@ -1580,7 +1588,7 @@ class DirectoryAuthority(Descriptor): } @classmethod - def content(cls, attr = None, exclude = (), is_vote = False): + def content(cls: Type['stem.descriptor.networkstatus.DirectoryAuthority'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), is_vote: bool = False) -> bytes: attr = {} if attr is None else dict(attr) # include mandatory 'vote-digest' if a consensus @@ -1599,10 +1607,10 @@ def content(cls, attr = None, exclude = (), is_vote = False): return content @classmethod - def create(cls, attr = None, exclude = (), validate = True, is_vote = False): + def create(cls: Type['stem.descriptor.networkstatus.DirectoryAuthority'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, is_vote: bool = False) -> 'stem.descriptor.networkstatus.DirectoryAuthority': return cls(cls.content(attr, exclude, is_vote), validate = validate, is_vote = is_vote) - def __init__(self, raw_content, validate = False, is_vote = False): + def __init__(self, raw_content: bytes, validate: bool = False, is_vote: bool = False) -> None: """ Parse a directory authority entry in a v3 network status document. @@ -1621,12 +1629,12 @@ def __init__(self, raw_content, validate = False, is_vote = False): key_div = content.find('\ndir-key-certificate-version') if key_div != -1: - self.key_certificate = KeyCertificate(content[key_div + 1:], validate) + self.key_certificate = KeyCertificate(content[key_div + 1:].encode('utf-8'), validate) content = content[:key_div + 1] else: self.key_certificate = None - entries = _descriptor_components(content, validate) + entries = _descriptor_components(content.encode('utf-8'), validate) if validate and 'dir-source' != list(entries.keys())[0]: raise ValueError("Authority entries are expected to start with a 'dir-source' line:\n%s" % (content)) @@ -1677,7 +1685,7 @@ def __init__(self, raw_content, validate = False, is_vote = False): self._entries = entries -def _parse_dir_address_line(descriptor, entries): +def _parse_dir_address_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "dir-address" IPPort value = _value('dir-address', entries) @@ -1752,7 +1760,7 @@ class KeyCertificate(Descriptor): } @classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.networkstatus.KeyCertificate'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('dir-key-certificate-version', '3'), ('fingerprint', _random_fingerprint()), @@ -1764,26 +1772,26 @@ def content(cls, attr = None, exclude = ()): ('dir-key-certification', _random_crypto_blob('SIGNATURE')), )) - def __init__(self, raw_content, validate = False): + def __init__(self, raw_content: bytes, validate: bool = False) -> None: super(KeyCertificate, self).__init__(raw_content, lazy_load = not validate) entries = _descriptor_components(raw_content, validate) if validate: if 'dir-key-certificate-version' != list(entries.keys())[0]: - raise ValueError("Key certificates must start with a 'dir-key-certificate-version' line:\n%s" % (raw_content)) + raise ValueError("Key certificates must start with a 'dir-key-certificate-version' line:\n%s" % stem.util.str_tools._to_unicode(raw_content)) elif 'dir-key-certification' != list(entries.keys())[-1]: - raise ValueError("Key certificates must end with a 'dir-key-certification' line:\n%s" % (raw_content)) + raise ValueError("Key certificates must end with a 'dir-key-certification' line:\n%s" % stem.util.str_tools._to_unicode(raw_content)) # check that we have mandatory fields and that our known fields only # appear once for keyword, is_mandatory in KEY_CERTIFICATE_PARAMS: if is_mandatory and keyword not in entries: - raise ValueError("Key certificates must have a '%s' line:\n%s" % (keyword, raw_content)) + raise ValueError("Key certificates must have a '%s' line:\n%s" % (keyword, stem.util.str_tools._to_unicode(raw_content))) entry_count = len(entries.get(keyword, [])) if entry_count > 1: - raise ValueError("Key certificates can only have a single '%s' line, got %i:\n%s" % (keyword, entry_count, raw_content)) + raise ValueError("Key certificates can only have a single '%s' line, got %i:\n%s" % (keyword, entry_count, stem.util.str_tools._to_unicode(raw_content))) self._parse(entries, validate) else: @@ -1805,7 +1813,7 @@ class DocumentSignature(object): :raises: **ValueError** if a validity check fails """ - def __init__(self, method, identity, key_digest, signature, flavor = None, validate = False): + def __init__(self, method: str, identity: str, key_digest: str, signature: str, flavor: Optional[str] = None, validate: bool = False) -> None: # Checking that these attributes are valid. Technically the key # digest isn't a fingerprint, but it has the same characteristics. @@ -1822,7 +1830,7 @@ def __init__(self, method, identity, key_digest, signature, flavor = None, valid self.signature = signature self.flavor = flavor - def _compare(self, other, method): + def _compare(self, other: Any, method: Callable[[Any, Any], bool]) -> bool: if not isinstance(other, DocumentSignature): return False @@ -1832,19 +1840,19 @@ def _compare(self, other, method): return method(True, True) # we're equal - def __hash__(self): + def __hash__(self) -> int: return hash(str(self).strip()) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s == o) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other - def __lt__(self, other): + def __lt__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s < o) - def __le__(self, other): + def __le__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s <= o) @@ -1885,7 +1893,7 @@ class DetachedSignature(Descriptor): 'additional_digests': ([], _parse_additional_digests), 'additional_signatures': ([], _parse_additional_signatures), 'signatures': ([], _parse_footer_directory_signature_line), - } + } # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]] PARSER_FOR_LINE = { 'consensus-digest': _parse_consensus_digest_line, @@ -1898,7 +1906,7 @@ class DetachedSignature(Descriptor): } @classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.networkstatus.DetachedSignature'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('consensus-digest', '6D3CC0EFA408F228410A4A8145E1B0BB0670E442'), ('valid-after', _random_date()), @@ -1906,23 +1914,23 @@ def content(cls, attr = None, exclude = ()): ('valid-until', _random_date()), )) - def __init__(self, raw_content, validate = False): + def __init__(self, raw_content: bytes, validate: bool = False) -> None: super(DetachedSignature, self).__init__(raw_content, lazy_load = not validate) entries = _descriptor_components(raw_content, validate) if validate: if 'consensus-digest' != list(entries.keys())[0]: - raise ValueError("Detached signatures must start with a 'consensus-digest' line:\n%s" % (raw_content)) + raise ValueError("Detached signatures must start with a 'consensus-digest' line:\n%s" % stem.util.str_tools._to_unicode(raw_content)) # check that we have mandatory fields and certain fields only appear once for keyword, is_mandatory, is_multiple in DETACHED_SIGNATURE_PARAMS: if is_mandatory and keyword not in entries: - raise ValueError("Detached signatures must have a '%s' line:\n%s" % (keyword, raw_content)) + raise ValueError("Detached signatures must have a '%s' line:\n%s" % (keyword, stem.util.str_tools._to_unicode(raw_content))) entry_count = len(entries.get(keyword, [])) if not is_multiple and entry_count > 1: - raise ValueError("Detached signatures can only have a single '%s' line, got %i:\n%s" % (keyword, entry_count, raw_content)) + raise ValueError("Detached signatures can only have a single '%s' line, got %i:\n%s" % (keyword, entry_count, stem.util.str_tools._to_unicode(raw_content))) self._parse(entries, validate) else: @@ -1941,7 +1949,7 @@ class BridgeNetworkStatusDocument(NetworkStatusDocument): TYPE_ANNOTATION_NAME = 'bridge-network-status' - def __init__(self, raw_content, validate = False): + def __init__(self, raw_content: bytes, validate: bool = False) -> None: super(BridgeNetworkStatusDocument, self).__init__(raw_content) self.published = None diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py index 24eb7b9b2..2e2bb53bb 100644 --- a/stem/descriptor/remote.py +++ b/stem/descriptor/remote.py @@ -101,6 +101,7 @@ from stem.descriptor import Compression from stem.util import log, str_tools +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union # Tor has a limited number of descriptors we can fetch explicitly by their # fingerprint or hashes due to a limit on the url length by squid proxies. @@ -121,7 +122,7 @@ DIR_PORT_BLACKLIST = ('tor26', 'Serge') -def get_instance(): +def get_instance() -> 'stem.descriptor.remote.DescriptorDownloader': """ Provides the singleton :class:`~stem.descriptor.remote.DescriptorDownloader` used for this module's shorthand functions. @@ -139,7 +140,7 @@ def get_instance(): return SINGLETON_DOWNLOADER -def their_server_descriptor(**query_args): +def their_server_descriptor(**query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the server descriptor of the relay we're downloading from. @@ -154,7 +155,7 @@ def their_server_descriptor(**query_args): return get_instance().their_server_descriptor(**query_args) -def get_server_descriptors(fingerprints = None, **query_args): +def get_server_descriptors(fingerprints: Optional[Union[str, Sequence[str]]] = None, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Shorthand for :func:`~stem.descriptor.remote.DescriptorDownloader.get_server_descriptors` @@ -166,7 +167,7 @@ def get_server_descriptors(fingerprints = None, **query_args): return get_instance().get_server_descriptors(fingerprints, **query_args) -def get_extrainfo_descriptors(fingerprints = None, **query_args): +def get_extrainfo_descriptors(fingerprints: Optional[Union[str, Sequence[str]]] = None, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Shorthand for :func:`~stem.descriptor.remote.DescriptorDownloader.get_extrainfo_descriptors` @@ -178,7 +179,7 @@ def get_extrainfo_descriptors(fingerprints = None, **query_args): return get_instance().get_extrainfo_descriptors(fingerprints, **query_args) -def get_microdescriptors(hashes, **query_args): +def get_microdescriptors(hashes: Optional[Union[str, Sequence[str]]], **query_args: Any) -> 'stem.descriptor.remote.Query': """ Shorthand for :func:`~stem.descriptor.remote.DescriptorDownloader.get_microdescriptors` @@ -190,7 +191,7 @@ def get_microdescriptors(hashes, **query_args): return get_instance().get_microdescriptors(hashes, **query_args) -def get_consensus(authority_v3ident = None, microdescriptor = False, **query_args): +def get_consensus(authority_v3ident: Optional[str] = None, microdescriptor: bool = False, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Shorthand for :func:`~stem.descriptor.remote.DescriptorDownloader.get_consensus` @@ -202,7 +203,7 @@ def get_consensus(authority_v3ident = None, microdescriptor = False, **query_arg return get_instance().get_consensus(authority_v3ident, microdescriptor, **query_args) -def get_bandwidth_file(**query_args): +def get_bandwidth_file(**query_args: Any) -> 'stem.descriptor.remote.Query': """ Shorthand for :func:`~stem.descriptor.remote.DescriptorDownloader.get_bandwidth_file` @@ -214,7 +215,7 @@ def get_bandwidth_file(**query_args): return get_instance().get_bandwidth_file(**query_args) -def get_detached_signatures(**query_args): +def get_detached_signatures(**query_args: Any) -> 'stem.descriptor.remote.Query': """ Shorthand for :func:`~stem.descriptor.remote.DescriptorDownloader.get_detached_signatures` @@ -370,7 +371,7 @@ class Query(object): the same as running **query.run(True)** (default is **False**) """ - def __init__(self, resource, descriptor_type = None, endpoints = None, compression = (Compression.GZIP,), retries = 2, fall_back_to_authority = False, timeout = None, start = True, block = False, validate = False, document_handler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs): + def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None: if not resource.startswith('/'): raise ValueError("Resources should start with a '/': %s" % resource) @@ -379,8 +380,10 @@ def __init__(self, resource, descriptor_type = None, endpoints = None, compressi resource = resource[:-2] elif isinstance(compression, tuple): compression = list(compression) - elif not isinstance(compression, list): + elif isinstance(compression, stem.descriptor._Compression): compression = [compression] # caller provided only a single option + else: + raise ValueError('Compression should be a list of stem.descriptor.Compression, was %s (%s)' % (compression, type(compression).__name__)) if Compression.ZSTD in compression and not Compression.ZSTD.available: compression.remove(Compression.ZSTD) @@ -410,21 +413,21 @@ def __init__(self, resource, descriptor_type = None, endpoints = None, compressi self.retries = retries self.fall_back_to_authority = fall_back_to_authority - self.content = None - self.error = None + self.content = None # type: Optional[bytes] + self.error = None # type: Optional[BaseException] self.is_done = False - self.download_url = None + self.download_url = None # type: Optional[str] - self.start_time = None + self.start_time = None # type: Optional[float] self.timeout = timeout - self.runtime = None + self.runtime = None # type: Optional[float] self.validate = validate self.document_handler = document_handler - self.reply_headers = None + self.reply_headers = None # type: Optional[Dict[str, str]] self.kwargs = kwargs - self._downloader_thread = None + self._downloader_thread = None # type: Optional[threading.Thread] self._downloader_thread_lock = threading.RLock() if start: @@ -433,7 +436,7 @@ def __init__(self, resource, descriptor_type = None, endpoints = None, compressi if block: self.run(True) - def start(self): + def start(self) -> None: """ Starts downloading the scriptors if we haven't started already. """ @@ -449,7 +452,7 @@ def start(self): self._downloader_thread.setDaemon(True) self._downloader_thread.start() - def run(self, suppress = False): + def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']: """ Blocks until our request is complete then provides the descriptors. If we haven't yet started our request then this does so. @@ -469,7 +472,7 @@ def run(self, suppress = False): return list(self._run(suppress)) - def _run(self, suppress): + def _run(self, suppress: bool) -> Iterator[stem.descriptor.Descriptor]: with self._downloader_thread_lock: self.start() self._downloader_thread.join() @@ -505,11 +508,11 @@ def _run(self, suppress): raise self.error - def __iter__(self): + def __iter__(self) -> Iterator[stem.descriptor.Descriptor]: for desc in self._run(True): yield desc - def _pick_endpoint(self, use_authority = False): + def _pick_endpoint(self, use_authority: bool = False) -> stem.Endpoint: """ Provides an endpoint to query. If we have multiple endpoints then one is picked at random. @@ -527,7 +530,7 @@ def _pick_endpoint(self, use_authority = False): else: return random.choice(self.endpoints) - def _download_descriptors(self, retries, timeout): + def _download_descriptors(self, retries: int, timeout: Optional[float]) -> None: try: self.start_time = time.time() endpoint = self._pick_endpoint(use_authority = retries == 0 and self.fall_back_to_authority) @@ -572,10 +575,10 @@ class DescriptorDownloader(object): :class:`~stem.descriptor.remote.Query` constructor """ - def __init__(self, use_mirrors = False, **default_args): + def __init__(self, use_mirrors: bool = False, **default_args: Any) -> None: self._default_args = default_args - self._endpoints = None + self._endpoints = None # type: Optional[List[stem.DirPort]] if use_mirrors: try: @@ -585,7 +588,7 @@ def __init__(self, use_mirrors = False, **default_args): except Exception as exc: log.debug('Unable to retrieve directory mirrors: %s' % exc) - def use_directory_mirrors(self): + def use_directory_mirrors(self) -> stem.descriptor.networkstatus.NetworkStatusDocumentV3: """ Downloads the present consensus and configures ourselves to use directory mirrors, in addition to authorities. @@ -609,9 +612,9 @@ def use_directory_mirrors(self): self._endpoints = list(new_endpoints) - return consensus + return consensus # type: ignore - def their_server_descriptor(self, **query_args): + def their_server_descriptor(self, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the server descriptor of the relay we're downloading from. @@ -625,7 +628,7 @@ def their_server_descriptor(self, **query_args): return self.query('/tor/server/authority', **query_args) - def get_server_descriptors(self, fingerprints = None, **query_args): + def get_server_descriptors(self, fingerprints: Optional[Union[str, Sequence[str]]] = None, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the server descriptors with the given fingerprints. If no fingerprints are provided then this returns all descriptors known @@ -655,7 +658,7 @@ def get_server_descriptors(self, fingerprints = None, **query_args): return self.query(resource, **query_args) - def get_extrainfo_descriptors(self, fingerprints = None, **query_args): + def get_extrainfo_descriptors(self, fingerprints: Optional[Union[str, Sequence[str]]] = None, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the extrainfo descriptors with the given fingerprints. If no fingerprints are provided then this returns all descriptors in the present @@ -685,7 +688,7 @@ def get_extrainfo_descriptors(self, fingerprints = None, **query_args): return self.query(resource, **query_args) - def get_microdescriptors(self, hashes, **query_args): + def get_microdescriptors(self, hashes: Optional[Union[str, Sequence[str]]], **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the microdescriptors with the given hashes. To get these see the **microdescriptor_digest** attribute of @@ -731,7 +734,7 @@ def get_microdescriptors(self, hashes, **query_args): return self.query('/tor/micro/d/%s' % '-'.join(hashes), **query_args) - def get_consensus(self, authority_v3ident = None, microdescriptor = False, **query_args): + def get_consensus(self, authority_v3ident: Optional[str] = None, microdescriptor: bool = False, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the present router status entries. @@ -775,7 +778,7 @@ def get_consensus(self, authority_v3ident = None, microdescriptor = False, **que return consensus_query - def get_vote(self, authority, **query_args): + def get_vote(self, authority: stem.directory.Authority, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the present vote for a given directory authority. @@ -794,13 +797,13 @@ def get_vote(self, authority, **query_args): return self.query(resource, **query_args) - def get_key_certificates(self, authority_v3idents = None, **query_args): + def get_key_certificates(self, authority_v3idents: Optional[Union[str, Sequence[str]]] = None, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the key certificates for authorities with the given fingerprints. If no fingerprints are provided then this returns all present key certificates. - :param str authority_v3idents: fingerprint or list of fingerprints of the + :param str,list authority_v3idents: fingerprint or list of fingerprints of the authority keys, see `'v3ident' in tor's config.c `_ for the values. @@ -827,7 +830,7 @@ def get_key_certificates(self, authority_v3idents = None, **query_args): return self.query(resource, **query_args) - def get_bandwidth_file(self, **query_args): + def get_bandwidth_file(self, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the bandwidth authority heuristics used to make the next consensus. @@ -843,7 +846,7 @@ def get_bandwidth_file(self, **query_args): return self.query('/tor/status-vote/next/bandwidth', **query_args) - def get_detached_signatures(self, **query_args): + def get_detached_signatures(self, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the detached signatures that will be used to make the next consensus. Please note that **these are only available during minutes 55-60 @@ -896,7 +899,7 @@ def get_detached_signatures(self, **query_args): return self.query('/tor/status-vote/next/consensus-signatures', **query_args) - def query(self, resource, **query_args): + def query(self, resource: str, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Issues a request for the given resource. @@ -923,7 +926,7 @@ def query(self, resource, **query_args): return Query(resource, **args) -def _download_from_orport(endpoint, compression, resource): +def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.descriptor._Compression], resource: str) -> Tuple[bytes, Dict[str, str]]: """ Downloads descriptors from the given orport. Payload is just like an http response (headers and all)... @@ -973,7 +976,7 @@ def _download_from_orport(endpoint, compression, resource): for line in str_tools._to_unicode(header_data).splitlines(): if ': ' not in line: - raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % line) + raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % (line, header_data.decode('utf-8'))) key, value = line.split(': ', 1) headers[key] = value @@ -981,7 +984,7 @@ def _download_from_orport(endpoint, compression, resource): return _decompress(body_data, headers.get('Content-Encoding')), headers -def _download_from_dirport(url, compression, timeout): +def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Compression], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]: """ Downloads descriptors from the given url. @@ -1010,13 +1013,13 @@ def _download_from_dirport(url, compression, timeout): except socket.timeout as exc: raise stem.DownloadTimeout(url, exc, sys.exc_info()[2], timeout) except: - exc, stacktrace = sys.exc_info()[1:3] - raise stem.DownloadFailed(url, exc, stacktrace) + exception, stacktrace = sys.exc_info()[1:3] + raise stem.DownloadFailed(url, exception, stacktrace) return _decompress(response.read(), response.headers.get('Content-Encoding')), response.headers -def _decompress(data, encoding): +def _decompress(data: bytes, encoding: str) -> bytes: """ Decompresses descriptor data. @@ -1030,6 +1033,8 @@ def _decompress(data, encoding): :param bytes data: data we received :param str encoding: 'Content-Encoding' header of the response + :returns: **bytes** with the decompressed data + :raises: * **ValueError** if encoding is unrecognized * **ImportError** if missing the decompression module @@ -1045,7 +1050,7 @@ def _decompress(data, encoding): raise ValueError("'%s' isn't a recognized type of encoding" % encoding) -def _guess_descriptor_type(resource): +def _guess_descriptor_type(resource: str) -> str: # Attempts to determine the descriptor type based on the resource url. This # raises a ValueError if the resource isn't recognized. diff --git a/stem/descriptor/router_status_entry.py b/stem/descriptor/router_status_entry.py index c2d8dd07b..2c4937f37 100644 --- a/stem/descriptor/router_status_entry.py +++ b/stem/descriptor/router_status_entry.py @@ -27,7 +27,10 @@ import stem.exit_policy import stem.util.str_tools +from typing import Any, BinaryIO, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union + from stem.descriptor import ( + ENTRY_TYPE, KEYWORD_LINE, Descriptor, _descriptor_content, @@ -35,7 +38,7 @@ _values, _descriptor_components, _parse_protocol_line, - _read_until_keywords, + _read_until_keywords_with_ending_keyword, _random_nickname, _random_ipv4_address, _random_date, @@ -44,7 +47,7 @@ _parse_pr_line = _parse_protocol_line('pr', 'protocols') -def _parse_file(document_file, validate, entry_class, entry_keyword = 'r', start_position = None, end_position = None, section_end_keywords = (), extra_args = ()): +def _parse_file(document_file: BinaryIO, validate: bool, entry_class: Type['stem.descriptor.router_status_entry.RouterStatusEntry'], entry_keyword: str = 'r', start_position: Optional[int] = None, end_position: Optional[int] = None, section_end_keywords: Tuple[str, ...] = (), extra_args: Sequence[Any] = ()) -> Iterator['stem.descriptor.router_status_entry.RouterStatusEntry']: """ Reads a range of the document_file containing some number of entry_class instances. We deliminate the entry_class entries by the keyword on their @@ -91,7 +94,7 @@ def _parse_file(document_file, validate, entry_class, entry_keyword = 'r', start return while end_position is None or document_file.tell() < end_position: - desc_lines, ending_keyword = _read_until_keywords( + desc_lines, ending_keyword = _read_until_keywords_with_ending_keyword( (entry_keyword,) + section_end_keywords, document_file, ignore_first = True, @@ -111,7 +114,7 @@ def _parse_file(document_file, validate, entry_class, entry_keyword = 'r', start break -def _parse_r_line(descriptor, entries): +def _parse_r_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # Parses a RouterStatusEntry's 'r' line. They're very nearly identical for # all current entry types (v2, v3, and microdescriptor v3) with one little # wrinkle: only the microdescriptor flavor excludes a 'digest' field. @@ -163,7 +166,7 @@ def _parse_r_line(descriptor, entries): raise ValueError("Publication time time wasn't parsable: r %s" % value) -def _parse_a_line(descriptor, entries): +def _parse_a_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "a" SP address ":" portlist # example: a [2001:888:2133:0:82:94:251:204]:9001 @@ -186,7 +189,7 @@ def _parse_a_line(descriptor, entries): descriptor.or_addresses = or_addresses -def _parse_s_line(descriptor, entries): +def _parse_s_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "s" Flags # example: s Named Running Stable Valid @@ -201,7 +204,7 @@ def _parse_s_line(descriptor, entries): raise ValueError("%s had extra whitespace on its 's' line: s %s" % (descriptor._name(), value)) -def _parse_v_line(descriptor, entries): +def _parse_v_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "v" version # example: v Tor 0.2.2.35 # @@ -219,7 +222,7 @@ def _parse_v_line(descriptor, entries): raise ValueError('%s has a malformed tor version (%s): v %s' % (descriptor._name(), exc, value)) -def _parse_w_line(descriptor, entries): +def _parse_w_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "w" "Bandwidth=" INT ["Measured=" INT] ["Unmeasured=1"] # example: w Bandwidth=7980 @@ -266,7 +269,7 @@ def _parse_w_line(descriptor, entries): descriptor.unrecognized_bandwidth_entries = unrecognized_bandwidth_entries -def _parse_p_line(descriptor, entries): +def _parse_p_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "p" ("accept" / "reject") PortList # # examples: @@ -282,7 +285,7 @@ def _parse_p_line(descriptor, entries): raise ValueError('%s exit policy is malformed (%s): p %s' % (descriptor._name(), exc, value)) -def _parse_id_line(descriptor, entries): +def _parse_id_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "id" "ed25519" ed25519-identity # # examples: @@ -305,7 +308,7 @@ def _parse_id_line(descriptor, entries): raise ValueError("'id' lines should contain both the key type and digest: id %s" % value) -def _parse_m_line(descriptor, entries): +def _parse_m_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "m" methods 1*(algorithm "=" digest) # example: m 8,9,10,11,12 sha256=g1vx9si329muxV3tquWIXXySNOIwRGMeAESKs/v4DWs @@ -339,14 +342,14 @@ def _parse_m_line(descriptor, entries): descriptor.microdescriptor_hashes = all_hashes -def _parse_microdescriptor_m_line(descriptor, entries): +def _parse_microdescriptor_m_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "m" digest # example: m aiUklwBrua82obG5AsTX+iEpkjQA2+AQHxZ7GwMfY70 descriptor.microdescriptor_digest = _value('m', entries) -def _base64_to_hex(identity, check_if_fingerprint = True): +def _base64_to_hex(identity: str, check_if_fingerprint: bool = True) -> str: """ Decodes a base64 value to hex. For example... @@ -420,7 +423,7 @@ class RouterStatusEntry(Descriptor): } @classmethod - def from_str(cls, content, **kwargs): + def from_str(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntry'], content: str, **kwargs: Any) -> Union['stem.descriptor.router_status_entry.RouterStatusEntry', List['stem.descriptor.router_status_entry.RouterStatusEntry']]: # type: ignore # Router status entries don't have their own @type annotation, so to make # our subclass from_str() work we need to do the type inferencing ourself. @@ -440,14 +443,14 @@ def from_str(cls, content, **kwargs): else: raise ValueError("Descriptor.from_str() expected a single descriptor, but had %i instead. Please include 'multiple = True' if you want a list of results instead." % len(results)) - def __init__(self, content, validate = False, document = None): + def __init__(self, content: bytes, validate: bool = False, document: Optional['stem.descriptor.networkstatus.NetworkStatusDocument'] = None) -> None: """ Parse a router descriptor in a network status document. :param str content: router descriptor content to be parsed - :param NetworkStatusDocument document: document this descriptor came from :param bool validate: checks the validity of the content if **True**, skips these checks otherwise + :param NetworkStatusDocument document: document this descriptor came from :raises: **ValueError** if the descriptor data is invalid """ @@ -472,21 +475,21 @@ def __init__(self, content, validate = False, document = None): else: self._entries = entries - def _name(self, is_plural = False): + def _name(self, is_plural: bool = False) -> str: """ Name for this descriptor type. """ return 'Router status entries' if is_plural else 'Router status entry' - def _required_fields(self): + def _required_fields(self) -> Tuple[str, ...]: """ Provides lines that must appear in the descriptor. """ return () - def _single_fields(self): + def _single_fields(self) -> Tuple[str, ...]: """ Provides lines that can only appear in the descriptor once. """ @@ -512,18 +515,18 @@ class RouterStatusEntryV2(RouterStatusEntry): }) @classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryV2'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('r', '%s p1aag7VwarGxqctS7/fS0y5FU+s oQZFLYe9e4A7bOkWKR7TaNxb0JE %s %s 9001 0' % (_random_nickname(), _random_date(), _random_ipv4_address())), )) - def _name(self, is_plural = False): + def _name(self, is_plural: bool = False) -> str: return 'Router status entries (v2)' if is_plural else 'Router status entry (v2)' - def _required_fields(self): - return ('r') + def _required_fields(self) -> Tuple[str, ...]: + return ('r',) - def _single_fields(self): + def _single_fields(self) -> Tuple[str, ...]: return ('r', 's', 'v') @@ -577,7 +580,7 @@ class RouterStatusEntryV3(RouterStatusEntry): TYPE_ANNOTATION_NAME = 'network-status-consensus-3' - ATTRIBUTES = dict(RouterStatusEntry.ATTRIBUTES, **{ + ATTRIBUTES = dict(RouterStatusEntry.ATTRIBUTES, **{ # type: ignore 'digest': (None, _parse_r_line), 'or_addresses': ([], _parse_a_line), 'identifier_type': (None, _parse_id_line), @@ -593,7 +596,7 @@ class RouterStatusEntryV3(RouterStatusEntry): 'microdescriptor_hashes': ([], _parse_m_line), }) - PARSER_FOR_LINE = dict(RouterStatusEntry.PARSER_FOR_LINE, **{ + PARSER_FOR_LINE = dict(RouterStatusEntry.PARSER_FOR_LINE, **{ # type: ignore 'a': _parse_a_line, 'w': _parse_w_line, 'p': _parse_p_line, @@ -603,19 +606,19 @@ class RouterStatusEntryV3(RouterStatusEntry): }) @classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('r', '%s p1aag7VwarGxqctS7/fS0y5FU+s oQZFLYe9e4A7bOkWKR7TaNxb0JE %s %s 9001 0' % (_random_nickname(), _random_date(), _random_ipv4_address())), ('s', 'Fast Named Running Stable Valid'), )) - def _name(self, is_plural = False): + def _name(self, is_plural: bool = False) -> str: return 'Router status entries (v3)' if is_plural else 'Router status entry (v3)' - def _required_fields(self): + def _required_fields(self) -> Tuple[str, ...]: return ('r', 's') - def _single_fields(self): + def _single_fields(self) -> Tuple[str, ...]: return ('r', 's', 'v', 'w', 'p', 'pr') @@ -650,7 +653,7 @@ class RouterStatusEntryMicroV3(RouterStatusEntry): TYPE_ANNOTATION_NAME = 'network-status-microdesc-consensus-3' - ATTRIBUTES = dict(RouterStatusEntry.ATTRIBUTES, **{ + ATTRIBUTES = dict(RouterStatusEntry.ATTRIBUTES, **{ # type: ignore 'or_addresses': ([], _parse_a_line), 'bandwidth': (None, _parse_w_line), 'measured': (None, _parse_w_line), @@ -660,7 +663,7 @@ class RouterStatusEntryMicroV3(RouterStatusEntry): 'microdescriptor_digest': (None, _parse_microdescriptor_m_line), }) - PARSER_FOR_LINE = dict(RouterStatusEntry.PARSER_FOR_LINE, **{ + PARSER_FOR_LINE = dict(RouterStatusEntry.PARSER_FOR_LINE, **{ # type: ignore 'a': _parse_a_line, 'w': _parse_w_line, 'm': _parse_microdescriptor_m_line, @@ -668,18 +671,18 @@ class RouterStatusEntryMicroV3(RouterStatusEntry): }) @classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryMicroV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('r', '%s ARIJF2zbqirB9IwsW0mQznccWww %s %s 9001 9030' % (_random_nickname(), _random_date(), _random_ipv4_address())), ('m', 'aiUklwBrua82obG5AsTX+iEpkjQA2+AQHxZ7GwMfY70'), ('s', 'Fast Guard HSDir Named Running Stable V2Dir Valid'), )) - def _name(self, is_plural = False): + def _name(self, is_plural: bool = False) -> str: return 'Router status entries (micro v3)' if is_plural else 'Router status entry (micro v3)' - def _required_fields(self): + def _required_fields(self) -> Tuple[str, ...]: return ('r', 's', 'm') - def _single_fields(self): + def _single_fields(self) -> Tuple[str, ...]: return ('r', 's', 'v', 'w', 'm', 'pr') diff --git a/stem/descriptor/server_descriptor.py b/stem/descriptor/server_descriptor.py index 955b84299..fbb5c633b 100644 --- a/stem/descriptor/server_descriptor.py +++ b/stem/descriptor/server_descriptor.py @@ -61,15 +61,17 @@ from stem.descriptor.certificate import Ed25519Certificate from stem.descriptor.router_status_entry import RouterStatusEntryV3 +from typing import Any, BinaryIO, Iterator, Optional, Mapping, Sequence, Tuple, Type, Union from stem.descriptor import ( + ENTRY_TYPE, PGP_BLOCK_END, Descriptor, DigestHash, DigestEncoding, create_signing_key, _descriptor_content, - _descriptor_components, + _descriptor_components_with_extra, _read_until_keywords, _bytes_for_block, _value, @@ -139,11 +141,11 @@ DEFAULT_BRIDGE_DISTRIBUTION = 'any' -def _truncated_b64encode(content): +def _truncated_b64encode(content: bytes) -> str: return stem.util.str_tools._to_unicode(base64.b64encode(content).rstrip(b'=')) -def _parse_file(descriptor_file, is_bridge = False, validate = False, **kwargs): +def _parse_file(descriptor_file: BinaryIO, is_bridge: bool = False, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.server_descriptor.ServerDescriptor']: """ Iterates over the server descriptors in a file. @@ -213,14 +215,17 @@ def _parse_file(descriptor_file, is_bridge = False, validate = False, **kwargs): descriptor_text = bytes.join(b'', descriptor_content) if is_bridge: - yield BridgeDescriptor(descriptor_text, validate, **kwargs) + if kwargs: + raise ValueError('BUG: keyword arguments unused by bridge descriptors') + + yield BridgeDescriptor(descriptor_text, validate) else: yield RelayDescriptor(descriptor_text, validate, **kwargs) else: break # done parsing descriptors -def _parse_router_line(descriptor, entries): +def _parse_router_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "router" nickname address ORPort SocksPort DirPort value = _value('router', entries) @@ -246,7 +251,7 @@ def _parse_router_line(descriptor, entries): descriptor.dir_port = None if router_comp[4] == '0' else int(router_comp[4]) -def _parse_bandwidth_line(descriptor, entries): +def _parse_bandwidth_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "bandwidth" bandwidth-avg bandwidth-burst bandwidth-observed value = _value('bandwidth', entries) @@ -266,7 +271,7 @@ def _parse_bandwidth_line(descriptor, entries): descriptor.observed_bandwidth = int(bandwidth_comp[2]) -def _parse_platform_line(descriptor, entries): +def _parse_platform_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "platform" string _parse_bytes_line('platform', 'platform')(descriptor, entries) @@ -292,7 +297,7 @@ def _parse_platform_line(descriptor, entries): pass -def _parse_fingerprint_line(descriptor, entries): +def _parse_fingerprint_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # This is forty hex digits split into space separated groups of four. # Checking that we match this pattern. @@ -309,7 +314,7 @@ def _parse_fingerprint_line(descriptor, entries): descriptor.fingerprint = fingerprint -def _parse_extrainfo_digest_line(descriptor, entries): +def _parse_extrainfo_digest_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value('extra-info-digest', entries) digest_comp = value.split(' ') @@ -320,7 +325,7 @@ def _parse_extrainfo_digest_line(descriptor, entries): descriptor.extra_info_sha256_digest = digest_comp[1] if len(digest_comp) >= 2 else None -def _parse_hibernating_line(descriptor, entries): +def _parse_hibernating_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "hibernating" 0|1 (in practice only set if one) value = _value('hibernating', entries) @@ -331,7 +336,7 @@ def _parse_hibernating_line(descriptor, entries): descriptor.hibernating = value == '1' -def _parse_protocols_line(descriptor, entries): +def _parse_protocols_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value('protocols', entries) protocols_match = re.match('^Link (.*) Circuit (.*)$', value) @@ -343,7 +348,7 @@ def _parse_protocols_line(descriptor, entries): descriptor.circuit_protocols = circuit_versions.split(' ') -def _parse_or_address_line(descriptor, entries): +def _parse_or_address_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: all_values = _values('or-address', entries) or_addresses = [] @@ -366,7 +371,7 @@ def _parse_or_address_line(descriptor, entries): descriptor.or_addresses = or_addresses -def _parse_history_line(keyword, history_end_attribute, history_interval_attribute, history_values_attribute, descriptor, entries): +def _parse_history_line(keyword: str, history_end_attribute: str, history_interval_attribute: str, history_values_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value(keyword, entries) timestamp, interval, remainder = stem.descriptor.extrainfo_descriptor._parse_timestamp_and_interval(keyword, value) @@ -383,7 +388,7 @@ def _parse_history_line(keyword, history_end_attribute, history_interval_attribu setattr(descriptor, history_values_attribute, history_values) -def _parse_exit_policy(descriptor, entries): +def _parse_exit_policy(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: if hasattr(descriptor, '_unparsed_exit_policy'): if descriptor._unparsed_exit_policy and stem.util.str_tools._to_unicode(descriptor._unparsed_exit_policy[0]) == 'reject *:*': descriptor.exit_policy = REJECT_ALL_POLICY @@ -576,7 +581,7 @@ class ServerDescriptor(Descriptor): 'eventdns': _parse_eventdns_line, } - def __init__(self, raw_contents, validate = False): + def __init__(self, raw_contents: bytes, validate: bool = False) -> None: """ Server descriptor constructor, created from an individual relay's descriptor content (as provided by 'GETINFO desc/*', cached descriptors, @@ -603,7 +608,7 @@ def __init__(self, raw_contents, validate = False): # influences the resulting exit policy, but for everything else the order # does not matter so breaking it into key / value pairs. - entries, self._unparsed_exit_policy = _descriptor_components(stem.util.str_tools._to_unicode(raw_contents), validate, extra_keywords = ('accept', 'reject'), non_ascii_fields = ('contact', 'platform')) + entries, self._unparsed_exit_policy = _descriptor_components_with_extra(raw_contents, validate, extra_keywords = ('accept', 'reject'), non_ascii_fields = ('contact', 'platform')) if validate: self._parse(entries, validate) @@ -621,7 +626,7 @@ def __init__(self, raw_contents, validate = False): else: self._entries = entries - def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore """ Digest of this descriptor's content. These are referenced by... @@ -641,7 +646,7 @@ def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): raise NotImplementedError('Unsupported Operation: this should be implemented by the ServerDescriptor subclass') - def _check_constraints(self, entries): + def _check_constraints(self, entries: ENTRY_TYPE) -> None: """ Does a basic check that the entries conform to this descriptor type's constraints. @@ -679,16 +684,16 @@ def _check_constraints(self, entries): # Constraints that the descriptor must meet to be valid. These can be None if # not applicable. - def _required_fields(self): + def _required_fields(self) -> Tuple[str, ...]: return REQUIRED_FIELDS - def _single_fields(self): + def _single_fields(self) -> Tuple[str, ...]: return REQUIRED_FIELDS + SINGLE_FIELDS - def _first_keyword(self): + def _first_keyword(self) -> str: return 'router' - def _last_keyword(self): + def _last_keyword(self) -> Optional[str]: return 'router-signature' @@ -753,7 +758,7 @@ class RelayDescriptor(ServerDescriptor): 'router-signature': _parse_router_signature_line, }) - def __init__(self, raw_contents, validate = False, skip_crypto_validation = False): + def __init__(self, raw_contents: bytes, validate: bool = False, skip_crypto_validation: bool = False) -> None: super(RelayDescriptor, self).__init__(raw_contents, validate) if validate: @@ -785,9 +790,8 @@ def __init__(self, raw_contents, validate = False, skip_crypto_validation = Fals pass # cryptography module unavailable @classmethod - def content(cls, attr = None, exclude = (), sign = False, signing_key = None, exit_policy = None): - if attr is None: - attr = {} + def content(cls: Type['stem.descriptor.server_descriptor.RelayDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None, exit_policy: Optional['stem.exit_policy.ExitPolicy'] = None) -> bytes: + attr = dict(attr) if attr else {} if exit_policy is None: exit_policy = REJECT_ALL_POLICY @@ -797,7 +801,7 @@ def content(cls, attr = None, exclude = (), sign = False, signing_key = None, ex ('published', _random_date()), ('bandwidth', '153600 256000 104590'), ] + [ - tuple(line.split(' ', 1)) for line in str(exit_policy).splitlines() + tuple(line.split(' ', 1)) for line in str(exit_policy).splitlines() # type: ignore ] + [ ('onion-key', _random_crypto_blob('RSA PUBLIC KEY')), ('signing-key', _random_crypto_blob('RSA PUBLIC KEY')), @@ -827,15 +831,18 @@ def content(cls, attr = None, exclude = (), sign = False, signing_key = None, ex )) @classmethod - def create(cls, attr = None, exclude = (), validate = True, sign = False, signing_key = None, exit_policy = None): + def create(cls: Type['stem.descriptor.server_descriptor.RelayDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None, exit_policy: Optional['stem.exit_policy.ExitPolicy'] = None) -> 'stem.descriptor.server_descriptor.RelayDescriptor': return cls(cls.content(attr, exclude, sign, signing_key, exit_policy), validate = validate, skip_crypto_validation = not sign) @functools.lru_cache() - def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore """ Provides the digest of our descriptor's content. - :returns: the digest string encoded in uppercase hex + :param stem.descriptor.DigestHash hash_type: digest hashing algorithm + :param stem.descriptor.DigestEncoding encoding: digest encoding + + :returns: **hashlib.HASH** or **str** based on our encoding argument :raises: ValueError if the digest cannot be calculated """ @@ -849,7 +856,7 @@ def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): else: raise NotImplementedError('Server descriptor digests are only available in sha1 and sha256, not %s' % hash_type) - def make_router_status_entry(self): + def make_router_status_entry(self) -> 'stem.descriptor.router_status_entry.RouterStatusEntryV3': """ Provides a RouterStatusEntryV3 for this descriptor content. @@ -885,15 +892,15 @@ def make_router_status_entry(self): if self.certificate: attr['id'] = 'ed25519 %s' % _truncated_b64encode(self.certificate.key) - return RouterStatusEntryV3.create(attr) + return RouterStatusEntryV3.create(attr) # type: ignore @functools.lru_cache() - def _onion_key_crosscert_digest(self): + def _onion_key_crosscert_digest(self) -> str: """ Provides the digest of the onion-key-crosscert data. This consists of the RSA identity key sha1 and ed25519 identity key. - :returns: **unicode** digest encoded in uppercase hex + :returns: **str** digest encoded in uppercase hex :raises: ValueError if the digest cannot be calculated """ @@ -902,7 +909,7 @@ def _onion_key_crosscert_digest(self): data = signing_key_digest + base64.b64decode(stem.util.str_tools._to_bytes(self.ed25519_master_key) + b'=') return stem.util.str_tools._to_unicode(binascii.hexlify(data).upper()) - def _check_constraints(self, entries): + def _check_constraints(self, entries: ENTRY_TYPE) -> None: super(RelayDescriptor, self)._check_constraints(entries) if self.certificate: @@ -941,7 +948,7 @@ class BridgeDescriptor(ServerDescriptor): }) @classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.server_descriptor.BridgeDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('router', '%s %s 9001 0 0' % (_random_nickname(), _random_ipv4_address())), ('router-digest', '006FD96BA35E7785A6A3B8B75FE2E2435A13BDB4'), @@ -950,13 +957,13 @@ def content(cls, attr = None, exclude = ()): ('reject', '*:*'), )) - def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore if hash_type == DigestHash.SHA1 and encoding == DigestEncoding.HEX: return self._digest else: raise NotImplementedError('Bridge server descriptor digests are only available as sha1/hex, not %s/%s' % (hash_type, encoding)) - def is_scrubbed(self): + def is_scrubbed(self) -> bool: """ Checks if we've been properly scrubbed in accordance with the `bridge descriptor specification @@ -969,7 +976,7 @@ def is_scrubbed(self): return self.get_scrubbing_issues() == [] @functools.lru_cache() - def get_scrubbing_issues(self): + def get_scrubbing_issues(self) -> Sequence[str]: """ Provides issues with our scrubbing. @@ -1003,7 +1010,7 @@ def get_scrubbing_issues(self): return issues - def _required_fields(self): + def _required_fields(self) -> Tuple[str, ...]: # bridge required fields are the same as a relay descriptor, minus items # excluded according to the format page @@ -1019,8 +1026,8 @@ def _required_fields(self): return tuple(included_fields + [f for f in REQUIRED_FIELDS if f not in excluded_fields]) - def _single_fields(self): + def _single_fields(self) -> Tuple[str, ...]: return self._required_fields() + SINGLE_FIELDS - def _last_keyword(self): + def _last_keyword(self) -> Optional[str]: return None diff --git a/stem/descriptor/tordnsel.py b/stem/descriptor/tordnsel.py index d0f57b939..6b9d4cebc 100644 --- a/stem/descriptor/tordnsel.py +++ b/stem/descriptor/tordnsel.py @@ -10,18 +10,23 @@ TorDNSEL - Exit list provided by TorDNSEL """ +import datetime + import stem.util.connection import stem.util.str_tools import stem.util.tor_tools +from typing import Any, BinaryIO, Callable, Dict, Iterator, List, Optional, Tuple + from stem.descriptor import ( + ENTRY_TYPE, Descriptor, _read_until_keywords, _descriptor_components, ) -def _parse_file(tordnsel_file, validate = False, **kwargs): +def _parse_file(tordnsel_file: BinaryIO, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.tordnsel.TorDNSEL']: """ Iterates over a tordnsel file. @@ -33,6 +38,9 @@ def _parse_file(tordnsel_file, validate = False, **kwargs): * **IOError** if the file can't be read """ + if kwargs: + raise ValueError("TorDNSEL doesn't support additional arguments: %s" % kwargs) + # skip content prior to the first ExitNode _read_until_keywords('ExitNode', tordnsel_file, skip = True) @@ -41,7 +49,7 @@ def _parse_file(tordnsel_file, validate = False, **kwargs): contents += _read_until_keywords('ExitNode', tordnsel_file) if contents: - yield TorDNSEL(bytes.join(b'', contents), validate, **kwargs) + yield TorDNSEL(bytes.join(b'', contents), validate) else: break # done parsing file @@ -62,19 +70,20 @@ class TorDNSEL(Descriptor): TYPE_ANNOTATION_NAME = 'tordnsel' - def __init__(self, raw_contents, validate): + def __init__(self, raw_contents: bytes, validate: bool) -> None: super(TorDNSEL, self).__init__(raw_contents) - raw_contents = stem.util.str_tools._to_unicode(raw_contents) entries = _descriptor_components(raw_contents, validate) - self.fingerprint = None - self.published = None - self.last_status = None - self.exit_addresses = [] + self.fingerprint = None # type: Optional[str] + self.published = None # type: Optional[datetime.datetime] + self.last_status = None # type: Optional[datetime.datetime] + self.exit_addresses = [] # type: List[Tuple[str, datetime.datetime]] self._parse(entries, validate) - def _parse(self, entries, validate): + def _parse(self, entries: ENTRY_TYPE, validate: bool, parser_for_line: Optional[Dict[str, Callable]] = None) -> None: + if parser_for_line: + raise ValueError('parser_for_line is unused by TorDNSEL') for keyword, values in list(entries.items()): value, block_type, block_content = values[0] @@ -101,7 +110,7 @@ def _parse(self, entries, validate): raise ValueError("LastStatus time wasn't parsable: %s" % value) elif keyword == 'ExitAddress': for value, block_type, block_content in values: - address, date = value.split(' ', 1) + address, date_str = value.split(' ', 1) if validate: if not stem.util.connection.is_valid_ipv4_address(address): @@ -110,7 +119,7 @@ def _parse(self, entries, validate): raise ValueError('Unexpected block content: %s' % block_content) try: - date = stem.util.str_tools._parse_timestamp(date) + date = stem.util.str_tools._parse_timestamp(date_str) self.exit_addresses.append((address, date)) except ValueError: if validate: diff --git a/stem/directory.py b/stem/directory.py index 67079c804..3ecb0b710 100644 --- a/stem/directory.py +++ b/stem/directory.py @@ -49,6 +49,7 @@ import stem.util.conf from stem.util import connection, str_tools, tor_tools +from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Pattern, Sequence, Tuple, Union GITWEB_AUTHORITY_URL = 'https://gitweb.torproject.org/tor.git/plain/src/app/config/auth_dirs.inc' GITWEB_FALLBACK_URL = 'https://gitweb.torproject.org/tor.git/plain/src/app/config/fallback_dirs.inc' @@ -68,7 +69,7 @@ FALLBACK_IPV6 = re.compile('" ipv6=\\[([\\da-f:]+)\\]:(\\d+)"') -def _match_with(lines, regexes, required = None): +def _match_with(lines: Sequence[str], regexes: Sequence[Pattern], required: Optional[Sequence[Pattern]] = None) -> Dict[Pattern, Union[str, List[str]]]: """ Scans the given content against a series of regex matchers, providing back a mapping of regexes to their capture groups. This maping is with the value if @@ -101,7 +102,7 @@ def _match_with(lines, regexes, required = None): return matches -def _directory_entries(lines, pop_section_func, regexes, required = None): +def _directory_entries(lines: List[str], pop_section_func: Callable[[List[str]], List[str]], regexes: Sequence[Pattern], required: Optional[Sequence[Pattern]] = None) -> Iterator[Dict[Pattern, Union[str, List[str]]]]: next_section = pop_section_func(lines) while next_section: @@ -129,11 +130,11 @@ class Directory(object): :var int dir_port: port on which directory information is available :var str fingerprint: relay fingerprint :var str nickname: relay nickname - :var str orport_v6: **(address, port)** tuple for the directory's IPv6 + :var tuple orport_v6: **(address, port)** tuple for the directory's IPv6 ORPort, or **None** if it doesn't have one """ - def __init__(self, address, or_port, dir_port, fingerprint, nickname, orport_v6): + def __init__(self, address: str, or_port: Union[int, str], dir_port: Union[int, str], fingerprint: str, nickname: str, orport_v6: Tuple[str, int]) -> None: identifier = '%s (%s)' % (fingerprint, nickname) if nickname else fingerprint if not connection.is_valid_ipv4_address(address): @@ -163,7 +164,7 @@ def __init__(self, address, or_port, dir_port, fingerprint, nickname, orport_v6) self.orport_v6 = (orport_v6[0], int(orport_v6[1])) if orport_v6 else None @staticmethod - def from_cache(): + def from_cache() -> Dict[str, Any]: """ Provides cached Tor directory information. This information is hardcoded into Tor and occasionally changes, so the information provided by this @@ -181,7 +182,7 @@ def from_cache(): raise NotImplementedError('Unsupported Operation: this should be implemented by the Directory subclass') @staticmethod - def from_remote(timeout = 60): + def from_remote(timeout: int = 60) -> Dict[str, Any]: """ Reads and parses tor's directory data `from gitweb.torproject.org `_. Note that while convenient, this reliance on GitWeb means you should alway @@ -209,13 +210,13 @@ def from_remote(timeout = 60): raise NotImplementedError('Unsupported Operation: this should be implemented by the Directory subclass') - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'address', 'or_port', 'dir_port', 'fingerprint', 'nickname', 'orport_v6') - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, Directory) else False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other @@ -231,7 +232,7 @@ class Authority(Directory): :var str v3ident: identity key fingerprint used to sign votes and consensus """ - def __init__(self, address = None, or_port = None, dir_port = None, fingerprint = None, nickname = None, orport_v6 = None, v3ident = None): + def __init__(self, address: Optional[str] = None, or_port: Optional[Union[int, str]] = None, dir_port: Optional[Union[int, str]] = None, fingerprint: Optional[str] = None, nickname: Optional[str] = None, orport_v6: Optional[Tuple[str, int]] = None, v3ident: Optional[str] = None) -> None: super(Authority, self).__init__(address, or_port, dir_port, fingerprint, nickname, orport_v6) if v3ident and not tor_tools.is_valid_fingerprint(v3ident): @@ -241,11 +242,11 @@ def __init__(self, address = None, or_port = None, dir_port = None, fingerprint self.v3ident = v3ident @staticmethod - def from_cache(): + def from_cache() -> Dict[str, 'stem.directory.Authority']: return dict(DIRECTORY_AUTHORITIES) @staticmethod - def from_remote(timeout = 60): + def from_remote(timeout: int = 60) -> Dict[str, 'stem.directory.Authority']: try: lines = str_tools._to_unicode(urllib.request.urlopen(GITWEB_AUTHORITY_URL, timeout = timeout).read()).splitlines() @@ -275,8 +276,8 @@ def from_remote(timeout = 60): dir_port = dir_port, fingerprint = fingerprint.replace(' ', ''), nickname = nickname, - orport_v6 = matches.get(AUTHORITY_IPV6), - v3ident = matches.get(AUTHORITY_V3IDENT), + orport_v6 = matches.get(AUTHORITY_IPV6), # type: ignore + v3ident = matches.get(AUTHORITY_V3IDENT), # type: ignore ) except ValueError as exc: raise IOError(str(exc)) @@ -284,7 +285,7 @@ def from_remote(timeout = 60): return results @staticmethod - def _pop_section(lines): + def _pop_section(lines: List[str]) -> List[str]: """ Provides the next authority entry. """ @@ -299,13 +300,13 @@ def _pop_section(lines): return section_lines - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'v3ident', parent = Directory, cache = True) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, Authority) else False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other @@ -348,13 +349,13 @@ class Fallback(Directory): :var collections.OrderedDict header: metadata about the fallback directory file this originated from """ - def __init__(self, address = None, or_port = None, dir_port = None, fingerprint = None, nickname = None, has_extrainfo = False, orport_v6 = None, header = None): + def __init__(self, address: Optional[str] = None, or_port: Optional[Union[int, str]] = None, dir_port: Optional[Union[int, str]] = None, fingerprint: Optional[str] = None, nickname: Optional[str] = None, has_extrainfo: bool = False, orport_v6: Optional[Tuple[str, int]] = None, header: Optional[Mapping[str, str]] = None) -> None: super(Fallback, self).__init__(address, or_port, dir_port, fingerprint, nickname, orport_v6) self.has_extrainfo = has_extrainfo self.header = collections.OrderedDict(header) if header else collections.OrderedDict() @staticmethod - def from_cache(path = FALLBACK_CACHE_PATH): + def from_cache(path: str = FALLBACK_CACHE_PATH) -> Dict[str, 'stem.directory.Fallback']: conf = stem.util.conf.Config() conf.load(path) headers = collections.OrderedDict([(k.split('.', 1)[1], conf.get(k)) for k in conf.keys() if k.startswith('header.')]) @@ -393,7 +394,7 @@ def from_cache(path = FALLBACK_CACHE_PATH): return results @staticmethod - def from_remote(timeout = 60): + def from_remote(timeout: int = 60) -> Dict[str, 'stem.directory.Fallback']: try: lines = str_tools._to_unicode(urllib.request.urlopen(GITWEB_FALLBACK_URL, timeout = timeout).read()).splitlines() @@ -439,9 +440,9 @@ def from_remote(timeout = 60): or_port = int(or_port), dir_port = int(dir_port), fingerprint = fingerprint, - nickname = matches.get(FALLBACK_NICKNAME), + nickname = matches.get(FALLBACK_NICKNAME), # type: ignore has_extrainfo = matches.get(FALLBACK_EXTRAINFO) == '1', - orport_v6 = matches.get(FALLBACK_IPV6), + orport_v6 = matches.get(FALLBACK_IPV6), # type: ignore header = header, ) except ValueError as exc: @@ -450,7 +451,7 @@ def from_remote(timeout = 60): return results @staticmethod - def _pop_section(lines): + def _pop_section(lines: List[str]) -> List[str]: """ Provides lines up through the next divider. This excludes lines with just a comma since they're an artifact of these being C strings. @@ -470,7 +471,7 @@ def _pop_section(lines): return section_lines @staticmethod - def _write(fallbacks, tor_commit, stem_commit, headers, path = FALLBACK_CACHE_PATH): + def _write(fallbacks: Dict[str, 'stem.directory.Fallback'], tor_commit: str, stem_commit: str, headers: Mapping[str, str], path: str = FALLBACK_CACHE_PATH) -> None: """ Persists fallback directories to a location in a way that can be read by from_cache(). @@ -503,17 +504,17 @@ def _write(fallbacks, tor_commit, stem_commit, headers, path = FALLBACK_CACHE_PA conf.save(path) - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'has_extrainfo', 'header', parent = Directory, cache = True) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, Fallback) else False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other -def _fallback_directory_differences(previous_directories, new_directories): +def _fallback_directory_differences(previous_directories: Mapping[str, 'stem.directory.Fallback'], new_directories: Mapping[str, 'stem.directory.Fallback']) -> str: """ Provides a description of how fallback directories differ. """ diff --git a/stem/exit_policy.py b/stem/exit_policy.py index ddcd7dfd0..19178c9ac 100644 --- a/stem/exit_policy.py +++ b/stem/exit_policy.py @@ -71,6 +71,8 @@ import stem.util.enum import stem.util.str_tools +from typing import Any, Iterator, List, Optional, Sequence, Set, Union + AddressType = stem.util.enum.Enum(('WILDCARD', 'Wildcard'), ('IPv4', 'IPv4'), ('IPv6', 'IPv6')) # Addresses aliased by the 'private' policy. From the tor man page... @@ -89,7 +91,7 @@ ) -def _flag_private_rules(rules): +def _flag_private_rules(rules: Sequence['ExitPolicyRule']) -> None: """ Determine if part of our policy was expanded from the 'private' keyword. This doesn't differentiate if this actually came from the 'private' keyword or a @@ -139,7 +141,7 @@ def _flag_private_rules(rules): last_rule._is_private = True -def _flag_default_rules(rules): +def _flag_default_rules(rules: Sequence['ExitPolicyRule']) -> None: """ Determine if part of our policy ends with the defaultly appended suffix. """ @@ -162,9 +164,11 @@ class ExitPolicy(object): entries that make up this policy """ - def __init__(self, *rules): + def __init__(self, *rules: Union[str, 'stem.exit_policy.ExitPolicyRule']) -> None: # sanity check the types + self._input_rules = None # type: Optional[Union[bytes, Sequence[Union[str, bytes, stem.exit_policy.ExitPolicyRule]]]] + for rule in rules: if not isinstance(rule, (bytes, str)) and not isinstance(rule, ExitPolicyRule): raise TypeError('Exit policy rules can only contain strings or ExitPolicyRules, got a %s (%s)' % (type(rule), rules)) @@ -181,13 +185,14 @@ def __init__(self, *rules): is_all_str = False if rules and is_all_str: - byte_rules = [stem.util.str_tools._to_bytes(r) for r in rules] + byte_rules = [stem.util.str_tools._to_bytes(r) for r in rules] # type: ignore self._input_rules = zlib.compress(b','.join(byte_rules)) else: self._input_rules = rules - self._rules = None - self._hash = None + self._policy_str = None # type: Optional[str] + self._rules = None # type: List[stem.exit_policy.ExitPolicyRule] + self._hash = None # type: Optional[int] # Result when no rules apply. According to the spec policies default to 'is # allowed', but our microdescriptor policy subclass might want to change @@ -196,7 +201,7 @@ def __init__(self, *rules): self._is_allowed_default = True @functools.lru_cache() - def can_exit_to(self, address = None, port = None, strict = False): + def can_exit_to(self, address: Optional[str] = None, port: Optional[int] = None, strict: bool = False) -> bool: """ Checks if this policy allows exiting to a given destination or not. If the address or port is omitted then this will check if we're allowed to exit to @@ -220,13 +225,13 @@ def can_exit_to(self, address = None, port = None, strict = False): return self._is_allowed_default @functools.lru_cache() - def is_exiting_allowed(self): + def is_exiting_allowed(self) -> bool: """ Provides **True** if the policy allows exiting whatsoever, **False** otherwise. """ - rejected_ports = set() + rejected_ports = set() # type: Set[int] for rule in self._get_rules(): if rule.is_accept: @@ -242,7 +247,7 @@ def is_exiting_allowed(self): return self._is_allowed_default @functools.lru_cache() - def summary(self): + def summary(self) -> str: """ Provides a short description of our policy chain, similar to a microdescriptor. This excludes entries that don't cover all IP @@ -296,7 +301,8 @@ def summary(self): # convert port list to a list of ranges (ie, ['1-3'] rather than [1, 2, 3]) if display_ports: - display_ranges, temp_range = [], [] + display_ranges = [] + temp_range = [] # type: List[int] display_ports.sort() display_ports.append(None) # ending item to include last range in loop @@ -320,7 +326,7 @@ def summary(self): return (label_prefix + ', '.join(display_ranges)).strip() - def has_private(self): + def has_private(self) -> bool: """ Checks if we have any rules expanded from the 'private' keyword. Tor appends these by default to the start of the policy and includes a dynamic @@ -338,7 +344,7 @@ def has_private(self): return False - def strip_private(self): + def strip_private(self) -> 'ExitPolicy': """ Provides a copy of this policy without 'private' policy entries. @@ -349,7 +355,7 @@ def strip_private(self): return ExitPolicy(*[rule for rule in self._get_rules() if not rule.is_private()]) - def has_default(self): + def has_default(self) -> bool: """ Checks if we have the default policy suffix. @@ -364,7 +370,7 @@ def has_default(self): return False - def strip_default(self): + def strip_default(self) -> 'ExitPolicy': """ Provides a copy of this policy without the default policy suffix. @@ -375,30 +381,35 @@ def strip_default(self): return ExitPolicy(*[rule for rule in self._get_rules() if not rule.is_default()]) - def _get_rules(self): + def _get_rules(self) -> Sequence['stem.exit_policy.ExitPolicyRule']: # Local reference to our input_rules so this can be lock free. Otherwise # another thread might unset our input_rules while processing them. input_rules = self._input_rules if self._rules is None and input_rules is not None: - rules = [] + rules = [] # type: List[stem.exit_policy.ExitPolicyRule] is_all_accept, is_all_reject = True, True + decompressed_rules = None # type: Optional[Sequence[Union[str, bytes, stem.exit_policy.ExitPolicyRule]]] if isinstance(input_rules, bytes): decompressed_rules = zlib.decompress(input_rules).split(b',') else: decompressed_rules = input_rules - for rule in decompressed_rules: - if isinstance(rule, bytes): - rule = stem.util.str_tools._to_unicode(rule) + for rule_val in decompressed_rules: + if isinstance(rule_val, bytes): + rule_val = stem.util.str_tools._to_unicode(rule_val) - if isinstance(rule, (bytes, str)): - if not rule.strip(): + if isinstance(rule_val, str): + if not rule_val.strip(): continue - rule = ExitPolicyRule(rule.strip()) + rule = ExitPolicyRule(rule_val.strip()) + elif isinstance(rule_val, stem.exit_policy.ExitPolicyRule): + rule = rule_val + else: + raise TypeError('BUG: unexpected type within decompressed policy: %s (%s)' % (stem.util.str_tools._to_unicode(rule_val), type(rule_val).__name__)) if rule.is_accept: is_all_reject = False @@ -437,18 +448,20 @@ def _get_rules(self): return self._rules - def __len__(self): + def __len__(self) -> int: return len(self._get_rules()) - def __iter__(self): + def __iter__(self) -> Iterator['stem.exit_policy.ExitPolicyRule']: for rule in self._get_rules(): yield rule - @functools.lru_cache() - def __str__(self): - return ', '.join([str(rule) for rule in self._get_rules()]) + def __str__(self) -> str: + if self._policy_str is None: + self._policy_str = ', '.join([str(rule) for rule in self._get_rules()]) + + return self._policy_str - def __hash__(self): + def __hash__(self) -> int: if self._hash is None: my_hash = 0 @@ -460,10 +473,10 @@ def __hash__(self): return self._hash - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, ExitPolicy) else False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other @@ -495,7 +508,7 @@ class MicroExitPolicy(ExitPolicy): :param str policy: policy string that describes this policy """ - def __init__(self, policy): + def __init__(self, policy: str) -> None: # Microdescriptor policies are of the form... # # MicrodescriptrPolicy ::= ("accept" / "reject") SP PortList NL @@ -503,7 +516,7 @@ def __init__(self, policy): # PortList ::= PortList "," PortOrRange # PortOrRange ::= INT "-" INT / INT - self._policy = policy + policy_str = policy if policy.startswith('accept'): self.is_accept = True @@ -515,7 +528,7 @@ def __init__(self, policy): policy = policy[6:] if not policy.startswith(' '): - raise ValueError('A microdescriptor exit policy should have a space separating accept/reject from its port list: %s' % self._policy) + raise ValueError('A microdescriptor exit policy should have a space separating accept/reject from its port list: %s' % policy_str) policy = policy.lstrip() @@ -536,17 +549,18 @@ def __init__(self, policy): super(MicroExitPolicy, self).__init__(*rules) self._is_allowed_default = not self.is_accept + self._policy_str = policy_str - def __str__(self): - return self._policy + def __str__(self) -> str: + return self._policy_str - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, MicroExitPolicy) else False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other @@ -580,7 +594,7 @@ class ExitPolicyRule(object): :raises: **ValueError** if input isn't a valid tor exit policy rule """ - def __init__(self, rule): + def __init__(self, rule: str) -> None: # policy ::= "accept[6]" exitpattern | "reject[6]" exitpattern # exitpattern ::= addrspec ":" portspec @@ -604,17 +618,17 @@ def __init__(self, rule): if ':' not in exitpattern or ']' in exitpattern.rsplit(':', 1)[1]: raise ValueError("An exitpattern must be of the form 'addrspec:portspec': %s" % rule) - self.address = None - self._address_type = None - self._masked_bits = None - self.min_port = self.max_port = None - self._hash = None + self.address = None # type: Optional[str] + self._address_type = None # type: Optional[stem.exit_policy.AddressType] + self._masked_bits = None # type: Optional[int] + self.min_port = self.max_port = None # type: Optional[int] + self._hash = None # type: Optional[int] # Our mask in ip notation (ex. '255.255.255.0'). This is only set if we # either have a custom mask that can't be represented by a number of bits, # or the user has called mask(), lazily loading this. - self._mask = None + self._mask = None # type: Optional[str] # Malformed exit policies are rejected, but there's an exception where it's # just skipped: when an accept6/reject6 rule has an IPv4 address... @@ -634,7 +648,7 @@ def __init__(self, rule): self._is_private = False self._is_default_suffix = False - def is_address_wildcard(self): + def is_address_wildcard(self) -> bool: """ **True** if we'll match against **any** address, **False** otherwise. @@ -646,7 +660,7 @@ def is_address_wildcard(self): return self._address_type == _address_type_to_int(AddressType.WILDCARD) - def is_port_wildcard(self): + def is_port_wildcard(self) -> bool: """ **True** if we'll match against any port, **False** otherwise. @@ -655,7 +669,7 @@ def is_port_wildcard(self): return self.min_port in (0, 1) and self.max_port == 65535 - def is_match(self, address = None, port = None, strict = False): + def is_match(self, address: Optional[str] = None, port: Optional[int] = None, strict: bool = False) -> bool: """ **True** if we match against the given destination, **False** otherwise. If the address or port is omitted then this will check if we're allowed to @@ -726,7 +740,7 @@ def is_match(self, address = None, port = None, strict = False): else: return True - def get_address_type(self): + def get_address_type(self) -> AddressType: """ Provides the :data:`~stem.exit_policy.AddressType` for our policy. @@ -735,7 +749,7 @@ def get_address_type(self): return _int_to_address_type(self._address_type) - def get_mask(self, cache = True): + def get_mask(self, cache: bool = True) -> str: """ Provides the address represented by our mask. This is **None** if our address type is a wildcard. @@ -765,7 +779,7 @@ def get_mask(self, cache = True): return self._mask - def get_masked_bits(self): + def get_masked_bits(self) -> int: """ Provides the number of bits our subnet mask represents. This is **None** if our mask can't have a bit representation. @@ -775,7 +789,7 @@ def get_masked_bits(self): return self._masked_bits - def is_private(self): + def is_private(self) -> bool: """ Checks if this rule was expanded from the 'private' policy keyword. @@ -786,7 +800,7 @@ def is_private(self): return self._is_private - def is_default(self): + def is_default(self) -> bool: """ Checks if this rule belongs to the default exit policy suffix. @@ -798,7 +812,7 @@ def is_default(self): return self._is_default_suffix @functools.lru_cache() - def __str__(self): + def __str__(self) -> str: """ Provides the string representation of our policy. This does not necessarily match the rule that we were constructed from (due to things @@ -842,18 +856,18 @@ def __str__(self): return label @functools.lru_cache() - def _get_mask_bin(self): + def _get_mask_bin(self) -> int: # provides an integer representation of our mask return int(stem.util.connection._address_to_binary(self.get_mask(False)), 2) @functools.lru_cache() - def _get_address_bin(self): + def _get_address_bin(self) -> int: # provides an integer representation of our address return stem.util.connection.address_to_int(self.address) & self._get_mask_bin() - def _apply_addrspec(self, rule, addrspec, is_ipv6_only): + def _apply_addrspec(self, rule: str, addrspec: str, is_ipv6_only: bool) -> None: # Parses the addrspec... # addrspec ::= "*" | ip4spec | ip6spec @@ -924,7 +938,7 @@ def _apply_addrspec(self, rule, addrspec, is_ipv6_only): else: raise ValueError("'%s' isn't a wildcard, IPv4, or IPv6 address: %s" % (addrspec, rule)) - def _apply_portspec(self, rule, portspec): + def _apply_portspec(self, rule: str, portspec: str) -> None: # Parses the portspec... # portspec ::= "*" | port | port "-" port # port ::= an integer between 1 and 65535, inclusive. @@ -955,24 +969,24 @@ def _apply_portspec(self, rule, portspec): else: raise ValueError("Port value isn't a wildcard, integer, or range: %s" % rule) - def __hash__(self): + def __hash__(self) -> int: if self._hash is None: self._hash = stem.util._hash_attr(self, 'is_accept', 'address', 'min_port', 'max_port') * 1024 + hash(self.get_mask(False)) return self._hash - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, ExitPolicyRule) else False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other -def _address_type_to_int(address_type): +def _address_type_to_int(address_type: AddressType) -> int: return AddressType.index_of(address_type) -def _int_to_address_type(address_type_int): +def _int_to_address_type(address_type_int: int) -> AddressType: return list(AddressType)[address_type_int] @@ -981,32 +995,32 @@ class MicroExitPolicyRule(ExitPolicyRule): Lighter weight ExitPolicyRule derivative for microdescriptors. """ - def __init__(self, is_accept, min_port, max_port): + def __init__(self, is_accept: bool, min_port: int, max_port: int) -> None: self.is_accept = is_accept self.address = None # wildcard address self.min_port = min_port self.max_port = max_port self._skip_rule = False - def is_address_wildcard(self): + def is_address_wildcard(self) -> bool: return True - def get_address_type(self): + def get_address_type(self) -> AddressType: return AddressType.WILDCARD - def get_mask(self, cache = True): + def get_mask(self, cache = True) -> str: return None - def get_masked_bits(self): + def get_masked_bits(self) -> int: return None - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'is_accept', 'min_port', 'max_port', cache = True) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, MicroExitPolicyRule) else False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other diff --git a/stem/interpreter/__init__.py b/stem/interpreter/__init__.py index 07a5f5736..1d08abb6a 100644 --- a/stem/interpreter/__init__.py +++ b/stem/interpreter/__init__.py @@ -38,11 +38,11 @@ @uses_settings -def msg(message, config, **attr): +def msg(message: str, config: 'stem.util.conf.Config', **attr: str) -> str: return config.get(message).format(**attr) -def main(): +def main() -> None: try: import readline except ImportError: @@ -54,13 +54,13 @@ def main(): import stem.interpreter.commands try: - args = stem.interpreter.arguments.parse(sys.argv[1:]) + args = stem.interpreter.arguments.Arguments.parse(sys.argv[1:]) except ValueError as exc: print(exc) sys.exit(1) if args.print_help: - print(stem.interpreter.arguments.get_help()) + print(stem.interpreter.arguments.Arguments.get_help()) sys.exit() if args.disable_color or not sys.stdout.isatty(): @@ -82,13 +82,11 @@ def main(): if not args.run_cmd and not args.run_path: print(format(msg('msg.starting_tor'), *HEADER_OUTPUT)) - control_port = '9051' if args.control_port == 'default' else str(args.control_port) - try: stem.process.launch_tor_with_config( config = { 'SocksPort': '0', - 'ControlPort': control_port, + 'ControlPort': '9051' if args.control_port is None else str(args.control_port), 'CookieAuthentication': '1', 'ExitPolicy': 'reject *:*', }, @@ -115,7 +113,7 @@ def main(): control_port = control_port, control_socket = control_socket, password_prompt = True, - ) + ) # type: stem.control.Controller if controller is None: sys.exit(1) @@ -126,7 +124,7 @@ def main(): if args.run_cmd: if args.run_cmd.upper().startswith('SETEVENTS '): - controller._handle_event = lambda event_message: print(format(str(event_message), *STANDARD_OUTPUT)) + controller._handle_event = lambda event_message: print(format(str(event_message), *STANDARD_OUTPUT)) # type: ignore if sys.stdout.isatty(): events = args.run_cmd.upper().split(' ', 1)[1] @@ -135,7 +133,7 @@ def main(): controller.msg(args.run_cmd) try: - raw_input() + input() except (KeyboardInterrupt, stem.SocketClosed): pass else: diff --git a/stem/interpreter/arguments.py b/stem/interpreter/arguments.py index 00c8891de..dd0b19bbd 100644 --- a/stem/interpreter/arguments.py +++ b/stem/interpreter/arguments.py @@ -5,101 +5,102 @@ Commandline argument parsing for our interpreter prompt. """ -import collections import getopt import os import stem.interpreter import stem.util.connection -DEFAULT_ARGS = { - 'control_address': '127.0.0.1', - 'control_port': 'default', - 'user_provided_port': False, - 'control_socket': '/var/run/tor/control', - 'user_provided_socket': False, - 'tor_path': 'tor', - 'run_cmd': None, - 'run_path': None, - 'disable_color': False, - 'print_help': False, -} +from typing import Any, Dict, NamedTuple, Optional, Sequence OPT = 'i:s:h' OPT_EXPANDED = ['interface=', 'socket=', 'tor=', 'run=', 'no-color', 'help'] -def parse(argv): - """ - Parses our arguments, providing a named tuple with their values. - - :param list argv: input arguments to be parsed - - :returns: a **named tuple** with our parsed arguments - - :raises: **ValueError** if we got an invalid argument - """ - - args = dict(DEFAULT_ARGS) - - try: - recognized_args, unrecognized_args = getopt.getopt(argv, OPT, OPT_EXPANDED) - - if unrecognized_args: - error_msg = "aren't recognized arguments" if len(unrecognized_args) > 1 else "isn't a recognized argument" - raise getopt.GetoptError("'%s' %s" % ("', '".join(unrecognized_args), error_msg)) - except Exception as exc: - raise ValueError('%s (for usage provide --help)' % exc) - - for opt, arg in recognized_args: - if opt in ('-i', '--interface'): - if ':' in arg: - address, port = arg.rsplit(':', 1) - else: - address, port = None, arg - - if address is not None: - if not stem.util.connection.is_valid_ipv4_address(address): - raise ValueError("'%s' isn't a valid IPv4 address" % address) - - args['control_address'] = address - - if not stem.util.connection.is_valid_port(port): - raise ValueError("'%s' isn't a valid port number" % port) - - args['control_port'] = int(port) - args['user_provided_port'] = True - elif opt in ('-s', '--socket'): - args['control_socket'] = arg - args['user_provided_socket'] = True - elif opt in ('--tor'): - args['tor_path'] = arg - elif opt in ('--run'): - if os.path.exists(arg): - args['run_path'] = arg - else: - args['run_cmd'] = arg - elif opt == '--no-color': - args['disable_color'] = True - elif opt in ('-h', '--help'): - args['print_help'] = True - - # translates our args dict into a named tuple - - Args = collections.namedtuple('Args', args.keys()) - return Args(**args) - - -def get_help(): - """ - Provides our --help usage information. - - :returns: **str** with our usage information - """ - - return stem.interpreter.msg( - 'msg.help', - address = DEFAULT_ARGS['control_address'], - port = DEFAULT_ARGS['control_port'], - socket = DEFAULT_ARGS['control_socket'], - ) +class Arguments(NamedTuple): + control_address: str = '127.0.0.1' + control_port: Optional[int] = None + user_provided_port: bool = False + control_socket: str = '/var/run/tor/control' + user_provided_socket: bool = False + tor_path: str = 'tor' + run_cmd: Optional[str] = None + run_path: Optional[str] = None + disable_color: bool = False + print_help: bool = False + + @staticmethod + def parse(argv: Sequence[str]) -> 'stem.interpreter.arguments.Arguments': + """ + Parses our commandline arguments into this class. + + :param list argv: input arguments to be parsed + + :returns: :class:`stem.interpreter.arguments.Arguments` for this + commandline input + + :raises: **ValueError** if we got an invalid argument + """ + + args = {} # type: Dict[str, Any] + + try: + recognized_args, unrecognized_args = getopt.getopt(argv, OPT, OPT_EXPANDED) # type: ignore + + if unrecognized_args: + error_msg = "aren't recognized arguments" if len(unrecognized_args) > 1 else "isn't a recognized argument" + raise getopt.GetoptError("'%s' %s" % ("', '".join(unrecognized_args), error_msg)) + except Exception as exc: + raise ValueError('%s (for usage provide --help)' % exc) + + for opt, arg in recognized_args: + if opt in ('-i', '--interface'): + if ':' in arg: + address, port = arg.rsplit(':', 1) + else: + address, port = None, arg + + if address is not None: + if not stem.util.connection.is_valid_ipv4_address(address): + raise ValueError("'%s' isn't a valid IPv4 address" % address) + + args['control_address'] = address + + if not stem.util.connection.is_valid_port(port): + raise ValueError("'%s' isn't a valid port number" % port) + + args['control_port'] = int(port) + args['user_provided_port'] = True + elif opt in ('-s', '--socket'): + args['control_socket'] = arg + args['user_provided_socket'] = True + elif opt in ('--tor'): + args['tor_path'] = arg + elif opt in ('--run'): + if os.path.exists(arg): + args['run_path'] = arg + else: + args['run_cmd'] = arg + elif opt == '--no-color': + args['disable_color'] = True + elif opt in ('-h', '--help'): + args['print_help'] = True + + return Arguments(**args) + + @staticmethod + def get_help() -> str: + """ + Provides our --help usage information. + + :returns: **str** with our usage information + """ + + defaults = Arguments() + + return stem.interpreter.msg( + 'msg.help', + address = defaults.control_address, + port = defaults.control_port if defaults.control_port else 'default', + socket = defaults.control_socket, + ) diff --git a/stem/interpreter/autocomplete.py b/stem/interpreter/autocomplete.py index 9f5f2659b..e310ed283 100644 --- a/stem/interpreter/autocomplete.py +++ b/stem/interpreter/autocomplete.py @@ -7,11 +7,15 @@ import functools +import stem.control +import stem.util.conf + from stem.interpreter import uses_settings +from typing import List, Optional @uses_settings -def _get_commands(controller, config): +def _get_commands(controller: stem.control.Controller, config: stem.util.conf.Config) -> List[str]: """ Provides commands recognized by tor. """ @@ -76,11 +80,11 @@ def _get_commands(controller, config): class Autocompleter(object): - def __init__(self, controller): + def __init__(self, controller: stem.control.Controller) -> None: self._commands = _get_commands(controller) @functools.lru_cache() - def matches(self, text): + def matches(self, text: str) -> List[str]: """ Provides autocompletion matches for the given text. @@ -92,7 +96,7 @@ def matches(self, text): lowercase_text = text.lower() return [cmd for cmd in self._commands if cmd.lower().startswith(lowercase_text)] - def complete(self, text, state): + def complete(self, text: str, state: int) -> Optional[str]: """ Provides case insensetive autocompletion options, acting as a functor for the readlines set_completer function. diff --git a/stem/interpreter/commands.py b/stem/interpreter/commands.py index 6e61fddab..254e46a12 100644 --- a/stem/interpreter/commands.py +++ b/stem/interpreter/commands.py @@ -21,11 +21,12 @@ from stem.interpreter import STANDARD_OUTPUT, BOLD_OUTPUT, ERROR_OUTPUT, uses_settings, msg from stem.util.term import format +from typing import Iterator, List, TextIO MAX_EVENTS = 100 -def _get_fingerprint(arg, controller): +def _get_fingerprint(arg: str, controller: stem.control.Controller) -> str: """ Resolves user input into a relay fingerprint. This accepts... @@ -90,7 +91,7 @@ def _get_fingerprint(arg, controller): @contextlib.contextmanager -def redirect(stdout, stderr): +def redirect(stdout: TextIO, stderr: TextIO) -> Iterator[None]: original = sys.stdout, sys.stderr sys.stdout, sys.stderr = stdout, stderr @@ -106,8 +107,8 @@ class ControlInterpreter(code.InteractiveConsole): for special irc style subcommands. """ - def __init__(self, controller): - self._received_events = [] + def __init__(self, controller: stem.control.Controller) -> None: + self._received_events = [] # type: List[stem.response.events.Event] code.InteractiveConsole.__init__(self, { 'stem': stem, @@ -129,25 +130,26 @@ def __init__(self, controller): handle_event_real = self._controller._handle_event - def handle_event_wrapper(event_message): + def handle_event_wrapper(event_message: stem.response.ControlMessage) -> None: handle_event_real(event_message) - self._received_events.insert(0, event_message) + self._received_events.insert(0, event_message) # type: ignore if len(self._received_events) > MAX_EVENTS: self._received_events.pop() - self._controller._handle_event = handle_event_wrapper + # type check disabled due to https://github.com/python/mypy/issues/708 - def get_events(self, *event_types): + self._controller._handle_event = handle_event_wrapper # type: ignore + + def get_events(self, *event_types: stem.control.EventType) -> List[stem.response.events.Event]: events = list(self._received_events) - event_types = list(map(str.upper, event_types)) # make filtering case insensitive if event_types: events = [e for e in events if e.type in event_types] return events - def do_help(self, arg): + def do_help(self, arg: str) -> str: """ Performs the '/help' operation, giving usage information for the given argument or a general summary if there wasn't one. @@ -155,7 +157,7 @@ def do_help(self, arg): return stem.interpreter.help.response(self._controller, arg) - def do_events(self, arg): + def do_events(self, arg: str) -> str: """ Performs the '/events' operation, dumping the events that we've received belonging to the given types. If no types are specified then this provides @@ -173,7 +175,7 @@ def do_events(self, arg): return '\n'.join([format(str(e), *STANDARD_OUTPUT) for e in self.get_events(*event_types)]) - def do_info(self, arg): + def do_info(self, arg: str) -> str: """ Performs the '/info' operation, looking up a relay by fingerprint, IP address, or nickname and printing its descriptor and consensus entries in a @@ -271,7 +273,7 @@ def do_info(self, arg): return '\n'.join(lines) - def do_python(self, arg): + def do_python(self, arg: str) -> str: """ Performs the '/python' operation, toggling if we accept python commands or not. @@ -295,17 +297,15 @@ def do_python(self, arg): return format(response, *STANDARD_OUTPUT) @uses_settings - def run_command(self, command, config, print_response = False): + def run_command(self, command: str, config: stem.util.conf.Config, print_response: bool = False) -> str: """ Runs the given command. Requests starting with a '/' are special commands to the interpreter, and anything else is sent to the control port. - :param stem.control.Controller controller: tor control connection :param str command: command to be processed :param bool print_response: prints the response to stdout if true - :returns: **list** out output lines, each line being a list of - (msg, format) tuples + :returns: **str** output of the command :raises: **stem.SocketClosed** if the control connection has been severed """ @@ -363,7 +363,7 @@ def run_command(self, command, config, print_response = False): output = console_output.getvalue().strip() else: try: - output = format(self._controller.msg(command).raw_content().strip(), *STANDARD_OUTPUT) + output = format(str(self._controller.msg(command).raw_content()).strip(), *STANDARD_OUTPUT) except stem.ControllerError as exc: if isinstance(exc, stem.SocketClosed): raise diff --git a/stem/interpreter/help.py b/stem/interpreter/help.py index 1f242a8ea..3a206c352 100644 --- a/stem/interpreter/help.py +++ b/stem/interpreter/help.py @@ -7,6 +7,11 @@ import functools +import stem.control +import stem.util.conf + +from stem.util.term import format + from stem.interpreter import ( STANDARD_OUTPUT, BOLD_OUTPUT, @@ -15,10 +20,8 @@ uses_settings, ) -from stem.util.term import format - -def response(controller, arg): +def response(controller: stem.control.Controller, arg: str) -> str: """ Provides our /help response. @@ -33,7 +36,7 @@ def response(controller, arg): return _response(controller, _normalize(arg)) -def _normalize(arg): +def _normalize(arg: str) -> str: arg = arg.upper() # If there's multiple arguments then just take the first. This is @@ -52,7 +55,7 @@ def _normalize(arg): @functools.lru_cache() @uses_settings -def _response(controller, arg, config): +def _response(controller: stem.control.Controller, arg: str, config: stem.util.conf.Config) -> str: if not arg: return _general_help() @@ -126,7 +129,7 @@ def _response(controller, arg, config): return output.rstrip() -def _general_help(): +def _general_help() -> str: lines = [] for line in msg('help.general').splitlines(): diff --git a/stem/manual.py b/stem/manual.py index 367b6d7eb..9bc10b852 100644 --- a/stem/manual.py +++ b/stem/manual.py @@ -61,8 +61,11 @@ import stem.util.conf import stem.util.enum import stem.util.log +import stem.util.str_tools import stem.util.system +from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Union + Category = stem.util.enum.Enum('GENERAL', 'CLIENT', 'RELAY', 'DIRECTORY', 'AUTHORITY', 'HIDDEN_SERVICE', 'DENIAL_OF_SERVICE', 'TESTING', 'UNKNOWN') GITWEB_MANUAL_URL = 'https://gitweb.torproject.org/tor.git/plain/doc/tor.1.txt' CACHE_PATH = os.path.join(os.path.dirname(__file__), 'cached_manual.sqlite') @@ -103,13 +106,13 @@ class SchemaMismatch(IOError): :var tuple supported_schemas: schemas library supports """ - def __init__(self, message, database_schema, library_schema): + def __init__(self, message: str, database_schema: int, supported_schemas: Tuple[int]) -> None: super(SchemaMismatch, self).__init__(message) self.database_schema = database_schema - self.library_schema = library_schema + self.supported_schemas = supported_schemas -def query(query, *param): +def query(query: str, *param: str) -> 'sqlite3.Cursor': # type: ignore """ Performs the given query on our sqlite manual cache. This database should be treated as being read-only. File permissions generally enforce this, and @@ -162,25 +165,25 @@ class ConfigOption(object): :var str description: longer manual description with details """ - def __init__(self, name, category = Category.UNKNOWN, usage = '', summary = '', description = ''): + def __init__(self, name: str, category: 'stem.manual.Category' = Category.UNKNOWN, usage: str = '', summary: str = '', description: str = '') -> None: self.name = name self.category = category self.usage = usage self.summary = summary self.description = description - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'name', 'category', 'usage', 'summary', 'description', cache = True) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, ConfigOption) else False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other @functools.lru_cache() -def _config(lowercase = True): +def _config(lowercase: bool = True) -> Dict[str, Union[List[str], str]]: """ Provides a dictionary for our settings.cfg. This has a couple categories... @@ -204,7 +207,7 @@ def _config(lowercase = True): return {} -def _manual_differences(previous_manual, new_manual): +def _manual_differences(previous_manual: 'stem.manual.Manual', new_manual: 'stem.manual.Manual') -> str: """ Provides a description of how two manuals differ. """ @@ -249,7 +252,7 @@ def _manual_differences(previous_manual, new_manual): return '\n'.join(lines) -def is_important(option): +def is_important(option: str) -> bool: """ Indicates if a configuration option of particularly common importance or not. @@ -262,7 +265,7 @@ def is_important(option): return option.lower() in _config()['manual.important'] -def download_man_page(path = None, file_handle = None, url = GITWEB_MANUAL_URL, timeout = 20): +def download_man_page(path: Optional[str] = None, file_handle: Optional[IO[bytes]] = None, url: str = GITWEB_MANUAL_URL, timeout: int = 20) -> None: """ Downloads tor's latest man page from `gitweb.torproject.org `_. This method is @@ -301,7 +304,7 @@ def download_man_page(path = None, file_handle = None, url = GITWEB_MANUAL_URL, if not os.path.exists(manual_path): raise OSError('no man page was generated') except stem.util.system.CallError as exc: - raise IOError("Unable to run '%s': %s" % (exc.command, exc.stderr)) + raise IOError("Unable to run '%s': %s" % (exc.command, stem.util.str_tools._to_unicode(exc.stderr))) if path: try: @@ -347,7 +350,7 @@ class Manual(object): :var str stem_commit: stem commit to cache this manual information """ - def __init__(self, name, synopsis, description, commandline_options, signals, files, config_options): + def __init__(self, name: str, synopsis: str, description: str, commandline_options: Mapping[str, str], signals: Mapping[str, str], files: Mapping[str, str], config_options: Mapping[str, 'stem.manual.ConfigOption']) -> None: self.name = name self.synopsis = synopsis self.description = description @@ -360,7 +363,7 @@ def __init__(self, name, synopsis, description, commandline_options, signals, fi self.schema = None @staticmethod - def from_cache(path = None): + def from_cache(path: Optional[str] = None) -> 'stem.manual.Manual': """ Provides manual information cached with Stem. Unlike :func:`~stem.manual.Manual.from_man` and @@ -424,7 +427,7 @@ def from_cache(path = None): return manual @staticmethod - def from_man(man_path = 'tor'): + def from_man(man_path: str = 'tor') -> 'stem.manual.Manual': """ Reads and parses a given man page. @@ -447,7 +450,8 @@ def from_man(man_path = 'tor'): except OSError as exc: raise IOError("Unable to run '%s': %s" % (man_cmd, exc)) - categories, config_options = _get_categories(man_output), collections.OrderedDict() + categories = _get_categories(man_output) + config_options = collections.OrderedDict() # type: collections.OrderedDict[str, stem.manual.ConfigOption] for category_header, category_enum in CATEGORY_SECTIONS.items(): _add_config_options(config_options, category_enum, categories.get(category_header, [])) @@ -467,7 +471,7 @@ def from_man(man_path = 'tor'): ) @staticmethod - def from_remote(timeout = 60): + def from_remote(timeout: int = 60) -> 'stem.manual.Manual': """ Reads and parses the latest tor man page `from gitweb.torproject.org `_. Note that @@ -500,7 +504,7 @@ def from_remote(timeout = 60): download_man_page(file_handle = tmp, timeout = timeout) return Manual.from_man(tmp.name) - def save(self, path): + def save(self, path: str) -> None: """ Persists the manual content to a given location. @@ -549,17 +553,17 @@ def save(self, path): os.rename(tmp_path, path) - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'name', 'synopsis', 'description', 'commandline_options', 'signals', 'files', 'config_options', cache = True) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, Manual) else False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other -def _get_categories(content): +def _get_categories(content: Sequence[str]) -> Dict[str, List[str]]: """ The man page is headers followed by an indented section. First pass gets the mapping of category titles to their lines. @@ -574,7 +578,8 @@ def _get_categories(content): content = content[:-1] categories = collections.OrderedDict() - category, lines = None, [] + category = None + lines = [] # type: List[str] for line in content: # replace non-ascii characters @@ -605,7 +610,7 @@ def _get_categories(content): return categories -def _get_indented_descriptions(lines): +def _get_indented_descriptions(lines: Sequence[str]) -> Dict[str, str]: """ Parses the commandline argument and signal sections. These are options followed by an indented description. For example... @@ -622,7 +627,8 @@ def _get_indented_descriptions(lines): ignoring those. """ - options, last_arg = collections.OrderedDict(), None + options = collections.OrderedDict() # type: collections.OrderedDict[str, List[str]] + last_arg = None for line in lines: if line == ' Note': @@ -635,7 +641,7 @@ def _get_indented_descriptions(lines): return dict([(arg, ' '.join(desc_lines)) for arg, desc_lines in options.items() if desc_lines]) -def _add_config_options(config_options, category, lines): +def _add_config_options(config_options: Dict[str, 'stem.manual.ConfigOption'], category: str, lines: Sequence[str]) -> None: """ Parses a section of tor configuration options. These have usage information, followed by an indented description. For instance... @@ -653,7 +659,7 @@ def _add_config_options(config_options, category, lines): since that platform lacks getrlimit(). (Default: 1000) """ - def add_option(title, description): + def add_option(title: str, description: List[str]) -> None: if 'PER INSTANCE OPTIONS' in title: return # skip, unfortunately amid the options @@ -667,7 +673,7 @@ def add_option(title, description): add_option(subtitle, description) else: name, usage = title.split(' ', 1) if ' ' in title else (title, '') - summary = _config().get('manual.summary.%s' % name.lower(), '') + summary = str(_config().get('manual.summary.%s' % name.lower(), '')) config_options[name] = ConfigOption(name, category, usage, summary, _join_lines(description).strip()) # Remove the section's description by finding the sentence the section @@ -679,7 +685,8 @@ def add_option(title, description): lines = lines[max(end_indices):] # trim to the description paragrah lines = lines[lines.index(''):] # drop the paragraph - last_title, description = None, [] + last_title = None + description = [] # type: List[str] for line in lines: if line and not line.startswith(' '): @@ -697,12 +704,12 @@ def add_option(title, description): add_option(last_title, description) -def _join_lines(lines): +def _join_lines(lines: Sequence[str]) -> str: """ Simple join, except we want empty lines to still provide a newline. """ - result = [] + result = [] # type: List[str] for line in lines: if not line: diff --git a/stem/process.py b/stem/process.py index a1d805ec2..3c7688a57 100644 --- a/stem/process.py +++ b/stem/process.py @@ -29,11 +29,13 @@ import stem.util.system import stem.version +from typing import Any, Callable, Dict, Optional, Sequence, Union + NO_TORRC = '' DEFAULT_INIT_TIMEOUT = 90 -def launch_tor(tor_cmd = 'tor', args = None, torrc_path = None, completion_percent = 100, init_msg_handler = None, timeout = DEFAULT_INIT_TIMEOUT, take_ownership = False, close_output = True, stdin = None): +def launch_tor(tor_cmd: str = 'tor', args: Optional[Sequence[str]] = None, torrc_path: Optional[str] = None, completion_percent: int = 100, init_msg_handler: Optional[Callable[[str], None]] = None, timeout: int = DEFAULT_INIT_TIMEOUT, take_ownership: bool = False, close_output: bool = True, stdin: Optional[str] = None) -> subprocess.Popen: """ Initializes a tor process. This blocks until initialization completes or we error out. @@ -131,7 +133,7 @@ def launch_tor(tor_cmd = 'tor', args = None, torrc_path = None, completion_perce tor_process.stdin.close() if timeout: - def timeout_handler(signum, frame): + def timeout_handler(signum: int, frame: Any) -> None: raise OSError('reached a %i second timeout without success' % timeout) signal.signal(signal.SIGALRM, timeout_handler) @@ -197,7 +199,7 @@ def timeout_handler(signum, frame): pass -def launch_tor_with_config(config, tor_cmd = 'tor', completion_percent = 100, init_msg_handler = None, timeout = DEFAULT_INIT_TIMEOUT, take_ownership = False, close_output = True): +def launch_tor_with_config(config: Dict[str, Union[str, Sequence[str]]], tor_cmd: str = 'tor', completion_percent: int = 100, init_msg_handler: Optional[Callable[[str], None]] = None, timeout: int = DEFAULT_INIT_TIMEOUT, take_ownership: bool = False, close_output: bool = True) -> subprocess.Popen: """ Initializes a tor process, like :func:`~stem.process.launch_tor`, but with a customized configuration. This writes a temporary torrc to disk, launches @@ -258,7 +260,7 @@ def launch_tor_with_config(config, tor_cmd = 'tor', completion_percent = 100, in break if not has_stdout: - config['Log'].append('NOTICE stdout') + config['Log'] = list(config['Log']) + ['NOTICE stdout'] config_str = '' diff --git a/stem/response/__init__.py b/stem/response/__init__.py index 2fbb9c489..2f851389e 100644 --- a/stem/response/__init__.py +++ b/stem/response/__init__.py @@ -38,6 +38,8 @@ import stem.util import stem.util.str_tools +from typing import Any, Iterator, List, Optional, Sequence, Tuple, Union + __all__ = [ 'add_onion', 'events', @@ -54,7 +56,7 @@ KEY_ARG = re.compile('^(\\S+)=') -def convert(response_type, message, **kwargs): +def convert(response_type: str, message: 'stem.response.ControlMessage', **kwargs: Any) -> None: """ Converts a :class:`~stem.response.ControlMessage` into a particular kind of tor response. This does an in-place conversion of the message from being a @@ -121,7 +123,40 @@ def convert(response_type, message, **kwargs): raise TypeError('Unsupported response type: %s' % response_type) message.__class__ = response_class - message._parse_message(**kwargs) + message._parse_message(**kwargs) # type: ignore + + +# TODO: These aliases are for type hint compatability. We should refactor how +# message conversion is performed to avoid this headache. + +def _convert_to_single_line(message: 'stem.response.ControlMessage', **kwargs: Any) -> 'stem.response.SingleLineResponse': + stem.response.convert('SINGLELINE', message) + return message # type: ignore + + +def _convert_to_event(message: 'stem.response.ControlMessage', **kwargs: Any) -> 'stem.response.events.Event': + stem.response.convert('EVENT', message) + return message # type: ignore + + +def _convert_to_getinfo(message: 'stem.response.ControlMessage', **kwargs: Any) -> 'stem.response.getinfo.GetInfoResponse': + stem.response.convert('GETINFO', message) + return message # type: ignore + + +def _convert_to_getconf(message: 'stem.response.ControlMessage', **kwargs: Any) -> 'stem.response.getconf.GetConfResponse': + stem.response.convert('GETCONF', message) + return message # type: ignore + + +def _convert_to_add_onion(message: 'stem.response.ControlMessage', **kwargs: Any) -> 'stem.response.add_onion.AddOnionResponse': + stem.response.convert('ADD_ONION', message) + return message # type: ignore + + +def _convert_to_mapaddress(message: 'stem.response.ControlMessage', **kwargs: Any) -> 'stem.response.mapaddress.MapAddressResponse': + stem.response.convert('MAPADDRESS', message) + return message # type: ignore class ControlMessage(object): @@ -140,7 +175,7 @@ class ControlMessage(object): """ @staticmethod - def from_str(content, msg_type = None, normalize = False, **kwargs): + def from_str(content: Union[str, bytes], msg_type: Optional[str] = None, normalize: bool = False, **kwargs: Any) -> 'stem.response.ControlMessage': """ Provides a ControlMessage for the given content. @@ -158,31 +193,38 @@ def from_str(content, msg_type = None, normalize = False, **kwargs): :returns: stem.response.ControlMessage instance """ + if isinstance(content, str): + content = stem.util.str_tools._to_bytes(content) + if normalize: - if not content.endswith('\n'): - content += '\n' + if not content.endswith(b'\n'): + content += b'\n' - content = re.sub('([\r]?)\n', '\r\n', content) + content = re.sub(b'([\r]?)\n', b'\r\n', content) - msg = stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes(content)), arrived_at = kwargs.pop('arrived_at', None)) + msg = stem.socket.recv_message(io.BytesIO(content), arrived_at = kwargs.pop('arrived_at', None)) if msg_type is not None: convert(msg_type, msg, **kwargs) return msg - def __init__(self, parsed_content, raw_content, arrived_at = None): + def __init__(self, parsed_content: Sequence[Tuple[str, str, bytes]], raw_content: bytes, arrived_at: Optional[float] = None) -> None: if not parsed_content: raise ValueError("ControlMessages can't be empty") - self.arrived_at = arrived_at if arrived_at else int(time.time()) + # TODO: Change arrived_at to a float (can't yet because it causes Event + # equality checks to fail - events include arrived_at within their hash + # whereas ControlMessages don't). + + self.arrived_at = int(arrived_at if arrived_at else time.time()) self._parsed_content = parsed_content self._raw_content = raw_content - self._str = None + self._str = None # type: Optional[str] self._hash = stem.util._hash_attr(self, '_raw_content') - def is_ok(self): + def is_ok(self) -> bool: """ Checks if any of our lines have a 250 response. @@ -195,7 +237,12 @@ def is_ok(self): return False - def content(self, get_bytes = False): + # TODO: drop this alias when we provide better type support + + def _content_bytes(self) -> List[Tuple[str, str, bytes]]: + return self.content(get_bytes = True) # type: ignore + + def content(self, get_bytes: bool = False) -> List[Tuple[str, str, str]]: """ Provides the parsed message content. These are entries of the form... @@ -232,9 +279,9 @@ def content(self, get_bytes = False): if not get_bytes: return [(code, div, stem.util.str_tools._to_unicode(content)) for (code, div, content) in self._parsed_content] else: - return list(self._parsed_content) + return list(self._parsed_content) # type: ignore - def raw_content(self, get_bytes = False): + def raw_content(self, get_bytes: bool = False) -> Union[str, bytes]: """ Provides the unparsed content read from the control socket. @@ -251,7 +298,10 @@ def raw_content(self, get_bytes = False): else: return self._raw_content - def __str__(self): + def _parse_message(self) -> None: + raise NotImplementedError('Implemented by subclasses') + + def __str__(self) -> str: """ Content of the message, stripped of status code and divider protocol formatting. @@ -262,7 +312,7 @@ def __str__(self): return self._str - def __iter__(self): + def __iter__(self) -> Iterator['stem.response.ControlLine']: """ Provides :class:`~stem.response.ControlLine` instances for the content of the message. This is stripped of status codes and dividers, for instance... @@ -286,18 +336,16 @@ def __iter__(self): """ for _, _, content in self._parsed_content: - content = stem.util.str_tools._to_unicode(content) + yield ControlLine(stem.util.str_tools._to_unicode(content)) - yield ControlLine(content) - - def __len__(self): + def __len__(self) -> int: """ :returns: number of ControlLines """ return len(self._parsed_content) - def __getitem__(self, index): + def __getitem__(self, index: int) -> 'stem.response.ControlLine': """ :returns: :class:`~stem.response.ControlLine` at the index """ @@ -307,13 +355,13 @@ def __getitem__(self, index): return ControlLine(content) - def __hash__(self): + def __hash__(self) -> int: return self._hash - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, ControlMessage) else False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other @@ -327,14 +375,14 @@ class ControlLine(str): immutable). All methods are thread safe. """ - def __new__(self, value): - return str.__new__(self, value) + def __new__(self, value: str) -> 'stem.response.ControlLine': + return str.__new__(self, value) # type: ignore - def __init__(self, value): + def __init__(self, value: str) -> None: self._remainder = value self._remainder_lock = threading.RLock() - def remainder(self): + def remainder(self) -> str: """ Provides our unparsed content. This is an empty string after we've popped all entries. @@ -344,7 +392,7 @@ def remainder(self): return self._remainder - def is_empty(self): + def is_empty(self) -> bool: """ Checks if we have further content to pop or not. @@ -353,7 +401,7 @@ def is_empty(self): return self._remainder == '' - def is_next_quoted(self, escaped = False): + def is_next_quoted(self, escaped: bool = False) -> bool: """ Checks if our next entry is a quoted value or not. @@ -365,7 +413,7 @@ def is_next_quoted(self, escaped = False): start_quote, end_quote = _get_quote_indices(self._remainder, escaped) return start_quote == 0 and end_quote != -1 - def is_next_mapping(self, key = None, quoted = False, escaped = False): + def is_next_mapping(self, key: Optional[str] = None, quoted: bool = False, escaped: bool = False) -> bool: """ Checks if our next entry is a KEY=VALUE mapping or not. @@ -393,7 +441,7 @@ def is_next_mapping(self, key = None, quoted = False, escaped = False): else: return False # doesn't start with a key - def peek_key(self): + def peek_key(self) -> str: """ Provides the key of the next entry, providing **None** if it isn't a key/value mapping. @@ -409,7 +457,7 @@ def peek_key(self): else: return None - def pop(self, quoted = False, escaped = False): + def pop(self, quoted: bool = False, escaped: bool = False) -> str: """ Parses the next space separated entry, removing it and the space from our remaining content. Examples... @@ -441,9 +489,14 @@ def pop(self, quoted = False, escaped = False): with self._remainder_lock: next_entry, remainder = _parse_entry(self._remainder, quoted, escaped, False) self._remainder = remainder - return next_entry + return next_entry # type: ignore + + # TODO: drop this alias when we provide better type support + + def _pop_mapping_bytes(self, quoted: bool = False, escaped: bool = False) -> Tuple[str, bytes]: + return self.pop_mapping(quoted, escaped, get_bytes = True) # type: ignore - def pop_mapping(self, quoted = False, escaped = False, get_bytes = False): + def pop_mapping(self, quoted: bool = False, escaped: bool = False, get_bytes: bool = False) -> Tuple[str, str]: """ Parses the next space separated entry as a KEY=VALUE mapping, removing it and the space from our remaining content. @@ -477,16 +530,17 @@ def pop_mapping(self, quoted = False, escaped = False, get_bytes = False): next_entry, remainder = _parse_entry(remainder, quoted, escaped, get_bytes) self._remainder = remainder - return (key, next_entry) + return (key, next_entry) # type: ignore -def _parse_entry(line, quoted, escaped, get_bytes): +def _parse_entry(line: str, quoted: bool, escaped: bool, get_bytes: bool) -> Tuple[Union[str, bytes], str]: """ Parses the next entry from the given space separated content. :param str line: content to be parsed :param bool quoted: parses the next entry as a quoted value, removing the quotes :param bool escaped: unescapes the string + :param bool get_bytes: provides **bytes** for the entry rather than a **str** :returns: **tuple** of the form (entry, remainder) @@ -529,18 +583,18 @@ def _parse_entry(line, quoted, escaped, get_bytes): # # https://stackoverflow.com/questions/14820429/how-do-i-decodestring-escape-in-python3 - next_entry = codecs.escape_decode(next_entry)[0] + next_entry = codecs.escape_decode(next_entry)[0] # type: ignore if not get_bytes: next_entry = stem.util.str_tools._to_unicode(next_entry) # normalize back to str if get_bytes: - next_entry = stem.util.str_tools._to_bytes(next_entry) - - return (next_entry, remainder.lstrip()) + return (stem.util.str_tools._to_bytes(next_entry), remainder.lstrip()) + else: + return (next_entry, remainder.lstrip()) -def _get_quote_indices(line, escaped): +def _get_quote_indices(line: str, escaped: bool) -> Tuple[int, int]: """ Provides the indices of the next two quotes in the given content. @@ -563,7 +617,7 @@ def _get_quote_indices(line, escaped): indices.append(quote_index) - return tuple(indices) + return tuple(indices) # type: ignore class SingleLineResponse(ControlMessage): @@ -576,7 +630,7 @@ class SingleLineResponse(ControlMessage): :var str message: content of the line """ - def is_ok(self, strict = False): + def is_ok(self, strict: bool = False) -> bool: """ Checks if the response code is "250". If strict is **True** then this checks if the response is "250 OK" @@ -593,7 +647,7 @@ def is_ok(self, strict = False): return self.content()[0][0] == '250' - def _parse_message(self): + def _parse_message(self) -> None: content = self.content() if len(content) > 1: @@ -601,4 +655,7 @@ def _parse_message(self): elif len(content) == 0: raise stem.ProtocolError('Received empty response') else: - self.code, _, self.message = content[0] + code, _, msg = content[0] + + self.code = stem.util.str_tools._to_unicode(code) + self.message = stem.util.str_tools._to_unicode(msg) diff --git a/stem/response/add_onion.py b/stem/response/add_onion.py index 64d582826..3f52f9f21 100644 --- a/stem/response/add_onion.py +++ b/stem/response/add_onion.py @@ -15,7 +15,7 @@ class AddOnionResponse(stem.response.ControlMessage): :var dict client_auth: newly generated client credentials the service accepts """ - def _parse_message(self): + def _parse_message(self) -> None: # Example: # 250-ServiceID=gfzprpioee3hoppz # 250-PrivateKey=RSA1024:MIICXgIBAAKBgQDZvYVxv... diff --git a/stem/response/authchallenge.py b/stem/response/authchallenge.py index d9cc54918..80a1c0f51 100644 --- a/stem/response/authchallenge.py +++ b/stem/response/authchallenge.py @@ -17,7 +17,7 @@ class AuthChallengeResponse(stem.response.ControlMessage): :var str server_nonce: server nonce provided by tor """ - def _parse_message(self): + def _parse_message(self) -> None: # Example: # 250 AUTHCHALLENGE SERVERHASH=680A73C9836C4F557314EA1C4EDE54C285DB9DC89C83627401AEF9D7D27A95D5 SERVERNONCE=F8EA4B1F2C8B40EF1AF68860171605B910E3BBCABADF6FC3DB1FA064F4690E85 diff --git a/stem/response/events.py b/stem/response/events.py index fdd17a257..65419fe6e 100644 --- a/stem/response/events.py +++ b/stem/response/events.py @@ -1,6 +1,9 @@ # Copyright 2012-2020, Damian Johnson and The Tor Project # See LICENSE for licensing information +# +# mypy: ignore-errors +import datetime import io import re @@ -12,6 +15,7 @@ import stem.version from stem.util import connection, log, str_tools, tor_tools +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union # Matches keyword=value arguments. This can't be a simple "(.*)=(.*)" pattern # because some positional arguments, like circuit paths, can have an equal @@ -33,35 +37,39 @@ class Event(stem.response.ControlMessage): :var dict keyword_args: key/value arguments of the event """ - _POSITIONAL_ARGS = () # attribute names for recognized positional arguments - _KEYWORD_ARGS = {} # map of 'keyword => attribute' for recognized attributes - _QUOTED = () # positional arguments that are quoted - _OPTIONALLY_QUOTED = () # positional arguments that may or may not be quoted + # TODO: Replace metaprogramming with concrete implementations (to simplify type information) + # TODO: _QUOTED looks to be unused + + _POSITIONAL_ARGS = () # type: Tuple[str, ...] # attribute names for recognized positional arguments + _KEYWORD_ARGS = {} # type: Dict[str, str] # map of 'keyword => attribute' for recognized attributes + _QUOTED = () # type: Tuple[str, ...] # positional arguments that are quoted + _OPTIONALLY_QUOTED = () # type: Tuple[str, ...] # positional arguments that may or may not be quoted _SKIP_PARSING = False # skip parsing contents into our positional_args and keyword_args _VERSION_ADDED = stem.version.Version('0.1.1.1-alpha') # minimum version with control-spec V1 event support - def _parse_message(self): + def _parse_message(self) -> None: if not str(self).strip(): raise stem.ProtocolError('Received a blank tor event. Events must at the very least have a type.') self.type = str(self).split()[0] - self.positional_args = [] - self.keyword_args = {} + self.positional_args = [] # type: List[str] + self.keyword_args = {} # type: Dict[str, str] # if we're a recognized event type then translate ourselves into that subclass if self.type in EVENT_TYPE_TO_CLASS: self.__class__ = EVENT_TYPE_TO_CLASS[self.type] + self.__init__() # type: ignore if not self._SKIP_PARSING: self._parse_standard_attr() self._parse() - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'arrived_at', parent = stem.response.ControlMessage, cache = True) - def _parse_standard_attr(self): + def _parse_standard_attr(self) -> None: """ Most events are of the form... 650 *( positional_args ) *( key "=" value ) @@ -122,7 +130,7 @@ def _parse_standard_attr(self): for controller_attr_name, attr_name in self._KEYWORD_ARGS.items(): setattr(self, attr_name, self.keyword_args.get(controller_attr_name)) - def _iso_timestamp(self, timestamp): + def _iso_timestamp(self, timestamp: str) -> datetime.datetime: """ Parses an iso timestamp (ISOTime2Frac in the control-spec). @@ -142,10 +150,10 @@ def _iso_timestamp(self, timestamp): raise stem.ProtocolError('Unable to parse timestamp (%s): %s' % (exc, self)) # method overwritten by our subclasses for special handling that they do - def _parse(self): + def _parse(self) -> None: pass - def _log_if_unrecognized(self, attr, attr_enum): + def _log_if_unrecognized(self, attr: str, attr_enum: Union[stem.util.enum.Enum, Sequence[stem.util.enum.Enum]]) -> None: """ Checks if an attribute exists in a given enumeration, logging a message if it isn't. Attributes can either be for a string or collection of strings @@ -194,9 +202,17 @@ class AddrMapEvent(Event): 'EXPIRES': 'utc_expiry', 'CACHED': 'cached', } - _OPTIONALLY_QUOTED = ('expiry') + _OPTIONALLY_QUOTED = ('expiry',) + + def __init__(self): + self.hostname = None # type: Optional[str] + self.destination = None # type: Optional[str] + self.expiry = None # type: Optional[datetime.datetime] + self.error = None # type: Optional[str] + self.utc_expiry = None # type: Optional[datetime.datetime] + self.cached = None # type: Optional[bool] - def _parse(self): + def _parse(self) -> None: if self.destination == '': self.destination = None @@ -234,7 +250,11 @@ class BandwidthEvent(Event): _POSITIONAL_ARGS = ('read', 'written') - def _parse(self): + def __init__(self): + self.read = None # type: Optional[int] + self.written = None # type: Optional[int] + + def _parse(self) -> None: if not self.read: raise stem.ProtocolError('BW event is missing its read value') elif not self.written: @@ -277,7 +297,18 @@ class BuildTimeoutSetEvent(Event): } _VERSION_ADDED = stem.version.Version('0.2.2.7-alpha') - def _parse(self): + def __init__(self): + self.set_type = None # type: Optional[stem.TimeoutSetType] + self.total_times = None # type: Optional[int] + self.timeout = None # type: Optional[int] + self.xm = None # type: Optional[int] + self.alpha = None # type: Optional[float] + self.quantile = None # type: Optional[float] + self.timeout_rate = None # type: Optional[float] + self.close_timeout = None # type: Optional[int] + self.close_rate = None # type: Optional[float] + + def _parse(self) -> None: # convert our integer and float parameters for param in ('total_times', 'timeout', 'xm', 'close_timeout'): @@ -346,7 +377,21 @@ class CircuitEvent(Event): 'SOCKS_PASSWORD': 'socks_password', } - def _parse(self): + def __init__(self): + self.id = None # type: Optional[str] + self.status = None # type: Optional[stem.CircStatus] + self.path = None # type: Optional[Tuple[Tuple[str, str], ...]] + self.build_flags = None # type: Optional[Tuple[stem.CircBuildFlag, ...]] + self.purpose = None # type: Optional[stem.CircPurpose] + self.hs_state = None # type: Optional[stem.HiddenServiceState] + self.rend_query = None # type: Optional[str] + self.created = None # type: Optional[datetime.datetime] + self.reason = None # type: Optional[stem.CircClosureReason] + self.remote_reason = None # type: Optional[stem.CircClosureReason] + self.socks_username = None # type: Optional[str] + self.socks_password = None # type: Optional[str] + + def _parse(self) -> None: self.path = tuple(stem.control._parse_circ_path(self.path)) self.created = self._iso_timestamp(self.created) @@ -363,7 +408,7 @@ def _parse(self): self._log_if_unrecognized('reason', stem.CircClosureReason) self._log_if_unrecognized('remote_reason', stem.CircClosureReason) - def _compare(self, other, method): + def _compare(self, other: Any, method: Any) -> bool: # sorting circuit events by their identifier if not isinstance(other, CircuitEvent): @@ -374,10 +419,10 @@ def _compare(self, other, method): return method(my_id, their_id) if my_id != their_id else method(hash(self), hash(other)) - def __gt__(self, other): + def __gt__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s > o) - def __ge__(self, other): + def __ge__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s >= o) @@ -414,7 +459,19 @@ class CircMinorEvent(Event): } _VERSION_ADDED = stem.version.Version('0.2.3.11-alpha') - def _parse(self): + def __init__(self): + self.id = None # type: Optional[str] + self.event = None # type: Optional[stem.CircEvent] + self.path = None # type: Optional[Tuple[Tuple[str, str], ...]] + self.build_flags = None # type: Optional[Tuple[stem.CircBuildFlag, ...]] + self.purpose = None # type: Optional[stem.CircPurpose] + self.hs_state = None # type: Optional[stem.HiddenServiceState] + self.rend_query = None # type: Optional[str] + self.created = None # type: Optional[datetime.datetime] + self.old_purpose = None # type: Optional[stem.CircPurpose] + self.old_hs_state = None # type: Optional[stem.HiddenServiceState] + + def _parse(self) -> None: self.path = tuple(stem.control._parse_circ_path(self.path)) self.created = self._iso_timestamp(self.created) @@ -450,7 +507,12 @@ class ClientsSeenEvent(Event): } _VERSION_ADDED = stem.version.Version('0.2.1.10-alpha') - def _parse(self): + def __init__(self): + self.start_time = None # type: Optional[datetime.datetime] + self.locales = None # type: Optional[Dict[str, int]] + self.ip_versions = None # type: Optional[Dict[str, int]] + + def _parse(self) -> None: if self.start_time is not None: self.start_time = stem.util.str_tools._parse_timestamp(self.start_time) @@ -509,7 +571,11 @@ class ConfChangedEvent(Event): _SKIP_PARSING = True _VERSION_ADDED = stem.version.Version('0.2.3.3-alpha') - def _parse(self): + def __init__(self): + self.changed = {} # type: Dict[str, List[str]] + self.unset = [] # type: List[str] + + def _parse(self) -> None: self.changed = {} self.unset = [] @@ -540,6 +606,9 @@ class DescChangedEvent(Event): _VERSION_ADDED = stem.version.Version('0.1.2.2-alpha') + def __init__(self): + pass + class GuardEvent(Event): """ @@ -563,10 +632,14 @@ class GuardEvent(Event): _VERSION_ADDED = stem.version.Version('0.1.2.5-alpha') _POSITIONAL_ARGS = ('guard_type', 'endpoint', 'status') - def _parse(self): - self.endpoint_fingerprint = None - self.endpoint_nickname = None + def __init__(self): + self.guard_type = None # type: Optional[stem.GuardType] + self.endpoint = None # type: Optional[str] + self.endpoint_fingerprint = None # type: Optional[str] + self.endpoint_nickname = None # type: Optional[str] + self.status = None # type: Optional[stem.GuardStatus] + def _parse(self) -> None: try: self.endpoint_fingerprint, self.endpoint_nickname = \ stem.control._parse_circ_entry(self.endpoint) @@ -610,10 +683,19 @@ class HSDescEvent(Event): _POSITIONAL_ARGS = ('action', 'address', 'authentication', 'directory', 'descriptor_id') _KEYWORD_ARGS = {'REASON': 'reason', 'REPLICA': 'replica', 'HSDIR_INDEX': 'index'} - def _parse(self): - self.directory_fingerprint = None - self.directory_nickname = None - + def __init__(self): + self.action = None # type: Optional[stem.HSDescAction] + self.address = None # type: Optional[str] + self.authentication = None # type: Optional[stem.HSAuth] + self.directory = None # type: Optional[str] + self.directory_fingerprint = None # type: Optional[str] + self.directory_nickname = None # type: Optional[str] + self.descriptor_id = None # type: Optional[str] + self.reason = None # type: Optional[stem.HSDescReason] + self.replica = None # type: Optional[int] + self.index = None # type: Optional[str] + + def _parse(self) -> None: if self.directory != 'UNKNOWN': try: self.directory_fingerprint, self.directory_nickname = \ @@ -650,13 +732,18 @@ class HSDescContentEvent(Event): _VERSION_ADDED = stem.version.Version('0.2.7.1-alpha') _POSITIONAL_ARGS = ('address', 'descriptor_id', 'directory') - def _parse(self): + def __init__(self): + self.address = None # type: Optional[str] + self.descriptor_id = None # type: Optional[str] + self.directory = None # type: Optional[str] + self.directory_fingerprint = None # type: Optional[str] + self.directory_nickname = None # type: Optional[str] + self.descriptor = None # type: Optional[stem.descriptor.hidden_service.HiddenServiceDescriptorV2] + + def _parse(self) -> None: if self.address == 'UNKNOWN': self.address = None - self.directory_fingerprint = None - self.directory_nickname = None - try: self.directory_fingerprint, self.directory_nickname = \ stem.control._parse_circ_entry(self.directory) @@ -686,7 +773,11 @@ class LogEvent(Event): _SKIP_PARSING = True - def _parse(self): + def __init__(self): + self.runlevel = None # type: Optional[stem.Runlevel] + self.message = None # type: Optional[str] + + def _parse(self) -> None: self.runlevel = self.type self._log_if_unrecognized('runlevel', stem.Runlevel) @@ -709,7 +800,10 @@ class NetworkStatusEvent(Event): _SKIP_PARSING = True _VERSION_ADDED = stem.version.Version('0.1.2.3-alpha') - def _parse(self): + def __init__(self): + self.descriptors = None # type: Optional[List[stem.descriptor.router_status_entry.RouterStatusEntryV3]] + + def _parse(self) -> None: content = str(self).lstrip('NS\n').rstrip('\nOK') self.descriptors = list(stem.descriptor.router_status_entry._parse_file( @@ -734,6 +828,9 @@ class NetworkLivenessEvent(Event): _VERSION_ADDED = stem.version.Version('0.2.7.2-alpha') _POSITIONAL_ARGS = ('status',) + def __init__(self): + self.status = None # type: Optional[str] + class NewConsensusEvent(Event): """ @@ -753,11 +850,14 @@ class NewConsensusEvent(Event): _SKIP_PARSING = True _VERSION_ADDED = stem.version.Version('0.2.1.13-alpha') - def _parse(self): + def __init__(self): + self.consensus_content = None # type: Optional[str] + self._parsed = None # type: List[stem.descriptor.router_status_entry.RouterStatusEntryV3] + + def _parse(self) -> None: self.consensus_content = str(self).lstrip('NEWCONSENSUS\n').rstrip('\nOK') - self._parsed = None - def entries(self): + def entries(self) -> List[stem.descriptor.router_status_entry.RouterStatusEntryV3]: """ Relay router status entries residing within this consensus. @@ -773,7 +873,7 @@ def entries(self): entry_class = stem.descriptor.router_status_entry.RouterStatusEntryV3, )) - return self._parsed + return list(self._parsed) class NewDescEvent(Event): @@ -791,7 +891,10 @@ class NewDescEvent(Event): new descriptors """ - def _parse(self): + def __init__(self): + self.relays = () # type: Tuple[Tuple[str, str], ...] + + def _parse(self) -> None: self.relays = tuple([stem.control._parse_circ_entry(entry) for entry in str(self).split()[1:]]) @@ -832,12 +935,18 @@ class ORConnEvent(Event): 'ID': 'id', } - def _parse(self): - self.endpoint_fingerprint = None - self.endpoint_nickname = None - self.endpoint_address = None - self.endpoint_port = None - + def __init__(self): + self.id = None # type: Optional[str] + self.endpoint = None # type: Optional[str] + self.endpoint_fingerprint = None # type: Optional[str] + self.endpoint_nickname = None # type: Optional[str] + self.endpoint_address = None # type: Optional[str] + self.endpoint_port = None # type: Optional[int] + self.status = None # type: Optional[stem.ORStatus] + self.reason = None # type: Optional[stem.ORClosureReason] + self.circ_count = None # type: Optional[int] + + def _parse(self) -> None: try: self.endpoint_fingerprint, self.endpoint_nickname = \ stem.control._parse_circ_entry(self.endpoint) @@ -886,7 +995,10 @@ class SignalEvent(Event): _POSITIONAL_ARGS = ('signal',) _VERSION_ADDED = stem.version.Version('0.2.3.1-alpha') - def _parse(self): + def __init__(self): + self.signal = None # type: Optional[stem.Signal] + + def _parse(self) -> None: # log if we recieved an unrecognized signal expected_signals = ( stem.Signal.RELOAD, @@ -918,7 +1030,13 @@ class StatusEvent(Event): _POSITIONAL_ARGS = ('runlevel', 'action') _VERSION_ADDED = stem.version.Version('0.1.2.3-alpha') - def _parse(self): + def __init__(self): + self.status_type = None # type: Optional[stem.StatusType] + self.runlevel = None # type: Optional[stem.Runlevel] + self.action = None # type: Optional[str] + self.arguments = None # type: Optional[Dict[str, str]] + + def _parse(self) -> None: if self.type == 'STATUS_GENERAL': self.status_type = stem.StatusType.GENERAL elif self.type == 'STATUS_CLIENT': @@ -970,7 +1088,22 @@ class StreamEvent(Event): 'PURPOSE': 'purpose', } - def _parse(self): + def __init__(self): + self.id = None # type: Optional[str] + self.status = None # type: Optional[stem.StreamStatus] + self.circ_id = None # type: Optional[str] + self.target = None # type: Optional[str] + self.target_address = None # type: Optional[str] + self.target_port = None # type: Optional[int] + self.reason = None # type: Optional[stem.StreamClosureReason] + self.remote_reason = None # type: Optional[stem.StreamClosureReason] + self.source = None # type: Optional[stem.StreamSource] + self.source_addr = None # type: Optional[str] + self.source_address = None # type: Optional[str] + self.source_port = None # type: Optional[str] + self.purpose = None # type: Optional[stem.StreamPurpose] + + def _parse(self) -> None: if self.target is None: raise stem.ProtocolError("STREAM event didn't have a target: %s" % self) else: @@ -1029,7 +1162,13 @@ class StreamBwEvent(Event): _POSITIONAL_ARGS = ('id', 'written', 'read', 'time') _VERSION_ADDED = stem.version.Version('0.1.2.8-beta') - def _parse(self): + def __init__(self): + self.id = None # type: Optional[str] + self.written = None # type: Optional[int] + self.read = None # type: Optional[int] + self.time = None # type: Optional[datetime.datetime] + + def _parse(self) -> None: if not tor_tools.is_valid_stream_id(self.id): raise stem.ProtocolError("Stream IDs must be one to sixteen alphanumeric characters, got '%s': %s" % (self.id, self)) elif not self.written: @@ -1062,7 +1201,13 @@ class TransportLaunchedEvent(Event): _POSITIONAL_ARGS = ('type', 'name', 'address', 'port') _VERSION_ADDED = stem.version.Version('0.2.5.0-alpha') - def _parse(self): + def __init__(self): + self.type = None # type: Optional[str] + self.name = None # type: Optional[str] + self.address = None # type: Optional[str] + self.port = None # type: Optional[int] + + def _parse(self) -> None: if self.type not in ('server', 'client'): raise stem.ProtocolError("Transport type should either be 'server' or 'client': %s" % self) @@ -1104,7 +1249,13 @@ class attribute with the same name. _VERSION_ADDED = stem.version.Version('0.2.5.2-alpha') - def _parse(self): + def __init__(self): + self.id = None # type: Optional[str] + self.conn_type = None # type: Optional[stem.ConnectionType] + self.read = None # type: Optional[int] + self.written = None # type: Optional[int] + + def _parse(self) -> None: if not self.id: raise stem.ProtocolError('CONN_BW event is missing its id') elif not self.conn_type: @@ -1163,7 +1314,17 @@ class CircuitBandwidthEvent(Event): _VERSION_ADDED = stem.version.Version('0.2.5.2-alpha') - def _parse(self): + def __init__(self): + self.id = None # type: Optional[str] + self.read = None # type: Optional[int] + self.written = None # type: Optional[int] + self.delivered_read = None # type: Optional[int] + self.delivered_written = None # type: Optional[int] + self.overhead_read = None # type: Optional[int] + self.overhead_written = None # type: Optional[int] + self.time = None # type: Optional[datetime.datetime] + + def _parse(self) -> None: if not self.id: raise stem.ProtocolError('CIRC_BW event is missing its id') elif not self.read: @@ -1233,7 +1394,20 @@ class CellStatsEvent(Event): _VERSION_ADDED = stem.version.Version('0.2.5.2-alpha') - def _parse(self): + def __init__(self): + self.id = None # type: Optional[str] + self.inbound_queue = None # type: Optional[str] + self.inbound_connection = None # type: Optional[str] + self.inbound_added = None # type: Optional[Dict[str, int]] + self.inbound_removed = None # type: Optional[Dict[str, int]] + self.inbound_time = None # type: Optional[Dict[str, int]] + self.outbound_queue = None # type: Optional[str] + self.outbound_connection = None # type: Optional[str] + self.outbound_added = None # type: Optional[Dict[str, int]] + self.outbound_removed = None # type: Optional[Dict[str, int]] + self.outbound_time = None # type: Optional[Dict[str, int]] + + def _parse(self) -> None: if self.id and not tor_tools.is_valid_circuit_id(self.id): raise stem.ProtocolError("Circuit IDs must be one to sixteen alphanumeric characters, got '%s': %s" % (self.id, self)) elif self.inbound_queue and not tor_tools.is_valid_circuit_id(self.inbound_queue): @@ -1279,7 +1453,14 @@ class TokenBucketEmptyEvent(Event): _VERSION_ADDED = stem.version.Version('0.2.5.2-alpha') - def _parse(self): + def __init__(self): + self.bucket = None # type: Optional[stem.TokenBucket] + self.id = None # type: Optional[str] + self.read = None # type: Optional[int] + self.written = None # type: Optional[int] + self.last_refill = None # type: Optional[int] + + def _parse(self) -> None: if self.id and not tor_tools.is_valid_connection_id(self.id): raise stem.ProtocolError("Connection IDs must be one to sixteen alphanumeric characters, got '%s': %s" % (self.id, self)) elif not self.read.isdigit(): @@ -1296,7 +1477,7 @@ def _parse(self): self._log_if_unrecognized('bucket', stem.TokenBucket) -def _parse_cell_type_mapping(mapping): +def _parse_cell_type_mapping(mapping: str) -> Dict[str, int]: """ Parses a mapping of the form... diff --git a/stem/response/getconf.py b/stem/response/getconf.py index 6de49b1f6..6c65c4ec4 100644 --- a/stem/response/getconf.py +++ b/stem/response/getconf.py @@ -4,6 +4,8 @@ import stem.response import stem.socket +from typing import Dict, List + class GetConfResponse(stem.response.ControlMessage): """ @@ -16,14 +18,14 @@ class GetConfResponse(stem.response.ControlMessage): values (**list** of **str**) """ - def _parse_message(self): + def _parse_message(self) -> None: # Example: # 250-CookieAuthentication=0 # 250-ControlPort=9100 # 250-DataDirectory=/home/neena/.tor # 250 DirPort - self.entries = {} + self.entries = {} # type: Dict[str, List[str]] remaining_lines = list(self) if self.content() == [('250', ' ', 'OK')]: diff --git a/stem/response/getinfo.py b/stem/response/getinfo.py index 27442ffd9..9d9da21bc 100644 --- a/stem/response/getinfo.py +++ b/stem/response/getinfo.py @@ -4,6 +4,8 @@ import stem.response import stem.socket +from typing import Dict, Set + class GetInfoResponse(stem.response.ControlMessage): """ @@ -12,7 +14,7 @@ class GetInfoResponse(stem.response.ControlMessage): :var dict entries: mapping between the queried options and their bytes values """ - def _parse_message(self): + def _parse_message(self) -> None: # Example: # 250-version=0.2.3.11-alpha-dev (git-ef0bc7f8f26a917c) # 250+config-text= @@ -25,8 +27,8 @@ def _parse_message(self): # . # 250 OK - self.entries = {} - remaining_lines = [content for (code, div, content) in self.content(get_bytes = True)] + self.entries = {} # type: Dict[str, bytes] + remaining_lines = [content for (code, div, content) in self._content_bytes()] if not self.is_ok() or not remaining_lines.pop() == b'OK': unrecognized_keywords = [] @@ -49,11 +51,11 @@ def _parse_message(self): while remaining_lines: try: - key, value = remaining_lines.pop(0).split(b'=', 1) + key_bytes, value = remaining_lines.pop(0).split(b'=', 1) except ValueError: raise stem.ProtocolError('GETINFO replies should only contain parameter=value mappings:\n%s' % self) - key = stem.util.str_tools._to_unicode(key) + key = stem.util.str_tools._to_unicode(key_bytes) # if the value is a multiline value then it *must* be of the form # '=\n' @@ -66,7 +68,7 @@ def _parse_message(self): self.entries[key] = value - def _assert_matches(self, params): + def _assert_matches(self, params: Set[str]) -> None: """ Checks if we match a given set of parameters, and raise a ProtocolError if not. diff --git a/stem/response/mapaddress.py b/stem/response/mapaddress.py index 73ed84f17..92ce16d21 100644 --- a/stem/response/mapaddress.py +++ b/stem/response/mapaddress.py @@ -17,7 +17,7 @@ class MapAddressResponse(stem.response.ControlMessage): * :class:`stem.InvalidRequest` if the addresses provided were invalid """ - def _parse_message(self): + def _parse_message(self) -> None: # Example: # 250-127.192.10.10=torproject.org # 250 1.2.3.4=tor.freehaven.net diff --git a/stem/response/protocolinfo.py b/stem/response/protocolinfo.py index 459fef5b4..c1387fab7 100644 --- a/stem/response/protocolinfo.py +++ b/stem/response/protocolinfo.py @@ -8,8 +8,8 @@ import stem.version import stem.util.str_tools -from stem.connection import AuthMethod from stem.util import log +from typing import Tuple class ProtocolInfoResponse(stem.response.ControlMessage): @@ -26,17 +26,19 @@ class ProtocolInfoResponse(stem.response.ControlMessage): :var str cookie_path: path of tor's authentication cookie """ - def _parse_message(self): + def _parse_message(self) -> None: # Example: # 250-PROTOCOLINFO 1 # 250-AUTH METHODS=COOKIE COOKIEFILE="/home/atagar/.tor/control_auth_cookie" # 250-VERSION Tor="0.2.1.30" # 250 OK + from stem.connection import AuthMethod + self.protocol_version = None self.tor_version = None - self.auth_methods = () - self.unknown_auth_methods = () + self.auth_methods = () # type: Tuple[stem.connection.AuthMethod, ...] + self.unknown_auth_methods = () # type: Tuple[str, ...] self.cookie_path = None auth_methods, unknown_auth_methods = [], [] @@ -106,7 +108,7 @@ def _parse_message(self): # parse optional COOKIEFILE mapping (quoted and can have escapes) if line.is_next_mapping('COOKIEFILE', True, True): - self.cookie_path = line.pop_mapping(True, True, get_bytes = True)[1].decode(sys.getfilesystemencoding()) + self.cookie_path = line._pop_mapping_bytes(True, True)[1].decode(sys.getfilesystemencoding()) self.cookie_path = stem.util.str_tools._to_unicode(self.cookie_path) # normalize back to str elif line_type == 'VERSION': # Line format: diff --git a/stem/socket.py b/stem/socket.py index db110973f..b5da4b78f 100644 --- a/stem/socket.py +++ b/stem/socket.py @@ -62,8 +62,7 @@ |- is_localhost - returns if the socket is for the local system or not |- connection_time - timestamp when socket last connected or disconnected |- connect - connects a new socket - |- close - shuts down the socket - +- __enter__ / __exit__ - manages socket connection + +- close - shuts down the socket send_message - Writes a message to a control socket. recv_message - Reads a ControlMessage from a control socket. @@ -80,6 +79,8 @@ import stem.util.str_tools from stem.util import log +from types import TracebackType +from typing import BinaryIO, Callable, List, Optional, Tuple, Type, Union, overload MESSAGE_PREFIX = re.compile(b'^[a-zA-Z0-9]{3}[-+ ]') ERROR_MSG = 'Error while receiving a control message (%s): %s' @@ -94,8 +95,9 @@ class BaseSocket(object): Thread safe socket, providing common socket functionality. """ - def __init__(self): - self._socket, self._socket_file = None, None + def __init__(self) -> None: + self._socket = None # type: Optional[Union[socket.socket, ssl.SSLSocket]] + self._socket_file = None # type: Optional[BinaryIO] self._is_alive = False self._connection_time = 0.0 # time when we last connected or disconnected @@ -106,7 +108,7 @@ def __init__(self): self._send_lock = threading.RLock() self._recv_lock = threading.RLock() - def is_alive(self): + def is_alive(self) -> bool: """ Checks if the socket is known to be closed. We won't be aware if it is until we either use it or have explicitily shut it down. @@ -125,7 +127,7 @@ def is_alive(self): return self._is_alive - def is_localhost(self): + def is_localhost(self) -> bool: """ Returns if the connection is for the local system or not. @@ -135,7 +137,7 @@ def is_localhost(self): return False - def connection_time(self): + def connection_time(self) -> float: """ Provides the unix timestamp for when our socket was either connected or disconnected. That is to say, the time we connected if we're currently @@ -149,7 +151,7 @@ def connection_time(self): return self._connection_time - def connect(self): + def connect(self) -> None: """ Connects to a new socket, closing our previous one if we're already attached. @@ -181,7 +183,7 @@ def connect(self): except stem.SocketError: self._connect() # single retry - def close(self): + def close(self) -> None: """ Shuts down the socket. If it's already closed then this is a no-op. """ @@ -217,7 +219,7 @@ def close(self): if is_change: self._close() - def _send(self, message, handler): + def _send(self, message: Union[bytes, str], handler: Callable[[Union[socket.socket, ssl.SSLSocket], BinaryIO, Union[bytes, str]], None]) -> None: """ Send message in a thread safe manner. Handler is expected to be of the form... @@ -241,6 +243,14 @@ def _send(self, message, handler): raise + @overload + def _recv(self, handler: Callable[[ssl.SSLSocket, BinaryIO], bytes]) -> bytes: + ... + + @overload + def _recv(self, handler: Callable[[socket.socket, BinaryIO], stem.response.ControlMessage]) -> stem.response.ControlMessage: + ... + def _recv(self, handler): """ Receives a message in a thread safe manner. Handler is expected to be of the form... @@ -283,7 +293,7 @@ def _recv(self, handler): raise - def _get_send_lock(self): + def _get_send_lock(self) -> threading.RLock: """ The send lock is useful to classes that interact with us at a deep level because it's used to lock :func:`stem.socket.ControlSocket.connect` / @@ -296,27 +306,27 @@ def _get_send_lock(self): return self._send_lock - def __enter__(self): + def __enter__(self) -> 'stem.socket.BaseSocket': return self - def __exit__(self, exit_type, value, traceback): + def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]): self.close() - def _connect(self): + def _connect(self) -> None: """ Connection callback that can be overwritten by subclasses and wrappers. """ pass - def _close(self): + def _close(self) -> None: """ Disconnection callback that can be overwritten by subclasses and wrappers. """ pass - def _make_socket(self): + def _make_socket(self) -> Union[socket.socket, ssl.SSLSocket]: """ Constructs and connects new socket. This is implemented by subclasses. @@ -342,7 +352,7 @@ class RelaySocket(BaseSocket): :var int port: ORPort our socket connects to """ - def __init__(self, address = '127.0.0.1', port = 9050, connect = True): + def __init__(self, address: str = '127.0.0.1', port: int = 9050, connect: bool = True) -> None: """ RelaySocket constructor. @@ -361,7 +371,7 @@ def __init__(self, address = '127.0.0.1', port = 9050, connect = True): if connect: self.connect() - def send(self, message): + def send(self, message: Union[str, bytes]) -> None: """ Sends a message to the relay's ORPort. @@ -374,7 +384,7 @@ def send(self, message): self._send(message, lambda s, sf, msg: _write_to_socket(sf, msg)) - def recv(self, timeout = None): + def recv(self, timeout: Optional[float] = None) -> bytes: """ Receives a message from the relay. @@ -388,26 +398,26 @@ def recv(self, timeout = None): * :class:`stem.SocketClosed` if the socket closes before we receive a complete message """ - def wrapped_recv(s, sf): + def wrapped_recv(s: ssl.SSLSocket, sf: BinaryIO) -> bytes: if timeout is None: - return s.recv() + return s.recv(1024) else: - s.setblocking(0) + s.setblocking(False) s.settimeout(timeout) try: - return s.recv() + return s.recv(1024) except (socket.timeout, ssl.SSLError, ssl.SSLWantReadError): return None finally: - s.setblocking(1) + s.setblocking(True) return self._recv(wrapped_recv) - def is_localhost(self): + def is_localhost(self) -> bool: return self.address == '127.0.0.1' - def _make_socket(self): + def _make_socket(self) -> ssl.SSLSocket: try: relay_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) relay_socket.connect((self.address, self.port)) @@ -426,10 +436,10 @@ class ControlSocket(BaseSocket): which are expected to implement the **_make_socket()** method. """ - def __init__(self): + def __init__(self) -> None: super(ControlSocket, self).__init__() - def send(self, message): + def send(self, message: Union[bytes, str]) -> None: """ Formats and sends a message to the control socket. For more information see the :func:`~stem.socket.send_message` function. @@ -443,7 +453,7 @@ def send(self, message): self._send(message, lambda s, sf, msg: send_message(sf, msg)) - def recv(self): + def recv(self) -> stem.response.ControlMessage: """ Receives a message from the control socket, blocking until we've received one. For more information see the :func:`~stem.socket.recv_message` function. @@ -467,7 +477,7 @@ class ControlPort(ControlSocket): :var int port: ControlPort our socket connects to """ - def __init__(self, address = '127.0.0.1', port = 9051, connect = True): + def __init__(self, address: str = '127.0.0.1', port: int = 9051, connect: bool = True) -> None: """ ControlPort constructor. @@ -486,10 +496,10 @@ def __init__(self, address = '127.0.0.1', port = 9051, connect = True): if connect: self.connect() - def is_localhost(self): + def is_localhost(self) -> bool: return self.address == '127.0.0.1' - def _make_socket(self): + def _make_socket(self) -> socket.socket: try: control_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) control_socket.connect((self.address, self.port)) @@ -506,7 +516,7 @@ class ControlSocketFile(ControlSocket): :var str path: filesystem path of the socket we connect to """ - def __init__(self, path = '/var/run/tor/control', connect = True): + def __init__(self, path: str = '/var/run/tor/control', connect: bool = True) -> None: """ ControlSocketFile constructor. @@ -523,10 +533,10 @@ def __init__(self, path = '/var/run/tor/control', connect = True): if connect: self.connect() - def is_localhost(self): + def is_localhost(self) -> bool: return True - def _make_socket(self): + def _make_socket(self) -> socket.socket: try: control_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) control_socket.connect(self.path) @@ -535,7 +545,7 @@ def _make_socket(self): raise stem.SocketError(exc) -def send_message(control_file, message, raw = False): +def send_message(control_file: BinaryIO, message: Union[bytes, str], raw: bool = False) -> None: """ Sends a message to the control socket, adding the expected formatting for single verses multi-line messages. Neither message type should contain an @@ -567,6 +577,8 @@ def send_message(control_file, message, raw = False): * :class:`stem.SocketClosed` if the socket is known to be shut down """ + message = stem.util.str_tools._to_unicode(message) + if not raw: message = send_formatting(message) @@ -578,7 +590,7 @@ def send_message(control_file, message, raw = False): log.trace('Sent to tor:%s%s' % (msg_div, log_message)) -def _write_to_socket(socket_file, message): +def _write_to_socket(socket_file: BinaryIO, message: Union[str, bytes]) -> None: try: socket_file.write(stem.util.str_tools._to_bytes(message)) socket_file.flush() @@ -601,7 +613,7 @@ def _write_to_socket(socket_file, message): raise stem.SocketClosed('file has been closed') -def recv_message(control_file, arrived_at = None): +def recv_message(control_file: BinaryIO, arrived_at: Optional[float] = None) -> stem.response.ControlMessage: """ Pulls from a control socket until we either have a complete message or encounter a problem. @@ -617,7 +629,9 @@ def recv_message(control_file, arrived_at = None): a complete message """ - parsed_content, raw_content, first_line = None, None, True + parsed_content = [] # type: List[Tuple[str, str, bytes]] + raw_content = bytearray() + first_line = True while True: try: @@ -648,10 +662,10 @@ def recv_message(control_file, arrived_at = None): log.info(ERROR_MSG % ('SocketClosed', 'empty socket content')) raise stem.SocketClosed('Received empty socket content.') elif not MESSAGE_PREFIX.match(line): - log.info(ERROR_MSG % ('ProtocolError', 'malformed status code/divider, "%s"' % log.escape(line))) + log.info(ERROR_MSG % ('ProtocolError', 'malformed status code/divider, "%s"' % log.escape(line.decode('utf-8')))) raise stem.ProtocolError('Badly formatted reply line: beginning is malformed') elif not line.endswith(b'\r\n'): - log.info(ERROR_MSG % ('ProtocolError', 'no CRLF linebreak, "%s"' % log.escape(line))) + log.info(ERROR_MSG % ('ProtocolError', 'no CRLF linebreak, "%s"' % log.escape(line.decode('utf-8')))) raise stem.ProtocolError('All lines should end with CRLF') status_code, divider, content = line[:3], line[3:4], line[4:-2] # strip CRLF off content @@ -690,11 +704,11 @@ def recv_message(control_file, arrived_at = None): line = control_file.readline() raw_content += line except socket.error as exc: - log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content))))) + log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content).decode('utf-8'))))) raise stem.SocketClosed(exc) if not line.endswith(b'\r\n'): - log.info(ERROR_MSG % ('ProtocolError', 'CRLF linebreaks missing from a data reply, "%s"' % log.escape(bytes(raw_content)))) + log.info(ERROR_MSG % ('ProtocolError', 'CRLF linebreaks missing from a data reply, "%s"' % log.escape(bytes(raw_content).decode('utf-8')))) raise stem.ProtocolError('All lines should end with CRLF') elif line == b'.\r\n': break # data block termination @@ -721,7 +735,7 @@ def recv_message(control_file, arrived_at = None): raise stem.ProtocolError("Unrecognized divider type '%s': %s" % (divider, stem.util.str_tools._to_unicode(line))) -def send_formatting(message): +def send_formatting(message: str) -> str: """ Performs the formatting expected from sent control messages. For more information see the :func:`~stem.socket.send_message` function. @@ -750,7 +764,7 @@ def send_formatting(message): return message + '\r\n' -def _log_trace(response): +def _log_trace(response: bytes) -> None: if not log.is_tracing(): return diff --git a/stem/util/__init__.py b/stem/util/__init__.py index e4e081746..498234cdc 100644 --- a/stem/util/__init__.py +++ b/stem/util/__init__.py @@ -7,6 +7,8 @@ import datetime +from typing import Any, Union + __all__ = [ 'conf', 'connection', @@ -43,7 +45,7 @@ HASH_TYPES = True -def _hash_value(val): +def _hash_value(val: Any) -> int: if not HASH_TYPES: my_hash = 0 else: @@ -64,7 +66,7 @@ def _hash_value(val): return my_hash -def datetime_to_unix(timestamp): +def datetime_to_unix(timestamp: 'datetime.datetime') -> float: """ Converts a utc datetime object to a unix timestamp. @@ -78,13 +80,15 @@ def datetime_to_unix(timestamp): return (timestamp - datetime.datetime(1970, 1, 1)).total_seconds() -def _pubkey_bytes(key): +def _pubkey_bytes(key: Union['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey', 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey']) -> bytes: # type: ignore """ Normalizes X25509 and ED25519 keys into their public key bytes. """ - if isinstance(key, (bytes, str)): + if isinstance(key, bytes): return key + elif isinstance(key, str): + return key.encode('utf-8') try: from cryptography.hazmat.primitives import serialization @@ -107,7 +111,7 @@ def _pubkey_bytes(key): raise ValueError('Key must be a string or cryptographic public/private key (was %s)' % type(key).__name__) -def _hash_attr(obj, *attributes, **kwargs): +def _hash_attr(obj: Any, *attributes: str, **kwargs: Any): """ Provide a hash value for the given set of attributes. diff --git a/stem/util/conf.py b/stem/util/conf.py index a06f1fd79..1fd31fd03 100644 --- a/stem/util/conf.py +++ b/stem/util/conf.py @@ -162,17 +162,20 @@ def config_validator(key, value): import os import threading +import stem.util.enum + from stem.util import log +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Set, Union CONFS = {} # mapping of identifier to singleton instances of configs class _SyncListener(object): - def __init__(self, config_dict, interceptor): + def __init__(self, config_dict: Mapping[str, Any], interceptor: Callable[[str, Any], Any]) -> None: self.config_dict = config_dict self.interceptor = interceptor - def update(self, config, key): + def update(self, config: 'stem.util.conf.Config', key: str) -> None: if key in self.config_dict: new_value = config.get(key, self.config_dict[key]) @@ -185,10 +188,10 @@ def update(self, config, key): if interceptor_value: new_value = interceptor_value - self.config_dict[key] = new_value + self.config_dict[key] = new_value # type: ignore -def config_dict(handle, conf_mappings, handler = None): +def config_dict(handle: str, conf_mappings: Dict[str, Any], handler: Optional[Callable[[str, Any], Any]] = None) -> Dict[str, Any]: """ Makes a dictionary that stays synchronized with a configuration. @@ -214,6 +217,8 @@ def config_dict(handle, conf_mappings, handler = None): :param str handle: unique identifier for a config instance :param dict conf_mappings: config key/value mappings used as our defaults :param functor handler: function referred to prior to assigning values + + :returns: mapping of attributes to their current configuration value """ selected_config = get_config(handle) @@ -221,7 +226,7 @@ def config_dict(handle, conf_mappings, handler = None): return conf_mappings -def get_config(handle): +def get_config(handle: str) -> 'stem.util.conf.Config': """ Singleton constructor for configuration file instances. If a configuration already exists for the handle then it's returned. Otherwise a fresh instance @@ -236,7 +241,7 @@ def get_config(handle): return CONFS[handle] -def uses_settings(handle, path, lazy_load = True): +def uses_settings(handle: str, path: str, lazy_load: bool = True) -> Callable: """ Provides a function that can be used as a decorator for other functions that require settings to be loaded. Functions with this decorator will be provided @@ -272,13 +277,13 @@ def my_function(config): config.load(path) config._settings_loaded = True - def decorator(func): - def wrapped(*args, **kwargs): + def decorator(func: Callable) -> Callable: + def wrapped(*args: Any, **kwargs: Any) -> Any: if lazy_load and not config._settings_loaded: config.load(path) config._settings_loaded = True - if 'config' in inspect.getargspec(func).args: + if 'config' in inspect.getfullargspec(func).args: return func(*args, config = config, **kwargs) else: return func(*args, **kwargs) @@ -288,7 +293,7 @@ def wrapped(*args, **kwargs): return decorator -def parse_enum(key, value, enumeration): +def parse_enum(key: str, value: str, enumeration: 'stem.util.enum.Enum') -> Any: """ Provides the enumeration value for a given key. This is a case insensitive lookup and raises an exception if the enum key doesn't exist. @@ -305,7 +310,7 @@ def parse_enum(key, value, enumeration): return parse_enum_csv(key, value, enumeration, 1)[0] -def parse_enum_csv(key, value, enumeration, count = None): +def parse_enum_csv(key: str, value: str, enumeration: 'stem.util.enum.Enum', count: Optional[Union[int, Sequence[int]]] = None) -> List[Any]: """ Parses a given value as being a comma separated listing of enumeration keys, returning the corresponding enumeration values. This is intended to be a @@ -445,21 +450,21 @@ class Config(object): Class can now be used as a dictionary. """ - def __init__(self): - self._path = None # location we last loaded from or saved to - self._contents = collections.OrderedDict() # configuration key/value pairs - self._listeners = [] # functors to be notified of config changes + def __init__(self) -> None: + self._path = None # type: Optional[str] # location we last loaded from or saved to + self._contents = collections.OrderedDict() # type: Dict[str, Any] # configuration key/value pairs + self._listeners = [] # type: List[Callable[['stem.util.conf.Config', str], Any]] # functors to be notified of config changes # used for accessing _contents self._contents_lock = threading.RLock() # keys that have been requested (used to provide unused config contents) - self._requested_keys = set() + self._requested_keys = set() # type: Set[str] # flag to support lazy loading in uses_settings() self._settings_loaded = False - def load(self, path = None, commenting = True): + def load(self, path: Optional[str] = None, commenting: bool = True) -> None: """ Reads in the contents of the given path, adding its configuration values to our current contents. If the path is a directory then this loads each @@ -534,7 +539,7 @@ def load(self, path = None, commenting = True): else: self.set(line, '', False) # default to a key => '' mapping - def save(self, path = None): + def save(self, path: Optional[str] = None) -> None: """ Saves configuration contents to disk. If a path is provided then it replaces the configuration location that we track. @@ -564,7 +569,7 @@ def save(self, path = None): output_file.write('%s %s\n' % (entry_key, entry_value)) - def clear(self): + def clear(self) -> None: """ Drops the configuration contents and reverts back to a blank, unloaded state. @@ -574,7 +579,7 @@ def clear(self): self._contents.clear() self._requested_keys = set() - def add_listener(self, listener, backfill = True): + def add_listener(self, listener: Callable[['stem.util.conf.Config', str], Any], backfill: bool = True) -> None: """ Registers the function to be notified of configuration updates. Listeners are expected to be functors which accept (config, key). @@ -590,14 +595,14 @@ def add_listener(self, listener, backfill = True): for key in self.keys(): listener(self, key) - def clear_listeners(self): + def clear_listeners(self) -> None: """ Removes all attached listeners. """ self._listeners = [] - def keys(self): + def keys(self) -> List[str]: """ Provides all keys in the currently loaded configuration. @@ -606,7 +611,7 @@ def keys(self): return list(self._contents.keys()) - def unused_keys(self): + def unused_keys(self) -> Set[str]: """ Provides the configuration keys that have never been provided to a caller via :func:`~stem.util.conf.config_dict` or the @@ -618,7 +623,7 @@ def unused_keys(self): return set(self.keys()).difference(self._requested_keys) - def set(self, key, value, overwrite = True): + def set(self, key: str, value: Union[str, Sequence[str]], overwrite: bool = True) -> None: """ Appends the given key/value configuration mapping, behaving the same as if we'd loaded this from a configuration file. @@ -657,7 +662,7 @@ def set(self, key, value, overwrite = True): else: raise ValueError("Config.set() only accepts str (bytes or unicode), list, or tuple. Provided value was a '%s'" % type(value)) - def get(self, key, default = None): + def get(self, key: str, default: Optional[Any] = None) -> Any: """ Fetches the given configuration, using the key and default value to determine the type it should be. Recognized inferences are: @@ -737,7 +742,7 @@ def get(self, key, default = None): return val - def get_value(self, key, default = None, multiple = False): + def get_value(self, key: str, default: Optional[Any] = None, multiple: bool = False) -> Union[str, List[str]]: """ This provides the current value associated with a given key. @@ -763,6 +768,6 @@ def get_value(self, key, default = None, multiple = False): log.log_once(message_id, log.TRACE, "config entry '%s' not found, defaulting to '%s'" % (key, default)) return default - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: with self._contents_lock: return self._contents[key] diff --git a/stem/util/connection.py b/stem/util/connection.py index eaeafec42..21745c43c 100644 --- a/stem/util/connection.py +++ b/stem/util/connection.py @@ -65,6 +65,7 @@ import stem.util.system from stem.util import conf, enum, log, str_tools +from typing import List, Optional, Sequence, Tuple, Union # Connection resolution is risky to log about since it's highly likely to # contain sensitive information. That said, it's also difficult to get right in @@ -157,15 +158,15 @@ class Connection(collections.namedtuple('Connection', ['local_address', 'local_p """ -def download(url, timeout = None, retries = None): +def download(url: str, timeout: Optional[float] = None, retries: Optional[int] = None) -> bytes: """ Download from the given url. .. versionadded:: 1.8.0 :param str url: uncompressed url to download from - :param int timeout: timeout when connection becomes idle, no timeout applied - if **None** + :param float timeout: timeout when connection becomes idle, no timeout + applied if **None** :param int retires: maximum attempts to impose :returns: **bytes** content of the given url @@ -185,20 +186,20 @@ def download(url, timeout = None, retries = None): except socket.timeout as exc: raise stem.DownloadTimeout(url, exc, sys.exc_info()[2], timeout) except: - exc, stacktrace = sys.exc_info()[1:3] + exception, stacktrace = sys.exc_info()[1:3] if timeout is not None: timeout -= time.time() - start_time if retries > 0 and (timeout is None or timeout > 0): - log.debug('Failed to download from %s (%i retries remaining): %s' % (url, retries, exc)) + log.debug('Failed to download from %s (%i retries remaining): %s' % (url, retries, exception)) return download(url, timeout, retries - 1) else: - log.debug('Failed to download from %s: %s' % (url, exc)) - raise stem.DownloadFailed(url, exc, stacktrace) + log.debug('Failed to download from %s: %s' % (url, exception)) + raise stem.DownloadFailed(url, exception, stacktrace) -def get_connections(resolver = None, process_pid = None, process_name = None): +def get_connections(resolver: Optional['stem.util.connection.Resolver'] = None, process_pid: Optional[int] = None, process_name: Optional[str] = None) -> Sequence['stem.util.connection.Connection']: """ Retrieves a list of the current connections for a given process. This provides a list of :class:`~stem.util.connection.Connection`. Note that @@ -239,7 +240,7 @@ def get_connections(resolver = None, process_pid = None, process_name = None): if not process_pid and not process_name: raise ValueError('You must provide a pid or process name to provide connections for') - def _log(msg): + def _log(msg: str) -> None: if LOG_CONNECTION_RESOLUTION: log.debug(msg) @@ -253,7 +254,7 @@ def _log(msg): raise ValueError('Process pid was non-numeric: %s' % process_pid) if process_pid is None: - all_pids = stem.util.system.pid_by_name(process_name, True) + all_pids = stem.util.system.pid_by_name(process_name, True) # type: List[int] # type: ignore if len(all_pids) == 0: if resolver in (Resolver.NETSTAT_WINDOWS, Resolver.PROC, Resolver.BSD_PROCSTAT): @@ -288,7 +289,7 @@ def _log(msg): connections = [] resolver_regex = re.compile(resolver_regex_str) - def _parse_address_str(addr_type, addr_str, line): + def _parse_address_str(addr_type: str, addr_str: str, line: str) -> Tuple[str, int]: addr, port = addr_str.rsplit(':', 1) if not is_valid_ipv4_address(addr) and not is_valid_ipv6_address(addr, allow_brackets = True): @@ -334,7 +335,7 @@ def _parse_address_str(addr_type, addr_str, line): return connections -def system_resolvers(system = None): +def system_resolvers(system: Optional[str] = None) -> Sequence['stem.util.connection.Resolver']: """ Provides the types of connection resolvers likely to be available on this platform. @@ -383,7 +384,7 @@ def system_resolvers(system = None): return resolvers -def port_usage(port): +def port_usage(port: int) -> Optional[str]: """ Provides the common use of a given port. For example, 'HTTP' for port 80 or 'SSH' for 22. @@ -429,7 +430,7 @@ def port_usage(port): return PORT_USES.get(port) -def is_valid_ipv4_address(address): +def is_valid_ipv4_address(address: str) -> bool: """ Checks if a string is a valid IPv4 address. @@ -458,7 +459,7 @@ def is_valid_ipv4_address(address): return True -def is_valid_ipv6_address(address, allow_brackets = False): +def is_valid_ipv6_address(address: str, allow_brackets: bool = False) -> bool: """ Checks if a string is a valid IPv6 address. @@ -513,7 +514,7 @@ def is_valid_ipv6_address(address, allow_brackets = False): return True -def is_valid_port(entry, allow_zero = False): +def is_valid_port(entry: Union[str, int, Sequence[str], Sequence[int]], allow_zero: bool = False) -> bool: """ Checks if a string or int is a valid port number. @@ -523,8 +524,15 @@ def is_valid_port(entry, allow_zero = False): :returns: **True** if input is an integer and within the valid port range, **False** otherwise """ + if isinstance(entry, (tuple, list)): + for port in entry: + if not is_valid_port(port, allow_zero): + return False + + return True + try: - value = int(entry) + value = int(entry) # type: ignore if str(value) != str(entry): return False # invalid leading char, e.g. space or zero @@ -533,19 +541,12 @@ def is_valid_port(entry, allow_zero = False): else: return value > 0 and value < 65536 except TypeError: - if isinstance(entry, (tuple, list)): - for port in entry: - if not is_valid_port(port, allow_zero): - return False - - return True - else: - return False + return False except ValueError: return False -def is_private_address(address): +def is_private_address(address: str) -> bool: """ Checks if the IPv4 address is in a range belonging to the local network or loopback. These include: @@ -581,7 +582,7 @@ def is_private_address(address): return False -def address_to_int(address): +def address_to_int(address: str) -> int: """ Provides an integer representation of a IPv4 or IPv6 address that can be used for sorting. @@ -599,7 +600,7 @@ def address_to_int(address): return int(_address_to_binary(address), 2) -def expand_ipv6_address(address): +def expand_ipv6_address(address: str) -> str: """ Expands abbreviated IPv6 addresses to their full colon separated hex format. For instance... @@ -620,6 +621,9 @@ def expand_ipv6_address(address): :raises: **ValueError** if the address can't be expanded due to being malformed """ + if isinstance(address, bytes): + address = str_tools._to_unicode(address) + if not is_valid_ipv6_address(address): raise ValueError("'%s' isn't a valid IPv6 address" % address) @@ -660,7 +664,7 @@ def expand_ipv6_address(address): return address -def get_mask_ipv4(bits): +def get_mask_ipv4(bits: int) -> str: """ Provides the IPv4 mask for a given number of bits, in the dotted-quad format. @@ -686,7 +690,7 @@ def get_mask_ipv4(bits): return '.'.join([str(int(octet, 2)) for octet in octets]) -def get_mask_ipv6(bits): +def get_mask_ipv6(bits: int) -> str: """ Provides the IPv6 mask for a given number of bits, in the hex colon-delimited format. @@ -713,7 +717,7 @@ def get_mask_ipv6(bits): return ':'.join(['%04x' % int(group, 2) for group in groupings]).upper() -def _get_masked_bits(mask): +def _get_masked_bits(mask: str) -> int: """ Provides the number of bits that an IPv4 subnet mask represents. Note that not all masks can be represented by a bit count. @@ -738,13 +742,15 @@ def _get_masked_bits(mask): raise ValueError('Unable to convert mask to a bit count: %s' % mask) -def _get_binary(value, bits): +def _get_binary(value: int, bits: int) -> str: """ Provides the given value as a binary string, padded with zeros to the given number of bits. :param int value: value to be converted :param int bits: number of bits to pad to + + :returns: **str** of this binary value """ # http://www.daniweb.com/code/snippet216539.html @@ -754,10 +760,12 @@ def _get_binary(value, bits): # TODO: In stem 2.x we should consider unifying this with # stem.client.datatype's _unpack_ipv4_address() and _unpack_ipv6_address(). -def _address_to_binary(address): +def _address_to_binary(address: str) -> str: """ Provides the binary value for an IPv4 or IPv6 address. + :param str address: address to convert + :returns: **str** with the binary representation of this address :raises: **ValueError** if address is neither an IPv4 nor IPv6 address diff --git a/stem/util/enum.py b/stem/util/enum.py index 56bf119dd..719a4c069 100644 --- a/stem/util/enum.py +++ b/stem/util/enum.py @@ -40,8 +40,10 @@ +- __iter__ - iterator over our enum keys """ +from typing import Any, Iterator, List, Sequence, Tuple, Union -def UppercaseEnum(*args): + +def UppercaseEnum(*args: str) -> 'Enum': """ Provides an :class:`~stem.util.enum.Enum` instance where the values are identical to the keys. Since the keys are uppercase by convention this means @@ -67,14 +69,15 @@ class Enum(object): Basic enumeration. """ - def __init__(self, *args): + def __init__(self, *args: Union[str, Tuple[str, Any]]) -> None: from stem.util.str_tools import _to_camel_case # ordered listings of our keys and values - keys, values = [], [] + keys = [] # type: List[str] + values = [] # type: List[Any] for entry in args: - if isinstance(entry, (bytes, str)): + if isinstance(entry, str): key, val = entry, _to_camel_case(entry) elif isinstance(entry, tuple) and len(entry) == 2: key, val = entry @@ -88,7 +91,7 @@ def __init__(self, *args): self._keys = tuple(keys) self._values = tuple(values) - def keys(self): + def keys(self) -> Sequence[str]: """ Provides an ordered listing of the enumeration keys in this set. @@ -97,11 +100,11 @@ def keys(self): return list(self._keys) - def index_of(self, value): + def index_of(self, value: Any) -> int: """ Provides the index of the given value in the collection. - :param str value: entry to be looked up + :param object value: entry to be looked up :returns: **int** index of the given entry @@ -110,11 +113,11 @@ def index_of(self, value): return self._values.index(value) - def next(self, value): + def next(self, value: Any) -> Any: """ Provides the next enumeration after the given value. - :param str value: enumeration for which to get the next entry + :param object value: enumeration for which to get the next entry :returns: enum value following the given entry @@ -127,11 +130,11 @@ def next(self, value): next_index = (self._values.index(value) + 1) % len(self._values) return self._values[next_index] - def previous(self, value): + def previous(self, value: Any) -> Any: """ Provides the previous enumeration before the given value. - :param str value: enumeration for which to get the previous entry + :param object value: enumeration for which to get the previous entry :returns: enum value proceeding the given entry @@ -144,13 +147,13 @@ def previous(self, value): prev_index = (self._values.index(value) - 1) % len(self._values) return self._values[prev_index] - def __getitem__(self, item): + def __getitem__(self, item: str) -> Any: """ Provides the values for the given key. - :param str item: key to be looked up + :param str item: key to look up - :returns: **str** with the value for the given key + :returns: value for the given key :raises: **ValueError** if the key doesn't exist """ @@ -161,7 +164,7 @@ def __getitem__(self, item): keys = ', '.join(self.keys()) raise ValueError("'%s' isn't among our enumeration keys, which includes: %s" % (item, keys)) - def __iter__(self): + def __iter__(self) -> Iterator[Any]: """ Provides an ordered listing of the enums in this set. """ diff --git a/stem/util/log.py b/stem/util/log.py index 94d055ffc..404249a7f 100644 --- a/stem/util/log.py +++ b/stem/util/log.py @@ -92,10 +92,10 @@ class _NullHandler(logging.Handler): - def __init__(self): + def __init__(self) -> None: logging.Handler.__init__(self, level = logging.FATAL + 5) # disable logging - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: pass @@ -103,7 +103,7 @@ def emit(self, record): LOGGER.addHandler(_NullHandler()) -def get_logger(): +def get_logger() -> logging.Logger: """ Provides the stem logger. @@ -113,7 +113,7 @@ def get_logger(): return LOGGER -def logging_level(runlevel): +def logging_level(runlevel: 'stem.util.log.Runlevel') -> int: """ Translates a runlevel into the value expected by the logging module. @@ -126,7 +126,7 @@ def logging_level(runlevel): return logging.FATAL + 5 -def is_tracing(): +def is_tracing() -> bool: """ Checks if we're logging at the trace runlevel. @@ -142,7 +142,7 @@ def is_tracing(): return False -def escape(message): +def escape(message: str) -> str: """ Escapes specific sequences for logging (newlines, tabs, carriage returns). If the input is **bytes** then this converts it to **unicode** under python 3.x. @@ -160,7 +160,7 @@ def escape(message): return message -def log(runlevel, message): +def log(runlevel: 'stem.util.log.Runlevel', message: str) -> None: """ Logs a message at the given runlevel. @@ -172,7 +172,7 @@ def log(runlevel, message): LOGGER.log(LOG_VALUES[runlevel], message) -def log_once(message_id, runlevel, message): +def log_once(message_id: str, runlevel: 'stem.util.log.Runlevel', message: str) -> bool: """ Logs a message at the given runlevel. If a message with this ID has already been logged then this is a no-op. @@ -189,47 +189,48 @@ def log_once(message_id, runlevel, message): else: DEDUPLICATION_MESSAGE_IDS.add(message_id) log(runlevel, message) + return True # shorter aliases for logging at a runlevel -def trace(message): +def trace(message: str) -> None: log(Runlevel.TRACE, message) -def debug(message): +def debug(message: str) -> None: log(Runlevel.DEBUG, message) -def info(message): +def info(message: str) -> None: log(Runlevel.INFO, message) -def notice(message): +def notice(message: str) -> None: log(Runlevel.NOTICE, message) -def warn(message): +def warn(message: str) -> None: log(Runlevel.WARN, message) -def error(message): +def error(message: str) -> None: log(Runlevel.ERROR, message) class _StdoutLogger(logging.Handler): - def __init__(self, runlevel): + def __init__(self, runlevel: 'stem.util.log.Runlevel') -> None: logging.Handler.__init__(self, level = logging_level(runlevel)) self.formatter = logging.Formatter( fmt = '%(asctime)s [%(levelname)s] %(message)s', datefmt = '%m/%d/%Y %H:%M:%S') - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: print(self.formatter.format(record)) -def log_to_stdout(runlevel): +def log_to_stdout(runlevel: 'stem.util.log.Runlevel') -> None: """ Logs further events to stdout. diff --git a/stem/util/proc.py b/stem/util/proc.py index 3589af134..e180bb660 100644 --- a/stem/util/proc.py +++ b/stem/util/proc.py @@ -56,6 +56,7 @@ import stem.util.str_tools from stem.util import log +from typing import Any, Mapping, Optional, Sequence, Set, Tuple try: # unavailable on windows (#19823) @@ -80,7 +81,7 @@ @functools.lru_cache() -def is_available(): +def is_available() -> bool: """ Checks if proc information is available on this platform. @@ -101,7 +102,7 @@ def is_available(): @functools.lru_cache() -def system_start_time(): +def system_start_time() -> float: """ Provides the unix time (seconds since epoch) when the system started. @@ -124,7 +125,7 @@ def system_start_time(): @functools.lru_cache() -def physical_memory(): +def physical_memory() -> int: """ Provides the total physical memory on the system in bytes. @@ -146,7 +147,7 @@ def physical_memory(): raise exc -def cwd(pid): +def cwd(pid: int) -> str: """ Provides the current working directory for the given process. @@ -174,7 +175,7 @@ def cwd(pid): return cwd -def uid(pid): +def uid(pid: int) -> int: """ Provides the user ID the given process is running under. @@ -199,7 +200,7 @@ def uid(pid): raise exc -def memory_usage(pid): +def memory_usage(pid: int) -> Tuple[int, int]: """ Provides the memory usage in bytes for the given process. @@ -232,7 +233,7 @@ def memory_usage(pid): raise exc -def stats(pid, *stat_types): +def stats(pid: int, *stat_types: 'stem.util.proc.Stat') -> Sequence[str]: """ Provides process specific information. See the :data:`~stem.util.proc.Stat` enum for valid options. @@ -270,6 +271,7 @@ def stats(pid, *stat_types): raise exc results = [] + for stat_type in stat_types: if stat_type == Stat.COMMAND: if pid == 0: @@ -288,7 +290,7 @@ def stats(pid, *stat_types): results.append(str(float(stat_comp[14]) / CLOCK_TICKS)) elif stat_type == Stat.START_TIME: if pid == 0: - return system_start_time() + results.append(str(system_start_time())) else: # According to documentation, starttime is in field 21 and the unit is # jiffies (clock ticks). We divide it for clock ticks, then add the @@ -300,7 +302,7 @@ def stats(pid, *stat_types): return tuple(results) -def file_descriptors_used(pid): +def file_descriptors_used(pid: int) -> int: """ Provides the number of file descriptors currently being used by a process. @@ -327,7 +329,7 @@ def file_descriptors_used(pid): raise IOError('Unable to check number of file descriptors used: %s' % exc) -def connections(pid = None, user = None): +def connections(pid: Optional[int] = None, user: Optional[str] = None) -> Sequence['stem.util.connection.Connection']: """ Queries connections from the proc contents. This matches netstat, lsof, and friends but is much faster. If no **pid** or **user** are provided this @@ -412,7 +414,7 @@ def connections(pid = None, user = None): raise -def _inodes_for_sockets(pid): +def _inodes_for_sockets(pid: int) -> Set[bytes]: """ Provides inodes in use by a process for its sockets. @@ -450,7 +452,7 @@ def _inodes_for_sockets(pid): return inodes -def _unpack_addr(addr): +def _unpack_addr(addr: bytes) -> str: """ Translates an address entry in the /proc/net/* contents to a human readable form (`reference `_, @@ -494,7 +496,7 @@ def _unpack_addr(addr): return ENCODED_ADDR[addr] -def _is_float(*value): +def _is_float(*value: Any) -> bool: try: for v in value: float(v) @@ -504,11 +506,11 @@ def _is_float(*value): return False -def _get_line(file_path, line_prefix, parameter): +def _get_line(file_path: str, line_prefix: str, parameter: str) -> str: return _get_lines(file_path, (line_prefix, ), parameter)[line_prefix] -def _get_lines(file_path, line_prefixes, parameter): +def _get_lines(file_path: str, line_prefixes: Sequence[str], parameter: str) -> Mapping[str, str]: """ Fetches lines with the given prefixes from a file. This only provides back the first instance of each prefix. @@ -552,7 +554,7 @@ def _get_lines(file_path, line_prefixes, parameter): raise -def _log_runtime(parameter, proc_location, start_time): +def _log_runtime(parameter: str, proc_location: str, start_time: float) -> None: """ Logs a message indicating a successful proc query. @@ -565,7 +567,7 @@ def _log_runtime(parameter, proc_location, start_time): log.debug('proc call (%s): %s (runtime: %0.4f)' % (parameter, proc_location, runtime)) -def _log_failure(parameter, exc): +def _log_failure(parameter: str, exc: BaseException) -> None: """ Logs a message indicating that the proc query failed. diff --git a/stem/util/str_tools.py b/stem/util/str_tools.py index c12856263..a0bef734f 100644 --- a/stem/util/str_tools.py +++ b/stem/util/str_tools.py @@ -26,6 +26,8 @@ import stem.util import stem.util.enum +from typing import List, Sequence, Tuple, Union, overload + # label conversion tuples of the form... # (bits / bytes / seconds, short label, long label) @@ -57,7 +59,7 @@ _timestamp_re = re.compile(r'(\d{4})-(\d{2})-(\d{2}) (\d{2}):(\d{2}):(\d{2})') -def _to_bytes(msg): +def _to_bytes(msg: Union[str, bytes]) -> bytes: """ Provides the ASCII bytes for the given string. This is purely to provide python 3 compatability, normalizing the unicode/ASCII change in the version @@ -71,12 +73,12 @@ def _to_bytes(msg): """ if isinstance(msg, str): - return codecs.latin_1_encode(msg, 'replace')[0] + return codecs.latin_1_encode(msg, 'replace')[0] # type: ignore else: return msg -def _to_unicode(msg): +def _to_unicode(msg: Union[str, bytes]) -> str: """ Provides the unicode string for the given ASCII bytes. This is purely to provide python 3 compatability, normalizing the unicode/ASCII change in the @@ -93,7 +95,7 @@ def _to_unicode(msg): return msg -def _decode_b64(msg): +def _decode_b64(msg: bytes) -> bytes: """ Base64 decode, without padding concerns. """ @@ -101,10 +103,10 @@ def _decode_b64(msg): missing_padding = len(msg) % 4 padding_chr = b'=' if isinstance(msg, bytes) else '=' - return base64.b64decode(msg + padding_chr * missing_padding) + return base64.b64decode(msg + (padding_chr * missing_padding)) -def _to_int(msg): +def _to_int(msg: Union[str, bytes]) -> int: """ Serializes a string to a number. @@ -120,7 +122,7 @@ def _to_int(msg): return sum([pow(256, (len(msg) - i - 1)) * ord(c) for (i, c) in enumerate(msg)]) -def _to_camel_case(label, divider = '_', joiner = ' '): +def _to_camel_case(label: str, divider: str = '_', joiner: str = ' ') -> str: """ Converts the given string to camel case, ie: @@ -148,6 +150,16 @@ def _to_camel_case(label, divider = '_', joiner = ' '): return joiner.join(words) +@overload +def _split_by_length(msg: bytes, size: int) -> List[bytes]: + ... + + +@overload +def _split_by_length(msg: str, size: int) -> List[str]: + ... + + def _split_by_length(msg, size): """ Splits a string into a list of strings up to the given size. @@ -172,7 +184,7 @@ def _split_by_length(msg, size): Ending = stem.util.enum.Enum('ELLIPSE', 'HYPHEN') -def crop(msg, size, min_word_length = 4, min_crop = 0, ending = Ending.ELLIPSE, get_remainder = False): +def crop(msg: str, size: int, min_word_length: int = 4, min_crop: int = 0, ending: 'stem.util.str_tools.Ending' = Ending.ELLIPSE, get_remainder: bool = False) -> Union[str, Tuple[str, str]]: """ Shortens a string to a given length. @@ -286,7 +298,7 @@ def crop(msg, size, min_word_length = 4, min_crop = 0, ending = Ending.ELLIPSE, return (return_msg, remainder) if get_remainder else return_msg -def size_label(byte_count, decimal = 0, is_long = False, is_bytes = True, round = False): +def size_label(byte_count: int, decimal: int = 0, is_long: bool = False, is_bytes: bool = True, round: bool = False) -> str: """ Converts a number of bytes into a human readable label in its most significant units. For instance, 7500 bytes would return "7 KB". If the @@ -323,7 +335,7 @@ def size_label(byte_count, decimal = 0, is_long = False, is_bytes = True, round return _get_label(SIZE_UNITS_BITS, byte_count, decimal, is_long, round) -def time_label(seconds, decimal = 0, is_long = False): +def time_label(seconds: int, decimal: int = 0, is_long: bool = False) -> str: """ Converts seconds into a time label truncated to its most significant units. For instance, 7500 seconds would return "2h". Units go up through days. @@ -354,7 +366,7 @@ def time_label(seconds, decimal = 0, is_long = False): return _get_label(TIME_UNITS, seconds, decimal, is_long) -def time_labels(seconds, is_long = False): +def time_labels(seconds: int, is_long: bool = False) -> Sequence[str]: """ Provides a list of label conversions for each time unit, starting with its most significant units on down. Any counts that evaluate to zero are omitted. @@ -379,12 +391,12 @@ def time_labels(seconds, is_long = False): for count_per_unit, _, _ in TIME_UNITS: if abs(seconds) >= count_per_unit: time_labels.append(_get_label(TIME_UNITS, seconds, 0, is_long)) - seconds %= count_per_unit + seconds %= int(count_per_unit) return time_labels -def short_time_label(seconds): +def short_time_label(seconds: int) -> str: """ Provides a time in the following format: [[dd-]hh:]mm:ss @@ -411,7 +423,7 @@ def short_time_label(seconds): for amount, _, label in TIME_UNITS: count = int(seconds / amount) - seconds %= amount + seconds %= int(amount) time_comp[label.strip()] = count label = '%02i:%02i' % (time_comp['minute'], time_comp['second']) @@ -424,7 +436,7 @@ def short_time_label(seconds): return label -def parse_short_time_label(label): +def parse_short_time_label(label: str) -> int: """ Provides the number of seconds corresponding to the formatting used for the cputime and etime fields of ps: @@ -469,7 +481,7 @@ def parse_short_time_label(label): raise ValueError('Non-numeric value in time entry: %s' % label) -def _parse_timestamp(entry): +def _parse_timestamp(entry: str) -> datetime.datetime: """ Parses the date and time that in format like like... @@ -495,7 +507,7 @@ def _parse_timestamp(entry): return datetime.datetime(time[0], time[1], time[2], time[3], time[4], time[5]) -def _parse_iso_timestamp(entry): +def _parse_iso_timestamp(entry: str) -> 'datetime.datetime': """ Parses the ISO 8601 standard that provides for timestamps like... @@ -533,7 +545,7 @@ def _parse_iso_timestamp(entry): return timestamp + datetime.timedelta(microseconds = int(microseconds)) -def _get_label(units, count, decimal, is_long, round = False): +def _get_label(units: Sequence[Tuple[float, str, str]], count: int, decimal: int, is_long: bool, round: bool = False) -> str: """ Provides label corresponding to units of the highest significance in the provided set. This rounds down (ie, integer truncation after visible units). @@ -578,3 +590,5 @@ def _get_label(units, count, decimal, is_long, round = False): return count_label + long_label + ('s' if is_plural else '') else: return count_label + short_label + + raise ValueError('BUG: value should always be divisible by a unit (%s)' % str(units)) diff --git a/stem/util/system.py b/stem/util/system.py index b3dee1518..a51479768 100644 --- a/stem/util/system.py +++ b/stem/util/system.py @@ -82,6 +82,7 @@ from stem import UNDEFINED from stem.util import log +from typing import Any, BinaryIO, Callable, Collection, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union State = stem.util.enum.UppercaseEnum( 'PENDING', @@ -97,11 +98,11 @@ dict: lambda d: itertools.chain.from_iterable(d.items()), set: iter, frozenset: iter, -} +} # type: Dict[Type, Callable] # Mapping of commands to if they're available or not. -CMD_AVAILABLE_CACHE = {} +CMD_AVAILABLE_CACHE = {} # type: Dict[str, bool] # An incomplete listing of commands provided by the shell. Expand this as # needed. Some noteworthy things about shell commands... @@ -185,11 +186,11 @@ class CallError(OSError): :var str command: command that was ran :var int exit_status: exit code of the process :var float runtime: time the command took to run - :var str stdout: stdout of the process - :var str stderr: stderr of the process + :var bytes stdout: stdout of the process + :var bytes stderr: stderr of the process """ - def __init__(self, msg, command, exit_status, runtime, stdout, stderr): + def __init__(self, msg: str, command: str, exit_status: int, runtime: float, stdout: bytes, stderr: bytes) -> None: self.msg = msg self.command = command self.exit_status = exit_status @@ -197,7 +198,7 @@ def __init__(self, msg, command, exit_status, runtime, stdout, stderr): self.stdout = stdout self.stderr = stderr - def __str__(self): + def __str__(self) -> str: return self.msg @@ -210,7 +211,7 @@ class CallTimeoutError(CallError): :var float timeout: time we waited """ - def __init__(self, msg, command, exit_status, runtime, stdout, stderr, timeout): + def __init__(self, msg: str, command: str, exit_status: int, runtime: float, stdout: bytes, stderr: bytes, timeout: float) -> None: super(CallTimeoutError, self).__init__(msg, command, exit_status, runtime, stdout, stderr) self.timeout = timeout @@ -231,7 +232,7 @@ class DaemonTask(object): :var exception error: exception raised by subprocess if it failed """ - def __init__(self, runner, args = None, priority = 15, start = False): + def __init__(self, runner: Callable, args: Optional[Sequence[Any]] = None, priority: int = 15, start: bool = False) -> None: self.runner = runner self.args = args self.priority = priority @@ -241,13 +242,13 @@ def __init__(self, runner, args = None, priority = 15, start = False): self.result = None self.error = None - self._process = None - self._pipe = None + self._process = None # type: Optional[multiprocessing.Process] + self._pipe = None # type: Optional[multiprocessing.connection.Connection] if start: self.run() - def run(self): + def run(self) -> None: """ Invokes the task if it hasn't already been started. If it has this is a no-op. @@ -259,7 +260,7 @@ def run(self): self._process.start() self.status = State.RUNNING - def join(self): + def join(self) -> Any: """ Provides the result of the daemon task. If still running this blocks until the task is completed. @@ -292,7 +293,7 @@ def join(self): raise RuntimeError('BUG: unexpected status from daemon task, %s' % self.status) @staticmethod - def _run_wrapper(conn, priority, runner, args): + def _run_wrapper(conn: 'multiprocessing.connection.Connection', priority: int, runner: Callable, args: Sequence[Any]) -> None: start_time = time.time() os.nice(priority) @@ -305,7 +306,7 @@ def _run_wrapper(conn, priority, runner, args): conn.close() -def is_windows(): +def is_windows() -> bool: """ Checks if we are running on Windows. @@ -315,7 +316,7 @@ def is_windows(): return platform.system() == 'Windows' -def is_mac(): +def is_mac() -> bool: """ Checks if we are running on Mac OSX. @@ -325,7 +326,7 @@ def is_mac(): return platform.system() == 'Darwin' -def is_gentoo(): +def is_gentoo() -> bool: """ Checks if we're running on Gentoo. @@ -335,7 +336,7 @@ def is_gentoo(): return os.path.exists('/etc/gentoo-release') -def is_slackware(): +def is_slackware() -> bool: """ Checks if we are running on a Slackware system. @@ -345,7 +346,7 @@ def is_slackware(): return os.path.exists('/etc/slackware-version') -def is_bsd(): +def is_bsd() -> bool: """ Checks if we are within the BSD family of operating systems. This currently recognizes Macs, FreeBSD, and OpenBSD but may be expanded later. @@ -356,7 +357,7 @@ def is_bsd(): return platform.system() in ('Darwin', 'FreeBSD', 'OpenBSD', 'NetBSD') -def is_available(command, cached=True): +def is_available(command: str, cached: bool = True) -> bool: """ Checks the current PATH to see if a command is available or not. If more than one command is present (for instance "ls -a | grep foo") then this @@ -399,7 +400,7 @@ def is_available(command, cached=True): return cmd_exists -def is_running(command): +def is_running(command: Union[str, int, Sequence[str]]) -> bool: """ Checks for if a process with a given name or pid is running. @@ -461,7 +462,7 @@ def is_running(command): return None -def size_of(obj, exclude = None): +def size_of(obj: Any, exclude: Optional[Collection[int]] = None) -> int: """ Provides the `approximate memory usage of an object `_. This can recurse tuples, @@ -485,9 +486,9 @@ def size_of(obj, exclude = None): if platform.python_implementation() == 'PyPy': raise NotImplementedError('PyPy does not implement sys.getsizeof()') - if exclude is None: - exclude = set() - elif id(obj) in exclude: + exclude = set(exclude) if exclude is not None else set() + + if id(obj) in exclude: return 0 try: @@ -504,7 +505,7 @@ def size_of(obj, exclude = None): return size -def name_by_pid(pid): +def name_by_pid(pid: int) -> Optional[str]: """ Attempts to determine the name a given process is running under (not including arguments). This uses... @@ -547,7 +548,7 @@ def name_by_pid(pid): return process_name -def pid_by_name(process_name, multiple = False): +def pid_by_name(process_name: str, multiple: bool = False) -> Union[int, List[int]]: """ Attempts to determine the process id for a running process, using... @@ -718,7 +719,7 @@ def pid_by_name(process_name, multiple = False): return [] if multiple else None -def pid_by_port(port): +def pid_by_port(port: int) -> Optional[int]: """ Attempts to determine the process id for a process with the given port, using... @@ -838,7 +839,7 @@ def pid_by_port(port): return None # all queries failed -def pid_by_open_file(path): +def pid_by_open_file(path: str) -> Optional[int]: """ Attempts to determine the process id for a process with the given open file, using... @@ -876,7 +877,7 @@ def pid_by_open_file(path): return None # all queries failed -def pids_by_user(user): +def pids_by_user(user: str) -> Optional[Sequence[int]]: """ Provides processes owned by a given user. @@ -908,7 +909,7 @@ def pids_by_user(user): return None -def cwd(pid): +def cwd(pid: int) -> Optional[str]: """ Provides the working directory of the given process. @@ -977,7 +978,7 @@ def cwd(pid): return None # all queries failed -def user(pid): +def user(pid: int) -> Optional[str]: """ Provides the user a process is running under. @@ -995,10 +996,8 @@ def user(pid): import pwd # only available on unix platforms uid = stem.util.proc.uid(pid) - - if uid and uid.isdigit(): - return pwd.getpwuid(int(uid)).pw_name - except: + return pwd.getpwuid(uid).pw_name + except ImportError: pass if is_available('ps'): @@ -1010,7 +1009,7 @@ def user(pid): return None -def start_time(pid): +def start_time(pid: str) -> Optional[float]: """ Provides the unix timestamp when the given process started. @@ -1041,7 +1040,7 @@ def start_time(pid): return None -def tail(target, lines = None): +def tail(target: Union[str, BinaryIO], lines: Optional[int] = None) -> Iterator[str]: """ Provides lines of a file starting with the end. For instance, 'tail -n 50 /tmp/my_log' could be done with... @@ -1060,8 +1059,8 @@ def tail(target, lines = None): if isinstance(target, str): with open(target, 'rb') as target_file: - for line in tail(target_file, lines): - yield line + for tail_line in tail(target_file, lines): + yield tail_line return @@ -1094,7 +1093,7 @@ def tail(target, lines = None): block_number -= 1 -def bsd_jail_id(pid): +def bsd_jail_id(pid: int) -> int: """ Gets the jail id for a process. These seem to only exist for FreeBSD (this style for jails does not exist on Linux, OSX, or OpenBSD). @@ -1129,7 +1128,7 @@ def bsd_jail_id(pid): return 0 -def bsd_jail_path(jid): +def bsd_jail_path(jid: int) -> Optional[str]: """ Provides the path of the given FreeBSD jail. @@ -1151,7 +1150,7 @@ def bsd_jail_path(jid): return None -def is_tarfile(path): +def is_tarfile(path: str) -> bool: """ Returns if the path belongs to a tarfile or not. @@ -1177,7 +1176,7 @@ def is_tarfile(path): return mimetypes.guess_type(path)[0] == 'application/x-tar' -def expand_path(path, cwd = None): +def expand_path(path: str, cwd: Optional[str] = None) -> str: """ Provides an absolute path, expanding tildes with the user's home and appending a current working directory if the path was relative. @@ -1222,7 +1221,7 @@ def expand_path(path, cwd = None): return relative_path -def files_with_suffix(base_path, suffix): +def files_with_suffix(base_path: str, suffix: str) -> Iterator[str]: """ Iterates over files in a given directory, providing filenames with a certain suffix. @@ -1245,7 +1244,7 @@ def files_with_suffix(base_path, suffix): yield os.path.join(root, filename) -def call(command, default = UNDEFINED, ignore_exit_status = False, timeout = None, cwd = None, env = None): +def call(command: Union[str, Sequence[str]], default: Any = UNDEFINED, ignore_exit_status: bool = False, timeout: Optional[float] = None, cwd: Optional[str] = None, env: Optional[Mapping[str, str]] = None) -> Sequence[str]: """ call(command, default = UNDEFINED, ignore_exit_status = False) @@ -1298,7 +1297,7 @@ def call(command, default = UNDEFINED, ignore_exit_status = False, timeout = Non if timeout: while process.poll() is None: if time.time() - start_time > timeout: - raise CallTimeoutError("Process didn't finish after %0.1f seconds" % timeout, ' '.join(command_list), None, timeout, '', '', timeout) + raise CallTimeoutError("Process didn't finish after %0.1f seconds" % timeout, ' '.join(command_list), None, timeout, b'', b'', timeout) time.sleep(0.001) @@ -1312,11 +1311,11 @@ def call(command, default = UNDEFINED, ignore_exit_status = False, timeout = Non trace_prefix = 'Received from system (%s)' % command if stdout and stderr: - log.trace(trace_prefix + ', stdout:\n%s\nstderr:\n%s' % (stdout, stderr)) + log.trace(trace_prefix + ', stdout:\n%s\nstderr:\n%s' % (stdout.decode('utf-8'), stderr.decode('utf-8'))) elif stdout: - log.trace(trace_prefix + ', stdout:\n%s' % stdout) + log.trace(trace_prefix + ', stdout:\n%s' % stdout.decode('utf-8')) elif stderr: - log.trace(trace_prefix + ', stderr:\n%s' % stderr) + log.trace(trace_prefix + ', stderr:\n%s' % stderr.decode('utf-8')) exit_status = process.poll() @@ -1346,7 +1345,7 @@ def call(command, default = UNDEFINED, ignore_exit_status = False, timeout = Non SYSTEM_CALL_TIME += time.time() - start_time -def get_process_name(): +def get_process_name() -> str: """ Provides the present name of our process. @@ -1398,7 +1397,7 @@ def get_process_name(): return _PROCESS_NAME -def set_process_name(process_name): +def set_process_name(process_name: str) -> None: """ Renames our current process from "python " to a custom name. This is best-effort, not necessarily working on all platforms. @@ -1432,7 +1431,7 @@ def set_process_name(process_name): _set_proc_title(process_name) -def _set_argv(process_name): +def _set_argv(process_name: str) -> None: """ Overwrites our argv in a similar fashion to how it's done in C with: strcpy(argv[0], 'new_name'); @@ -1462,7 +1461,7 @@ def _set_argv(process_name): _PROCESS_NAME = process_name -def _set_prctl_name(process_name): +def _set_prctl_name(process_name: str) -> None: """ Sets the prctl name, which is used by top and killall. This appears to be Linux specific and has the max of 15 characters. @@ -1477,7 +1476,7 @@ def _set_prctl_name(process_name): libc.prctl(PR_SET_NAME, ctypes.byref(name_buffer), 0, 0, 0) -def _set_proc_title(process_name): +def _set_proc_title(process_name: str) -> None: """ BSD specific calls (should be compataible with both FreeBSD and OpenBSD: http://fxr.watson.org/fxr/source/gen/setproctitle.c?v=FREEBSD-LIBC diff --git a/stem/util/term.py b/stem/util/term.py index 063914410..862767c45 100644 --- a/stem/util/term.py +++ b/stem/util/term.py @@ -50,6 +50,8 @@ import stem.util.enum import stem.util.str_tools +from typing import Optional, Union + TERM_COLORS = ('BLACK', 'RED', 'GREEN', 'YELLOW', 'BLUE', 'MAGENTA', 'CYAN', 'WHITE') # DISABLE_COLOR_SUPPORT is *not* being vended to Stem users. This is likely to @@ -70,18 +72,18 @@ RESET = CSI % '0' -def encoding(*attrs): +def encoding(*attrs: Union['stem.util.term.Color', 'stem.util.term.BgColor', 'stem.util.term.Attr']) -> Optional[str]: """ Provides the ANSI escape sequence for these terminal color or attributes. .. versionadded:: 1.5.0 - :param list attr: :data:`~stem.util.terminal.Color`, - :data:`~stem.util.terminal.BgColor`, or :data:`~stem.util.terminal.Attr` to + :param list attr: :data:`~stem.util.term.Color`, + :data:`~stem.util.term.BgColor`, or :data:`~stem.util.term.Attr` to provide an ecoding for :returns: **str** of the ANSI escape sequence, **None** no attributes are - recognized + unrecognized """ term_encodings = [] @@ -97,9 +99,11 @@ def encoding(*attrs): if term_encodings: return CSI % ';'.join(term_encodings) + else: + return None -def format(msg, *attr): +def format(msg: str, *attr: Union['stem.util.term.Color', 'stem.util.term.BgColor', 'stem.util.term.Attr']) -> str: """ Simple terminal text formatting using `ANSI escape sequences `_. @@ -118,17 +122,17 @@ def format(msg, *attr): :data:`~stem.util.term.BgColor`, or :data:`~stem.util.term.Attr` enums and are case insensitive (so strings like 'red' are fine) - :returns: **unicode** wrapped with ANSI escape encodings, starting with the given + :returns: **str** wrapped with ANSI escape encodings, starting with the given attributes and ending with a reset """ msg = stem.util.str_tools._to_unicode(msg) + attr = list(attr) if DISABLE_COLOR_SUPPORT: return msg if Attr.LINES in attr: - attr = list(attr) attr.remove(Attr.LINES) lines = [format(line, *attr) for line in msg.split('\n')] return '\n'.join(lines) diff --git a/stem/util/test_tools.py b/stem/util/test_tools.py index 711652143..345734509 100644 --- a/stem/util/test_tools.py +++ b/stem/util/test_tools.py @@ -23,9 +23,11 @@ is_pyflakes_available - checks if pyflakes is available is_pycodestyle_available - checks if pycodestyle is available + is_mypy_available - checks if mypy is available pyflakes_issues - static checks for problems via pyflakes stylistic_issues - checks for PEP8 and other stylistic issues + type_issues - checks for type problems """ import collections @@ -42,20 +44,23 @@ import stem.util.enum import stem.util.system +from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union + CONFIG = stem.util.conf.config_dict('test', { 'pycodestyle.ignore': [], 'pyflakes.ignore': [], + 'mypy.ignore': [], 'exclude_paths': [], }) -TEST_RUNTIMES = {} -ASYNC_TESTS = {} +TEST_RUNTIMES: Dict[str, float] = {} +ASYNC_TESTS: Dict[str, 'stem.util.test_tools.AsyncTest'] = {} AsyncStatus = stem.util.enum.UppercaseEnum('PENDING', 'RUNNING', 'FINISHED') AsyncResult = collections.namedtuple('AsyncResult', 'type msg') -def assert_equal(expected, actual, msg = None): +def assert_equal(expected: Any, actual: Any, msg: Optional[str] = None) -> None: """ Function form of a TestCase's assertEqual. @@ -72,7 +77,7 @@ def assert_equal(expected, actual, msg = None): raise AssertionError("Expected '%s' but was '%s'" % (expected, actual) if msg is None else msg) -def assert_in(expected, actual, msg = None): +def assert_in(expected: Any, actual: Any, msg: Optional[str] = None) -> None: """ Asserts that a given value is within this content. @@ -89,7 +94,7 @@ def assert_in(expected, actual, msg = None): raise AssertionError("Expected '%s' to be within '%s'" % (expected, actual) if msg is None else msg) -def skip(msg): +def skip(msg: str) -> None: """ Function form of a TestCase's skipTest. @@ -100,10 +105,12 @@ def skip(msg): :raises: **unittest.case.SkipTest** for this reason """ + # TODO: remove now that python 2.x is unsupported? + raise unittest.case.SkipTest(msg) -def asynchronous(func): +def asynchronous(func: Callable) -> Callable: test = stem.util.test_tools.AsyncTest(func) ASYNC_TESTS[test.name] = test return test.method @@ -131,7 +138,7 @@ def test_addition(): .. versionadded:: 1.6.0 """ - def __init__(self, runner, args = None, threaded = False): + def __init__(self, runner: Callable, args: Optional[Any] = None, threaded: bool = False) -> None: self.name = '%s.%s' % (runner.__module__, runner.__name__) self._runner = runner @@ -140,15 +147,15 @@ def __init__(self, runner, args = None, threaded = False): self.method = lambda test: self.result(test) # method that can be mixed into TestCases - self._process = None - self._process_pipe = None + self._process = None # type: Optional[Union[threading.Thread, multiprocessing.Process]] + self._process_pipe = None # type: Optional[multiprocessing.connection.Connection] self._process_lock = threading.RLock() - self._result = None + self._result = None # type: Optional[stem.util.test_tools.AsyncResult] self._status = AsyncStatus.PENDING - def run(self, *runner_args, **kwargs): - def _wrapper(conn, runner, args): + def run(self, *runner_args: Any, **kwargs: Any) -> None: + def _wrapper(conn: 'multiprocessing.connection.Connection', runner: Callable, args: Any) -> None: os.nice(12) try: @@ -187,14 +194,14 @@ def _wrapper(conn, runner, args): self._process.start() self._status = AsyncStatus.RUNNING - def pid(self): + def pid(self) -> Optional[int]: with self._process_lock: - return self._process.pid if (self._process and not self._threaded) else None + return getattr(self._process, 'pid', None) - def join(self): + def join(self) -> None: self.result(None) - def result(self, test): + def result(self, test: 'unittest.TestCase') -> None: with self._process_lock: if self._status == AsyncStatus.PENDING: self.run() @@ -231,18 +238,18 @@ class TimedTestRunner(unittest.TextTestRunner): .. versionadded:: 1.6.0 """ - def run(self, test): - for t in test._tests: - original_type = type(t) + def run(self, test: Union[unittest.TestCase, unittest.TestSuite]) -> unittest.TestResult: + for t in getattr(test, '_tests', ()): + original_type = type(t) # type: Any class _TestWrapper(original_type): - def run(self, result = None): + def run(self, result: Optional[Any] = None) -> Any: start_time = time.time() result = super(type(self), self).run(result) TEST_RUNTIMES[self.id()] = time.time() - start_time return result - def assertRaisesWith(self, exc_type, exc_msg, func, *args, **kwargs): + def assertRaisesWith(self, exc_type: Type[Exception], exc_msg: str, func: Callable, *args: Any, **kwargs: Any) -> None: """ Asserts the given invokation raises the expected excepiton. This is similar to unittest's assertRaises and assertRaisesRegexp, but checks @@ -255,10 +262,10 @@ def assertRaisesWith(self, exc_type, exc_msg, func, *args, **kwargs): return self.assertRaisesRegexp(exc_type, '^%s$' % re.escape(exc_msg), func, *args, **kwargs) - def id(self): + def id(self) -> str: return '%s.%s.%s' % (original_type.__module__, original_type.__name__, self._testMethodName) - def __str__(self): + def __str__(self) -> str: return '%s (%s.%s)' % (self._testMethodName, original_type.__module__, original_type.__name__) t.__class__ = _TestWrapper @@ -266,7 +273,7 @@ def __str__(self): return super(TimedTestRunner, self).run(test) -def test_runtimes(): +def test_runtimes() -> Dict[str, float]: """ Provides the runtimes of tests executed through TimedTestRunners. @@ -279,7 +286,7 @@ def test_runtimes(): return dict(TEST_RUNTIMES) -def clean_orphaned_pyc(paths): +def clean_orphaned_pyc(paths: Sequence[str]) -> List[str]: """ Deletes any file with a \\*.pyc extention without a corresponding \\*.py. This helps to address a common gotcha when deleting python files... @@ -295,7 +302,7 @@ def clean_orphaned_pyc(paths): :param list paths: paths to search for orphaned pyc files - :returns: list of absolute paths that were deleted + :returns: **list** of absolute paths that were deleted """ orphaned_pyc = [] @@ -324,7 +331,7 @@ def clean_orphaned_pyc(paths): return orphaned_pyc -def is_pyflakes_available(): +def is_pyflakes_available() -> bool: """ Checks if pyflakes is availalbe. @@ -334,7 +341,7 @@ def is_pyflakes_available(): return _module_exists('pyflakes.api') and _module_exists('pyflakes.reporter') -def is_pycodestyle_available(): +def is_pycodestyle_available() -> bool: """ Checks if pycodestyle is availalbe. @@ -349,7 +356,17 @@ def is_pycodestyle_available(): return hasattr(pycodestyle, 'BaseReport') -def stylistic_issues(paths, check_newlines = False, check_exception_keyword = False, prefer_single_quotes = False): +def is_mypy_available() -> bool: + """ + Checks if mypy is available. + + :returns: **True** if we can use mypy and **False** otherwise + """ + + return _module_exists('mypy.api') + + +def stylistic_issues(paths: Sequence[str], check_newlines: bool = False, check_exception_keyword: bool = False, prefer_single_quotes: bool = False) -> Dict[str, List['stem.util.test_tools.Issue']]: """ Checks for stylistic issues that are an issue according to the parts of PEP8 we conform to. You can suppress pycodestyle issues by making a 'test' @@ -407,7 +424,7 @@ def stylistic_issues(paths, check_newlines = False, check_exception_keyword = Fa :returns: dict of paths list of :class:`stem.util.test_tools.Issue` instances """ - issues = {} + issues = {} # type: Dict[str, List[stem.util.test_tools.Issue]] ignore_rules = [] ignore_for_file = [] @@ -425,7 +442,7 @@ def stylistic_issues(paths, check_newlines = False, check_exception_keyword = Fa else: ignore_rules.append(rule) - def is_ignored(path, rule, code): + def is_ignored(path: str, rule: str, code: str) -> bool: for ignored_path, ignored_rule, ignored_code in ignore_for_file: if path.endswith(ignored_path) and ignored_rule == rule and code.strip().startswith(ignored_code): return True @@ -440,7 +457,7 @@ def is_ignored(path, rule, code): import pycodestyle class StyleReport(pycodestyle.BaseReport): - def init_file(self, filename, lines, expected, line_offset): + def init_file(self, filename: str, lines: Sequence[str], expected: Tuple[str], line_offset: int) -> None: super(StyleReport, self).init_file(filename, lines, expected, line_offset) if not check_newlines and not check_exception_keyword and not prefer_single_quotes: @@ -473,7 +490,7 @@ def init_file(self, filename, lines, expected, line_offset): issues.setdefault(filename, []).append(Issue(index + 1, 'use single rather than double quotes', line)) - def error(self, line_number, offset, text, check): + def error(self, line_number: int, offset: int, text: str, check: str) -> None: code = super(StyleReport, self).error(line_number, offset, text, check) if code: @@ -488,7 +505,7 @@ def error(self, line_number, offset, text, check): return issues -def pyflakes_issues(paths): +def pyflakes_issues(paths: Sequence[str]) -> Dict[str, List['stem.util.test_tools.Issue']]: """ Performs static checks via pyflakes. False positives can be ignored via 'pyflakes.ignore' entries in our 'test' config. For instance... @@ -514,50 +531,31 @@ def pyflakes_issues(paths): :returns: dict of paths list of :class:`stem.util.test_tools.Issue` instances """ - issues = {} + issues = {} # type: Dict[str, List[stem.util.test_tools.Issue]] if is_pyflakes_available(): import pyflakes.api import pyflakes.reporter class Reporter(pyflakes.reporter.Reporter): - def __init__(self): - self._ignored_issues = {} + def __init__(self) -> None: + self._ignored_issues = {} # type: Dict[str, List[str]] for line in CONFIG['pyflakes.ignore']: path, issue = line.split('=>') self._ignored_issues.setdefault(path.strip(), []).append(issue.strip()) - def unexpectedError(self, filename, msg): + def unexpectedError(self, filename: str, msg: 'pyflakes.messages.Message') -> None: self._register_issue(filename, None, msg, None) - def syntaxError(self, filename, msg, lineno, offset, text): + def syntaxError(self, filename: str, msg: str, lineno: int, offset: int, text: str) -> None: self._register_issue(filename, lineno, msg, text) - def flake(self, msg): + def flake(self, msg: 'pyflakes.messages.Message') -> None: self._register_issue(msg.filename, msg.lineno, msg.message % msg.message_args, None) - def _is_ignored(self, path, issue): - # Paths in pyflakes_ignore are relative, so we need to check to see if our - # path ends with any of them. - - for ignored_path, ignored_issues in self._ignored_issues.items(): - if path.endswith(ignored_path): - if issue in ignored_issues: - return True - - for prefix in [i[:1] for i in ignored_issues if i.endswith('*')]: - if issue.startswith(prefix): - return True - - for suffix in [i[1:] for i in ignored_issues if i.startswith('*')]: - if issue.endswith(suffix): - return True - - return False - - def _register_issue(self, path, line_number, issue, line): - if not self._is_ignored(path, issue): + def _register_issue(self, path: str, line_number: int, issue: str, line: str) -> None: + if not _is_ignored(self._ignored_issues, path, issue): if path and line_number and not line: line = linecache.getline(path, line_number).strip() @@ -571,7 +569,68 @@ def _register_issue(self, path, line_number, issue, line): return issues -def _module_exists(module_name): +def type_issues(args: Sequence[str]) -> Dict[str, List['stem.util.test_tools.Issue']]: + """ + Performs type checks via mypy. False positives can be ignored via + 'mypy.ignore' entries in our 'test' config. For instance... + + :: + + mypy.ignore stem/util/system.py => Incompatible types in assignment* + + :param list args: mypy commmandline arguments + + :returns: dict of paths list of :class:`stem.util.test_tools.Issue` instances + """ + + issues = {} # type: Dict[str, List[stem.util.test_tools.Issue]] + + if is_mypy_available(): + import mypy.api + + ignored_issues = {} # type: Dict[str, List[str]] + + for line in CONFIG['mypy.ignore']: + path, issue = line.split('=>') + ignored_issues.setdefault(path.strip(), []).append(issue.strip()) + + # mypy returns (report, errors, exit_status) + + lines = mypy.api.run(args)[0].splitlines() # type: ignore + + for line in lines: + # example: + # stem/util/__init__.py:89: error: Incompatible return value type (got "Union[bytes, str]", expected "bytes") + + if line.startswith('Found ') and line.endswith(' source files)'): + continue # ex. "Found 1786 errors in 45 files (checked 49 source files)" + elif line.count(':') < 3: + raise ValueError('Failed to parse mypy line: %s' % line) + + path, line_number, _, issue = line.split(':', 3) + + if not line_number.isdigit(): + raise ValueError('Malformed line number on: %s' % line) + + issue = issue.strip() + line_number = int(line_number) + + if _is_ignored(ignored_issues, path, issue): + continue + + # skip getting code if there's too many reported issues + + if len(lines) < 25: + line = linecache.getline(path, line_number).strip() + else: + line = '' + + issues.setdefault(path, []).append(Issue(line_number, issue, line)) + + return issues + + +def _module_exists(module_name: str) -> bool: """ Checks if a module exists. @@ -587,7 +646,7 @@ def _module_exists(module_name): return False -def _python_files(paths): +def _python_files(paths: Sequence[str]) -> Iterator[str]: for path in paths: for file_path in stem.util.system.files_with_suffix(path, '.py'): skip = False @@ -599,3 +658,25 @@ def _python_files(paths): if not skip: yield file_path + + +def _is_ignored(config: Mapping[str, Sequence[str]], path: str, issue: str) -> bool: + for ignored_path, ignored_issues in config.items(): + if ignored_path == '*' or path.endswith(ignored_path): + for ignored_issue in ignored_issues: + if issue == ignored_issue: + return True + + # TODO: try using glob module instead? + + if ignored_issue.startswith('*') and ignored_issue.endswith('*'): + if ignored_issue[1:-1] in issue: + return True # substring match + elif ignored_issue.startswith('*'): + if issue.endswith(ignored_issue[1:]): + return True # prefix match + elif ignored_issue.endswith('*'): + if issue.startswith(ignored_issue[:-1]): + return True # suffix match + + return False diff --git a/stem/util/tor_tools.py b/stem/util/tor_tools.py index 8987635eb..2398b7bc5 100644 --- a/stem/util/tor_tools.py +++ b/stem/util/tor_tools.py @@ -23,6 +23,8 @@ import stem.util.str_tools +from typing import Optional, Sequence, Union + # The control-spec defines the following as... # # Fingerprint = "$" 40*HEXDIG @@ -45,7 +47,7 @@ HS_V3_ADDRESS_PATTERN = re.compile('^[a-z2-7]{56}$') -def is_valid_fingerprint(entry, check_prefix = False): +def is_valid_fingerprint(entry: str, check_prefix: bool = False) -> bool: """ Checks if a string is a properly formatted relay fingerprint. This checks for a '$' prefix if check_prefix is true, otherwise this only validates the hex @@ -72,11 +74,11 @@ def is_valid_fingerprint(entry, check_prefix = False): return False -def is_valid_nickname(entry): +def is_valid_nickname(entry: str) -> bool: """ Checks if a string is a valid format for being a nickname. - :param str entry: string to be checked + :param str entry: string to check :returns: **True** if the string could be a nickname, **False** otherwise """ @@ -90,10 +92,12 @@ def is_valid_nickname(entry): return False -def is_valid_circuit_id(entry): +def is_valid_circuit_id(entry: str) -> bool: """ Checks if a string is a valid format for being a circuit identifier. + :param str entry: string to check + :returns: **True** if the string could be a circuit id, **False** otherwise """ @@ -106,29 +110,33 @@ def is_valid_circuit_id(entry): return False -def is_valid_stream_id(entry): +def is_valid_stream_id(entry: str) -> bool: """ Checks if a string is a valid format for being a stream identifier. Currently, this is just an alias to :func:`~stem.util.tor_tools.is_valid_circuit_id`. + :param str entry: string to check + :returns: **True** if the string could be a stream id, **False** otherwise """ return is_valid_circuit_id(entry) -def is_valid_connection_id(entry): +def is_valid_connection_id(entry: str) -> bool: """ Checks if a string is a valid format for being a connection identifier. Currently, this is just an alias to :func:`~stem.util.tor_tools.is_valid_circuit_id`. + :param str entry: string to check + :returns: **True** if the string could be a connection id, **False** otherwise """ return is_valid_circuit_id(entry) -def is_valid_hidden_service_address(entry, version = None): +def is_valid_hidden_service_address(entry: str, version: Optional[Union[int, Sequence[int]]] = None) -> bool: """ Checks if a string is a valid format for being a hidden service address (not including the '.onion' suffix). @@ -137,6 +145,7 @@ def is_valid_hidden_service_address(entry, version = None): Added the **version** argument, and responds with **True** if a version 3 hidden service address rather than just version 2 addresses. + :param str entry: string to check :param int,list version: versions to check for, if unspecified either v2 or v3 hidden service address will provide **True** @@ -166,7 +175,7 @@ def is_valid_hidden_service_address(entry, version = None): return False -def is_hex_digits(entry, count): +def is_hex_digits(entry: str, count: int) -> bool: """ Checks if a string is the given number of hex digits. Digits represented by letters are case insensitive. diff --git a/stem/version.py b/stem/version.py index 181aec8af..6ef7c890c 100644 --- a/stem/version.py +++ b/stem/version.py @@ -42,13 +42,15 @@ import stem.util.enum import stem.util.system +from typing import Any, Callable + # cache for the get_system_tor_version function VERSION_CACHE = {} VERSION_PATTERN = re.compile(r'^([0-9]+)\.([0-9]+)\.([0-9]+)(\.[0-9]+)?(-\S*)?(( \(\S*\))*)$') -def get_system_tor_version(tor_cmd = 'tor'): +def get_system_tor_version(tor_cmd: str = 'tor') -> 'stem.version.Version': """ Queries tor for its version. This is os dependent, only working on linux, osx, and bsd. @@ -70,9 +72,9 @@ def get_system_tor_version(tor_cmd = 'tor'): if 'No such file or directory' in str(exc): if os.path.isabs(tor_cmd): - exc = "Unable to check tor's version. '%s' doesn't exist." % tor_cmd + raise IOError("Unable to check tor's version. '%s' doesn't exist." % tor_cmd) else: - exc = "Unable to run '%s'. Maybe tor isn't in your PATH?" % version_cmd + raise IOError("Unable to run '%s'. Maybe tor isn't in your PATH?" % version_cmd) raise IOError(exc) @@ -96,7 +98,7 @@ def get_system_tor_version(tor_cmd = 'tor'): @functools.lru_cache() -def _get_version(version_str): +def _get_version(version_str: str) -> 'stem.version.Version': return Version(version_str) @@ -125,18 +127,17 @@ class Version(object): :raises: **ValueError** if input isn't a valid tor version """ - def __init__(self, version_str): + def __init__(self, version_str: str) -> None: self.version_str = version_str version_parts = VERSION_PATTERN.match(version_str) if version_parts: - major, minor, micro, patch, status, extra_str, _ = version_parts.groups() + major, minor, micro, patch_str, status, extra_str, _ = version_parts.groups() # The patch and status matches are optional (may be None) and have an extra # proceeding period or dash if they exist. Stripping those off. - if patch: - patch = int(patch[1:]) + patch = int(patch_str[1:]) if patch_str else None if status: status = status[1:] @@ -157,14 +158,14 @@ def __init__(self, version_str): else: raise ValueError("'%s' isn't a properly formatted tor version" % version_str) - def __str__(self): + def __str__(self) -> str: """ Provides the string used to construct the version. """ return self.version_str - def _compare(self, other, method): + def _compare(self, other: Any, method: Callable[[Any, Any], bool]) -> bool: """ Compares version ordering according to the spec. """ @@ -195,23 +196,23 @@ def _compare(self, other, method): return method(my_status, other_status) - def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'major', 'minor', 'micro', 'patch', 'status', cache = True) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s == o) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other - def __gt__(self, other): + def __gt__(self, other: Any) -> bool: """ Checks if this version meets the requirements for a given feature. """ return self._compare(other, lambda s, o: s > o) - def __ge__(self, other): + def __ge__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s >= o) diff --git a/test/arguments.py b/test/arguments.py index d0f0dc3fd..e06148c41 100644 --- a/test/arguments.py +++ b/test/arguments.py @@ -5,13 +5,14 @@ Commandline argument parsing for our test runner. """ -import collections import getopt import stem.util.conf import stem.util.log import test +from typing import Any, Dict, List, NamedTuple, Optional, Sequence + LOG_TYPE_ERROR = """\ '%s' isn't a logging runlevel, use one of the following instead: TRACE, DEBUG, INFO, NOTICE, WARN, ERROR @@ -23,138 +24,136 @@ 'target.torrc': {}, }) -DEFAULT_ARGS = { - 'run_unit': False, - 'run_integ': False, - 'specific_test': [], - 'exclude_test': [], - 'logging_runlevel': None, - 'logging_path': None, - 'tor_path': 'tor', - 'run_targets': [test.Target.RUN_OPEN], - 'attribute_targets': [], - 'quiet': False, - 'verbose': False, - 'print_help': False, -} - OPT = 'auit:l:qvh' OPT_EXPANDED = ['all', 'unit', 'integ', 'targets=', 'test=', 'exclude-test=', 'log=', 'log-file=', 'tor=', 'quiet', 'verbose', 'help'] -def parse(argv): - """ - Parses our arguments, providing a named tuple with their values. +class Arguments(NamedTuple): + run_unit: bool = False + run_integ: bool = False + specific_test: List[str] = [] + exclude_test: List[str] = [] + logging_runlevel: Optional[str] = None + logging_path: Optional[str] = None + tor_path: str = 'tor' + run_targets: List['test.Target'] = [test.Target.RUN_OPEN] + attribute_targets: List['test.Target'] = [] + quiet: bool = False + verbose: bool = False + print_help: bool = False + + @staticmethod + def parse(argv: Sequence[str]) -> 'test.arguments.Arguments': + """ + Parses our commandline arguments into this class. + + :param list argv: input arguments to be parsed + + :returns: :class:`test.arguments.Arguments` for this commandline input + + :raises: **ValueError** if we got an invalid argument + """ - :param list argv: input arguments to be parsed + args = {} # type: Dict[str, Any] - :returns: a **named tuple** with our parsed arguments + try: + recognized_args, unrecognized_args = getopt.getopt(argv, OPT, OPT_EXPANDED) # type: ignore - :raises: **ValueError** if we got an invalid argument - """ + if unrecognized_args: + error_msg = "aren't recognized arguments" if len(unrecognized_args) > 1 else "isn't a recognized argument" + raise getopt.GetoptError("'%s' %s" % ("', '".join(unrecognized_args), error_msg)) + except Exception as exc: + raise ValueError('%s (for usage provide --help)' % exc) - args = dict(DEFAULT_ARGS) - - try: - recognized_args, unrecognized_args = getopt.getopt(argv, OPT, OPT_EXPANDED) - - if unrecognized_args: - error_msg = "aren't recognized arguments" if len(unrecognized_args) > 1 else "isn't a recognized argument" - raise getopt.GetoptError("'%s' %s" % ("', '".join(unrecognized_args), error_msg)) - except Exception as exc: - raise ValueError('%s (for usage provide --help)' % exc) - - for opt, arg in recognized_args: - if opt in ('-a', '--all'): - args['run_unit'] = True - args['run_integ'] = True - elif opt in ('-u', '--unit'): - args['run_unit'] = True - elif opt in ('-i', '--integ'): - args['run_integ'] = True - elif opt in ('-t', '--targets'): - run_targets, attribute_targets = [], [] - - integ_targets = arg.split(',') - all_run_targets = [t for t in test.Target if CONFIG['target.torrc'].get(t) is not None] - - # validates the targets and split them into run and attribute targets - - if not integ_targets: - raise ValueError('No targets provided') - - for target in integ_targets: - if target not in test.Target: - raise ValueError('Invalid integration target: %s' % target) - elif target in all_run_targets: - run_targets.append(target) - else: - attribute_targets.append(target) - - # check if we were told to use all run targets - - if test.Target.RUN_ALL in attribute_targets: - attribute_targets.remove(test.Target.RUN_ALL) - run_targets = all_run_targets - - # if no RUN_* targets are provided then keep the default (otherwise we - # won't have any tests to run) - - if run_targets: - args['run_targets'] = run_targets - - args['attribute_targets'] = attribute_targets - elif opt == '--test': - args['specific_test'].append(crop_module_name(arg)) - elif opt == '--exclude-test': - args['exclude_test'].append(crop_module_name(arg)) - elif opt in ('-l', '--log'): - arg = arg.upper() - - if arg not in stem.util.log.LOG_VALUES: - raise ValueError(LOG_TYPE_ERROR % arg) - - args['logging_runlevel'] = arg - elif opt == '--log-file': - args['logging_path'] = arg - elif opt in ('--tor'): - args['tor_path'] = arg - elif opt in ('-q', '--quiet'): - args['quiet'] = True - elif opt in ('-v', '--verbose'): - args['verbose'] = True - elif opt in ('-h', '--help'): - args['print_help'] = True - - # translates our args dict into a named tuple - - Args = collections.namedtuple('Args', args.keys()) - return Args(**args) - - -def get_help(): - """ - Provides usage information, as provided by the '--help' argument. This - includes a listing of the valid integration targets. + for opt, arg in recognized_args: + if opt in ('-a', '--all'): + args['run_unit'] = True + args['run_integ'] = True + elif opt in ('-u', '--unit'): + args['run_unit'] = True + elif opt in ('-i', '--integ'): + args['run_integ'] = True + elif opt in ('-t', '--targets'): + run_targets, attribute_targets = [], [] - :returns: **str** with our usage information - """ + integ_targets = arg.split(',') + all_run_targets = [t for t in test.Target if CONFIG['target.torrc'].get(t) is not None] + + # validates the targets and split them into run and attribute targets + + if not integ_targets: + raise ValueError('No targets provided') + + for target in integ_targets: + if target not in test.Target: + raise ValueError('Invalid integration target: %s' % target) + elif target in all_run_targets: + run_targets.append(target) + else: + attribute_targets.append(target) + + # check if we were told to use all run targets + + if test.Target.RUN_ALL in attribute_targets: + attribute_targets.remove(test.Target.RUN_ALL) + run_targets = all_run_targets + + # if no RUN_* targets are provided then keep the default (otherwise we + # won't have any tests to run) + + if run_targets: + args['run_targets'] = run_targets + + args['attribute_targets'] = attribute_targets + elif opt == '--test': + args['specific_test'].append(crop_module_name(arg)) + elif opt == '--exclude-test': + args['exclude_test'].append(crop_module_name(arg)) + elif opt in ('-l', '--log'): + arg = arg.upper() + + if arg not in stem.util.log.LOG_VALUES: + raise ValueError(LOG_TYPE_ERROR % arg) + + args['logging_runlevel'] = arg + elif opt == '--log-file': + args['logging_path'] = arg + elif opt in ('--tor'): + args['tor_path'] = arg + elif opt in ('-q', '--quiet'): + args['quiet'] = True + elif opt in ('-v', '--verbose'): + args['verbose'] = True + elif opt in ('-h', '--help'): + args['print_help'] = True + + return Arguments(**args) + + @staticmethod + def get_help() -> str: + """ + Provides usage information, as provided by the '--help' argument. This + includes a listing of the valid integration targets. + + :returns: **str** with our usage information + """ + + help_msg = CONFIG['msg.help'] - help_msg = CONFIG['msg.help'] + # gets the longest target length so we can show the entries in columns - # gets the longest target length so we can show the entries in columns - target_name_length = max(map(len, test.Target)) - description_format = '\n %%-%is - %%s' % target_name_length + target_name_length = max(map(len, test.Target)) + description_format = '\n %%-%is - %%s' % target_name_length - for target in test.Target: - help_msg += description_format % (target, CONFIG['target.description'].get(target, '')) + for target in test.Target: + help_msg += description_format % (target, CONFIG['target.description'].get(target, '')) - help_msg += '\n' + help_msg += '\n' - return help_msg + return help_msg -def crop_module_name(name): +def crop_module_name(name: str) -> str: """ Test modules have a 'test.unit.' or 'test.integ.' prefix which can be omitted from our '--test' argument. Cropping this so we can do diff --git a/test/integ/control/controller.py b/test/integ/control/controller.py index 8b8b3205f..732ae50af 100644 --- a/test/integ/control/controller.py +++ b/test/integ/control/controller.py @@ -339,14 +339,14 @@ def test_protocolinfo(self): auth_methods = [] if test.runner.Torrc.COOKIE in tor_options: - auth_methods.append(stem.response.protocolinfo.AuthMethod.COOKIE) - auth_methods.append(stem.response.protocolinfo.AuthMethod.SAFECOOKIE) + auth_methods.append(stem.connection.AuthMethod.COOKIE) + auth_methods.append(stem.connection.AuthMethod.SAFECOOKIE) if test.runner.Torrc.PASSWORD in tor_options: - auth_methods.append(stem.response.protocolinfo.AuthMethod.PASSWORD) + auth_methods.append(stem.connection.AuthMethod.PASSWORD) if not auth_methods: - auth_methods.append(stem.response.protocolinfo.AuthMethod.NONE) + auth_methods.append(stem.connection.AuthMethod.NONE) self.assertEqual(tuple(auth_methods), protocolinfo.auth_methods) diff --git a/test/integ/response/protocolinfo.py b/test/integ/response/protocolinfo.py index 2fb060dbe..3a9ee0beb 100644 --- a/test/integ/response/protocolinfo.py +++ b/test/integ/response/protocolinfo.py @@ -125,8 +125,8 @@ def assert_matches_test_config(self, protocolinfo_response): auth_methods, auth_cookie_path = [], None if test.runner.Torrc.COOKIE in tor_options: - auth_methods.append(stem.response.protocolinfo.AuthMethod.COOKIE) - auth_methods.append(stem.response.protocolinfo.AuthMethod.SAFECOOKIE) + auth_methods.append(stem.connection.AuthMethod.COOKIE) + auth_methods.append(stem.connection.AuthMethod.SAFECOOKIE) chroot_path = runner.get_chroot() auth_cookie_path = runner.get_auth_cookie_path() @@ -135,10 +135,10 @@ def assert_matches_test_config(self, protocolinfo_response): auth_cookie_path = auth_cookie_path[len(chroot_path):] if test.runner.Torrc.PASSWORD in tor_options: - auth_methods.append(stem.response.protocolinfo.AuthMethod.PASSWORD) + auth_methods.append(stem.connection.AuthMethod.PASSWORD) if not auth_methods: - auth_methods.append(stem.response.protocolinfo.AuthMethod.NONE) + auth_methods.append(stem.connection.AuthMethod.NONE) self.assertEqual((), protocolinfo_response.unknown_auth_methods) self.assertEqual(tuple(auth_methods), protocolinfo_response.auth_methods) diff --git a/test/mypy.ini b/test/mypy.ini new file mode 100644 index 000000000..1c77449a6 --- /dev/null +++ b/test/mypy.ini @@ -0,0 +1,6 @@ +[mypy] +allow_redefinition = True +ignore_missing_imports = True +show_error_codes = True +strict_optional = False +warn_unused_ignores = True diff --git a/test/settings.cfg b/test/settings.cfg index 38a37ef98..51109f96e 100644 --- a/test/settings.cfg +++ b/test/settings.cfg @@ -192,20 +192,18 @@ pycodestyle.ignore test/unit/util/connection.py => W291: _tor tor 158 # False positives from pyflakes. These are mappings between the path and the # issue. -pyflakes.ignore run_tests.py => 'unittest' imported but unused -pyflakes.ignore stem/control.py => undefined name 'controller' -pyflakes.ignore stem/manual.py => undefined name 'unichr' +pyflakes.ignore stem/manual.py => undefined name 'sqlite3' +pyflakes.ignore stem/client/cell.py => undefined name 'cryptography' +pyflakes.ignore stem/client/cell.py => undefined name 'hashlib' pyflakes.ignore stem/client/datatype.py => redefinition of unused 'pop' from * -pyflakes.ignore stem/descriptor/hidden_service_descriptor.py => 'stem.descriptor.hidden_service.*' imported but unused -pyflakes.ignore stem/descriptor/hidden_service_descriptor.py => 'from stem.descriptor.hidden_service import *' used; unable to detect undefined names -pyflakes.ignore stem/interpreter/__init__.py => undefined name 'raw_input' -pyflakes.ignore stem/response/events.py => undefined name 'long' -pyflakes.ignore stem/util/__init__.py => undefined name 'long' -pyflakes.ignore stem/util/__init__.py => undefined name 'unicode' -pyflakes.ignore stem/util/conf.py => undefined name 'unicode' -pyflakes.ignore stem/util/test_tools.py => 'pyflakes' imported but unused -pyflakes.ignore stem/util/test_tools.py => 'pycodestyle' imported but unused -pyflakes.ignore test/__init__.py => undefined name 'test' +pyflakes.ignore stem/descriptor/__init__.py => undefined name 'cryptography' +pyflakes.ignore stem/descriptor/hidden_service.py => undefined name 'cryptography' +pyflakes.ignore stem/interpreter/autocomplete.py => undefined name 'stem' +pyflakes.ignore stem/interpreter/help.py => undefined name 'stem' +pyflakes.ignore stem/response/events.py => undefined name 'datetime' +pyflakes.ignore stem/socket.py => redefinition of unused '_recv'* +pyflakes.ignore stem/util/conf.py => undefined name 'stem' +pyflakes.ignore stem/util/enum.py => undefined name 'stem' pyflakes.ignore test/require.py => 'cryptography.utils.int_from_bytes' imported but unused pyflakes.ignore test/require.py => 'cryptography.utils.int_to_bytes' imported but unused pyflakes.ignore test/require.py => 'cryptography.hazmat.backends.default_backend' imported but unused @@ -216,9 +214,25 @@ pyflakes.ignore test/require.py => 'cryptography.hazmat.primitives.serialization pyflakes.ignore test/require.py => 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey' imported but unused pyflakes.ignore test/unit/response/events.py => 'from stem import *' used; unable to detect undefined names pyflakes.ignore test/unit/response/events.py => *may be undefined, or defined from star imports: stem -pyflakes.ignore stem/util/str_tools.py => undefined name 'unicode' pyflakes.ignore test/integ/interpreter.py => 'readline' imported but unused +# Our enum class confuses mypy. Ignore this until we can change to python 3.x's +# new enum builtin. +# +# For example... +# +# See https://mypy.readthedocs.io/en/latest/common_issues.html#variables-vs-type-aliases +# Variable "stem.control.EventType" is not valid as a type [valid-type] + +mypy.ignore * => "Enum" has no attribute * +mypy.ignore * => "_IntegerEnum" has no attribute * +mypy.ignore * => See https://mypy.readthedocs.io/en/latest/common_issues.html* +mypy.ignore * => *is not valid as a type* + +# Metaprogramming prevents mypy from determining descriptor attributes. + +mypy.ignore * => "Descriptor" has no attribute "* + # Test modules we want to run. Modules are roughly ordered by the dependencies # so the lowest level tests come first. This is because a problem in say, # controller message parsing, will cause all higher level tests to fail too. diff --git a/test/task.py b/test/task.py index 939e263c7..b2957e657 100644 --- a/test/task.py +++ b/test/task.py @@ -16,12 +16,14 @@ |- CRYPTO_VERSION - checks our version of cryptography |- PYFLAKES_VERSION - checks our version of pyflakes |- PYCODESTYLE_VERSION - checks our version of pycodestyle + |- MYPY_VERSION - checks our version of mypy |- CLEAN_PYC - removes any *.pyc without a corresponding *.py |- REMOVE_TOR_DATA_DIR - removes our tor data directory |- IMPORT_TESTS - ensure all test modules have been imported |- UNUSED_TESTS - checks to see if any tests are missing from our settings |- PYFLAKES_TASK - static checks - +- PYCODESTYLE_TASK - style checks + |- PYCODESTYLE_TASK - style checks + +- MYPY_TASK - type checks """ import importlib @@ -60,12 +62,12 @@ 'cache_fallback_directories.py', 'setup.py', 'tor-prompt', - os.path.join('docs', 'republish.py'), os.path.join('docs', 'roles.py'), )] PYFLAKES_UNAVAILABLE = 'Static error checking requires pyflakes version 0.7.3 or later. Please install it from ...\n https://pypi.org/project/pyflakes/\n' PYCODESTYLE_UNAVAILABLE = 'Style checks require pycodestyle version 1.4.2 or later. Please install it from...\n https://pypi.org/project/pycodestyle/\n' +MYPY_UNAVAILABLE = 'Type checks require mypy. Please install it from...\n http://mypy-lang.org/\n' def _check_stem_version(): @@ -324,6 +326,7 @@ def run(self): CRYPTO_VERSION = ModuleVersion('cryptography version', 'cryptography', lambda: test.require.CRYPTOGRAPHY_AVAILABLE) PYFLAKES_VERSION = ModuleVersion('pyflakes version', 'pyflakes') PYCODESTYLE_VERSION = ModuleVersion('pycodestyle version', ['pycodestyle', 'pep8']) +MYPY_VERSION = ModuleVersion('mypy version', 'mypy.version') CLEAN_PYC = Task('checking for orphaned .pyc files', _clean_orphaned_pyc, (SRC_PATHS,), print_runtime = True) REMOVE_TOR_DATA_DIR = Task('emptying our tor data directory', _remove_tor_data_dir) IMPORT_TESTS = Task('importing test modules', _import_tests, print_runtime = True) @@ -348,3 +351,11 @@ def run(self): is_available = stem.util.test_tools.is_pycodestyle_available(), unavailable_msg = PYCODESTYLE_UNAVAILABLE, ) + +MYPY_TASK = StaticCheckTask( + 'running mypy', + stem.util.test_tools.type_issues, + args = (['--config-file', os.path.join(test.STEM_BASE, 'test', 'mypy.ini'), os.path.join(test.STEM_BASE, 'stem')],), + is_available = stem.util.test_tools.is_mypy_available(), + unavailable_msg = MYPY_UNAVAILABLE, +) diff --git a/test/unit/client/address.py b/test/unit/client/address.py index c4e51f4eb..352a8d8c1 100644 --- a/test/unit/client/address.py +++ b/test/unit/client/address.py @@ -50,7 +50,7 @@ def test_unknown_type(self): self.assertEqual(AddrType.UNKNOWN, addr.type) self.assertEqual(12, addr.type_int) self.assertEqual(None, addr.value) - self.assertEqual('hello', addr.value_bin) + self.assertEqual(b'hello', addr.value_bin) def test_packing(self): test_data = { diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py index 37e252b96..02ed27749 100644 --- a/test/unit/control/controller.py +++ b/test/unit/control/controller.py @@ -206,7 +206,7 @@ def test_get_ports(self, get_conf_mock, get_info_mock): get_info_mock.side_effect = InvalidArguments - get_conf_mock.side_effect = lambda param, **kwargs: { + get_conf_mock.side_effect = lambda param, *args, **kwargs: { 'ControlPort': '9050', 'ControlListenAddress': ['127.0.0.1'], }[param] @@ -217,7 +217,7 @@ def test_get_ports(self, get_conf_mock, get_info_mock): # non-local addresss - get_conf_mock.side_effect = lambda param, **kwargs: { + get_conf_mock.side_effect = lambda param, *args, **kwargs: { 'ControlPort': '9050', 'ControlListenAddress': ['27.4.4.1'], }[param] @@ -679,7 +679,7 @@ def test_get_effective_rate(self, get_conf_mock): # check default if nothing was set - get_conf_mock.side_effect = lambda param, **kwargs: { + get_conf_mock.side_effect = lambda param, *args, **kwargs: { 'BandwidthRate': '1073741824', 'BandwidthBurst': '1073741824', 'RelayBandwidthRate': '0', diff --git a/test/unit/descriptor/bandwidth_file.py b/test/unit/descriptor/bandwidth_file.py index 9bee5f95d..5e56f9d24 100644 --- a/test/unit/descriptor/bandwidth_file.py +++ b/test/unit/descriptor/bandwidth_file.py @@ -7,6 +7,7 @@ import unittest import stem.descriptor +import stem.util.str_tools from unittest.mock import Mock, patch @@ -334,5 +335,5 @@ def test_invalid_timestamp(self): ) for value in test_values: - expected_exc = "First line should be a unix timestamp, but was '%s'" % value + expected_exc = "First line should be a unix timestamp, but was '%s'" % stem.util.str_tools._to_unicode(value) self.assertRaisesWith(ValueError, expected_exc, BandwidthFile.create, {'timestamp': value}) diff --git a/test/unit/interpreter/arguments.py b/test/unit/interpreter/arguments.py index df81e7e3b..d61de42df 100644 --- a/test/unit/interpreter/arguments.py +++ b/test/unit/interpreter/arguments.py @@ -1,39 +1,39 @@ import unittest -from stem.interpreter.arguments import DEFAULT_ARGS, parse, get_help +from stem.interpreter.arguments import Arguments class TestArgumentParsing(unittest.TestCase): def test_that_we_get_default_values(self): - args = parse([]) + args = Arguments.parse([]) - for attr in DEFAULT_ARGS: - self.assertEqual(DEFAULT_ARGS[attr], getattr(args, attr)) + for attr, value in Arguments._field_defaults.items(): + self.assertEqual(value, getattr(args, attr)) def test_that_we_load_arguments(self): - args = parse(['--interface', '10.0.0.25:80']) + args = Arguments.parse(['--interface', '10.0.0.25:80']) self.assertEqual('10.0.0.25', args.control_address) self.assertEqual(80, args.control_port) - args = parse(['--interface', '80']) - self.assertEqual(DEFAULT_ARGS['control_address'], args.control_address) + args = Arguments.parse(['--interface', '80']) + self.assertEqual('127.0.0.1', args.control_address) self.assertEqual(80, args.control_port) - args = parse(['--socket', '/tmp/my_socket']) + args = Arguments.parse(['--socket', '/tmp/my_socket']) self.assertEqual('/tmp/my_socket', args.control_socket) - args = parse(['--help']) + args = Arguments.parse(['--help']) self.assertEqual(True, args.print_help) def test_examples(self): - args = parse(['-i', '1643']) + args = Arguments.parse(['-i', '1643']) self.assertEqual(1643, args.control_port) - args = parse(['-s', '~/.tor/socket']) + args = Arguments.parse(['-s', '~/.tor/socket']) self.assertEqual('~/.tor/socket', args.control_socket) def test_that_we_reject_unrecognized_arguments(self): - self.assertRaises(ValueError, parse, ['--blarg', 'stuff']) + self.assertRaises(ValueError, Arguments.parse, ['--blarg', 'stuff']) def test_that_we_reject_invalid_interfaces(self): invalid_inputs = ( @@ -49,15 +49,15 @@ def test_that_we_reject_invalid_interfaces(self): ) for invalid_input in invalid_inputs: - self.assertRaises(ValueError, parse, ['--interface', invalid_input]) + self.assertRaises(ValueError, Arguments.parse, ['--interface', invalid_input]) def test_run_with_command(self): - self.assertEqual('GETINFO version', parse(['--run', 'GETINFO version']).run_cmd) + self.assertEqual('GETINFO version', Arguments.parse(['--run', 'GETINFO version']).run_cmd) def test_run_with_path(self): - self.assertEqual(__file__, parse(['--run', __file__]).run_path) + self.assertEqual(__file__, Arguments.parse(['--run', __file__]).run_path) def test_get_help(self): - help_text = get_help() + help_text = Arguments.get_help() self.assertTrue('Interactive interpreter for Tor.' in help_text) self.assertTrue('change control interface from 127.0.0.1:default' in help_text) diff --git a/test/unit/response/protocolinfo.py b/test/unit/response/protocolinfo.py index dd8d21607..a71746c99 100644 --- a/test/unit/response/protocolinfo.py +++ b/test/unit/response/protocolinfo.py @@ -13,8 +13,8 @@ from unittest.mock import Mock, patch +from stem.connection import AuthMethod from stem.response import ControlMessage -from stem.response.protocolinfo import AuthMethod NO_AUTH = """250-PROTOCOLINFO 1 250-AUTH METHODS=NULL diff --git a/test/unit/util/proc.py b/test/unit/util/proc.py index 2316f669e..39087cbe0 100644 --- a/test/unit/util/proc.py +++ b/test/unit/util/proc.py @@ -147,18 +147,17 @@ def test_stats(self, get_line_mock): # tests the case where pid = 0 - if 'start time' in args: - response = 10 - else: - response = () - - for arg in args: - if arg == 'command': - response += ('sched',) - elif arg == 'utime': - response += ('0',) - elif arg == 'stime': - response += ('0',) + response = () + + for arg in args: + if arg == 'command': + response += ('sched',) + elif arg == 'utime': + response += ('0',) + elif arg == 'stime': + response += ('0',) + elif arg == 'start time': + response += ('10',) get_line_mock.side_effect = lambda *params: { ('/proc/0/stat', '0', 'process %s' % ', '.join(args)): stat