diff --git a/examples/test_examples.py b/examples/test_examples.py index 4ad9acd91..001f9008c 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -19,16 +19,17 @@ # limitations under the License. -from unittest import TestCase +from test.util import ServerTestCase # tag::minimal-example-import[] from neo4j.v1 import GraphDatabase # end::minimal-example-import[] -class FreshDatabaseTestCase(TestCase): +class FreshDatabaseTestCase(ServerTestCase): def setUp(self): + ServerTestCase.setUp(self) session = GraphDatabase.driver("bolt://localhost").session() session.run("MATCH (n) DETACH DELETE n") session.close() diff --git a/neo4j/v1/__init__.py b/neo4j/v1/__init__.py index d51d7b9af..1a1b454b3 100644 --- a/neo4j/v1/__init__.py +++ b/neo4j/v1/__init__.py @@ -18,5 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .constants import * from .session import * from .typesystem import * diff --git a/neo4j/v1/compat.py b/neo4j/v1/compat.py index 24cdbc744..dc21adad6 100644 --- a/neo4j/v1/compat.py +++ b/neo4j/v1/compat.py @@ -90,19 +90,3 @@ def perf_counter(): from urllib.parse import urlparse except ImportError: from urlparse import urlparse - - -try: - from ssl import SSLContext, PROTOCOL_SSLv23, OP_NO_SSLv2, HAS_SNI -except ImportError: - from ssl import wrap_socket, PROTOCOL_SSLv23 - - def secure_socket(s, host): - return wrap_socket(s, ssl_version=PROTOCOL_SSLv23) - -else: - - def secure_socket(s, host): - ssl_context = SSLContext(PROTOCOL_SSLv23) - ssl_context.options |= OP_NO_SSLv2 - return ssl_context.wrap_socket(s, server_hostname=host if HAS_SNI else None) diff --git a/neo4j/v1/connection.py b/neo4j/v1/connection.py index f64aff0c4..165b4c30d 100644 --- a/neo4j/v1/connection.py +++ b/neo4j/v1/connection.py @@ -21,25 +21,24 @@ from __future__ import division +from base64 import b64encode from collections import deque from io import BytesIO import logging -from os import environ +from os import makedirs, open as os_open, write as os_write, close as os_close, O_CREAT, O_APPEND, O_WRONLY +from os.path import dirname, isfile from select import select from socket import create_connection, SHUT_RDWR +from ssl import HAS_SNI, SSLError from struct import pack as struct_pack, unpack as struct_unpack, unpack_from as struct_unpack_from -from ..meta import version -from .compat import hex2, secure_socket +from .constants import DEFAULT_PORT, DEFAULT_USER_AGENT, KNOWN_HOSTS, MAGIC_PREAMBLE, \ + SECURITY_DEFAULT, SECURITY_TRUST_ON_FIRST_USE +from .compat import hex2 from .exceptions import ProtocolError from .packstream import Packer, Unpacker -DEFAULT_PORT = 7687 -DEFAULT_USER_AGENT = "neo4j-python/%s" % version - -MAGIC_PREAMBLE = 0x6060B017 - # Signature bytes for each message type INIT = b"\x01" # 0000 0001 // INIT RESET = b"\x0F" # 0000 1111 // RESET @@ -211,6 +210,10 @@ def __init__(self, sock, **config): user_agent = config.get("user_agent", DEFAULT_USER_AGENT) if isinstance(user_agent, bytes): user_agent = user_agent.decode("UTF-8") + self.user_agent = user_agent + + # Pick up the server certificate, if any + self.der_encoded_server_certificate = config.get("der_encoded_server_certificate") def on_failure(metadata): raise ProtocolError("Initialisation failed") @@ -218,7 +221,7 @@ def on_failure(metadata): response = Response(self) response.on_failure = on_failure - self.append(INIT, (user_agent,), response=response) + self.append(INIT, (self.user_agent,), response=response) self.send() while not response.complete: self.fetch() @@ -313,7 +316,53 @@ def close(self): self.closed = True -def connect(host, port=None, **config): +class CertificateStore(object): + + def match_or_trust(self, host, der_encoded_certificate): + """ Check whether the supplied certificate matches that stored for the + specified host. If it does, return ``True``, if it doesn't, return + ``False``. If no entry for that host is found, add it to the store + and return ``True``. + + :arg host: + :arg der_encoded_certificate: + :return: + """ + raise NotImplementedError() + + +class PersonalCertificateStore(CertificateStore): + + def __init__(self, path=None): + self.path = path or KNOWN_HOSTS + + def match_or_trust(self, host, der_encoded_certificate): + base64_encoded_certificate = b64encode(der_encoded_certificate) + if isfile(self.path): + with open(self.path) as f_in: + for line in f_in: + known_host, _, known_cert = line.strip().partition(":") + known_cert = known_cert.encode("utf-8") + if host == known_host: + return base64_encoded_certificate == known_cert + # First use (no hosts match) + try: + makedirs(dirname(self.path)) + except OSError: + pass + f_out = os_open(self.path, O_CREAT | O_APPEND | O_WRONLY, 0o600) # TODO: Windows + if isinstance(host, bytes): + os_write(f_out, host) + else: + os_write(f_out, host.encode("utf-8")) + os_write(f_out, b":") + os_write(f_out, base64_encoded_certificate) + os_write(f_out, b"\n") + os_close(f_out) + return True + + +def connect(host, port=None, ssl_context=None, **config): """ Connect and perform a handshake and return a valid Connection object, assuming a protocol version can be agreed. """ @@ -323,14 +372,28 @@ def connect(host, port=None, **config): if __debug__: log_info("~~ [CONNECT] %s %d", host, port) s = create_connection((host, port)) - # Secure the connection if so requested - try: - secure = environ["NEO4J_SECURE"] - except KeyError: - secure = config.get("secure", False) - if secure: + # Secure the connection if an SSL context has been provided + if ssl_context: if __debug__: log_info("~~ [SECURE] %s", host) - s = secure_socket(s, host) + try: + s = ssl_context.wrap_socket(s, server_hostname=host if HAS_SNI else None) + except SSLError as cause: + error = ProtocolError("Cannot establish secure connection; %s" % cause.args[1]) + error.__cause__ = cause + raise error + else: + # Check that the server provides a certificate + der_encoded_server_certificate = s.getpeercert(binary_form=True) + if der_encoded_server_certificate is None: + raise ProtocolError("When using a secure socket, the server should always provide a certificate") + security = config.get("security", SECURITY_DEFAULT) + if security == SECURITY_TRUST_ON_FIRST_USE: + store = PersonalCertificateStore() + if not store.match_or_trust(host, der_encoded_server_certificate): + raise ProtocolError("Server certificate does not match known certificate for %r; check " + "details in file %r" % (host, KNOWN_HOSTS)) + else: + der_encoded_server_certificate = None # Send details of the protocol versions supported supported_versions = [1, 0, 0, 0] @@ -364,4 +427,4 @@ def connect(host, port=None, **config): s.shutdown(SHUT_RDWR) s.close() else: - return Connection(s, **config) + return Connection(s, der_encoded_server_certificate=der_encoded_server_certificate, **config) diff --git a/neo4j/v1/constants.py b/neo4j/v1/constants.py new file mode 100644 index 000000000..238c24ed4 --- /dev/null +++ b/neo4j/v1/constants.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) 2002-2016 "Neo Technology," +# Network Engine for Objects in Lund AB [http://neotechnology.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from os.path import expanduser, join + +from ..meta import version + + +DEFAULT_PORT = 7687 +DEFAULT_USER_AGENT = "neo4j-python/%s" % version + +KNOWN_HOSTS = join(expanduser("~"), ".neo4j", "known_hosts") + +MAGIC_PREAMBLE = 0x6060B017 + +SECURITY_NONE = 0 +SECURITY_TRUST_ON_FIRST_USE = 1 +SECURITY_VERIFIED = 2 + +SECURITY_DEFAULT = SECURITY_TRUST_ON_FIRST_USE diff --git a/neo4j/v1/session.py b/neo4j/v1/session.py index cace94ff6..56b0ac5bb 100644 --- a/neo4j/v1/session.py +++ b/neo4j/v1/session.py @@ -29,9 +29,11 @@ class which can be used to obtain `Driver` instances that are used for from __future__ import division from collections import deque, namedtuple +from ssl import SSLContext, PROTOCOL_SSLv23, OP_NO_SSLv2, CERT_REQUIRED, Purpose from .compat import integer, string, urlparse from .connection import connect, Response, RUN, PULL_ALL +from .constants import SECURITY_NONE, SECURITY_VERIFIED, SECURITY_DEFAULT from .exceptions import CypherError, ResultError from .typesystem import hydrated @@ -77,6 +79,16 @@ def __init__(self, url, **config): self.config = config self.max_pool_size = config.get("max_pool_size", DEFAULT_MAX_POOL_SIZE) self.session_pool = deque() + self.security = security = config.get("security", SECURITY_DEFAULT) + if security > SECURITY_NONE: + ssl_context = SSLContext(PROTOCOL_SSLv23) + ssl_context.options |= OP_NO_SSLv2 + if security >= SECURITY_VERIFIED: + ssl_context.verify_mode = CERT_REQUIRED + ssl_context.load_default_certs(Purpose.SERVER_AUTH) + self.ssl_context = ssl_context + else: + self.ssl_context = None def session(self): """ Create a new session based on the graph database details @@ -425,7 +437,7 @@ class Session(object): def __init__(self, driver): self.driver = driver - self.connection = connect(driver.host, driver.port, **driver.config) + self.connection = connect(driver.host, driver.port, driver.ssl_context, **driver.config) self.transaction = None self.last_cursor = None @@ -654,6 +666,7 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + def record(obj): """ Obtain an immutable record for the given object (either by calling obj.__record__() or by copying out the record data) diff --git a/test/tck/tck_util.py b/test/tck/tck_util.py index d05913f75..cb755fbe1 100644 --- a/test/tck/tck_util.py +++ b/test/tck/tck_util.py @@ -18,11 +18,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neo4j.v1 import compat, Relationship, Node, Path +from neo4j.v1 import GraphDatabase, Relationship, Node, Path, SECURITY_NONE +from neo4j.v1.compat import string -from neo4j.v1 import GraphDatabase -driver = GraphDatabase.driver("bolt://localhost") +driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_NONE) def send_string(text): @@ -39,11 +39,10 @@ def send_parameters(statement, parameters): return list(cursor.stream()) -def to_unicode(val): - try: - return unicode(val) - except NameError: - return str(val) +try: + to_unicode = unicode +except NameError: + to_unicode = str def string_to_type(str): @@ -91,7 +90,7 @@ def __init__(self, entity): elif isinstance(entity, Path): self.content = self.create_path(entity) elif isinstance(entity, int) or isinstance(entity, float) or isinstance(entity, - (str, compat.string)) or entity is None: + (str, string)) or entity is None: self.content['value'] = entity else: raise ValueError("Do not support object type: %s" % entity) diff --git a/test/test_session.py b/test/test_session.py index 4cd8b118e..8ee09065d 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -19,15 +19,19 @@ # limitations under the License. -from unittest import TestCase +from socket import socket +from ssl import SSLSocket from mock import patch -from neo4j.v1.exceptions import ResultError -from neo4j.v1.session import GraphDatabase, CypherError, Record, record +from neo4j.v1.constants import SECURITY_NONE, SECURITY_TRUST_ON_FIRST_USE +from neo4j.v1.exceptions import CypherError, ResultError +from neo4j.v1.session import GraphDatabase, Record, record from neo4j.v1.typesystem import Node, Relationship, Path +from test.util import ServerTestCase -class DriverTestCase(TestCase): + +class DriverTestCase(ServerTestCase): def test_healthy_session_will_be_returned_to_the_pool_on_close(self): driver = GraphDatabase.driver("bolt://localhost") @@ -60,9 +64,6 @@ def test_session_that_dies_in_the_pool_will_not_be_given_out(self): session_2 = driver.session() assert session_2 is not session_1 - -class RunTestCase(TestCase): - def test_must_use_valid_url_scheme(self): with self.assertRaises(ValueError): GraphDatabase.driver("x://xxx") @@ -83,6 +84,52 @@ def test_sessions_are_not_reused_if_still_in_use(self): session_1.close() assert session_1 is not session_2 + +class SecurityTestCase(ServerTestCase): + + def test_default_session_uses_tofu(self): + driver = GraphDatabase.driver("bolt://localhost") + assert driver.security == SECURITY_TRUST_ON_FIRST_USE + + def test_insecure_session_uses_normal_socket(self): + driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_NONE) + session = driver.session() + connection = session.connection + assert isinstance(connection.channel.socket, socket) + assert connection.der_encoded_server_certificate is None + session.close() + + def test_tofu_session_uses_secure_socket(self): + driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_TRUST_ON_FIRST_USE) + session = driver.session() + connection = session.connection + assert isinstance(connection.channel.socket, SSLSocket) + assert connection.der_encoded_server_certificate is not None + session.close() + + def test_tofu_session_trusts_certificate_after_first_use(self): + driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_TRUST_ON_FIRST_USE) + session = driver.session() + connection = session.connection + certificate = connection.der_encoded_server_certificate + session.close() + session = driver.session() + connection = session.connection + assert connection.der_encoded_server_certificate == certificate + session.close() + + # TODO: Find a way to run this test + # def test_verified_session_uses_secure_socket(self): + # driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_VERIFIED) + # session = driver.session() + # connection = session.connection + # assert isinstance(connection.channel.socket, SSLSocket) + # assert connection.der_encoded_server_certificate is not None + # session.close() + + +class RunTestCase(ServerTestCase): + def test_can_run_simple_statement(self): session = GraphDatabase.driver("bolt://localhost").session() count = 0 @@ -209,7 +256,7 @@ def test_keys_with_an_error(self): _ = list(cursor.keys()) -class SummaryTestCase(TestCase): +class SummaryTestCase(ServerTestCase): def test_can_obtain_summary_after_consuming_result(self): with GraphDatabase.driver("bolt://localhost").session() as session: @@ -303,7 +350,7 @@ def test_can_obtain_notification_info(self): assert position.column == 1 -class ResetTestCase(TestCase): +class ResetTestCase(ServerTestCase): def test_automatic_reset_after_failure(self): with GraphDatabase.driver("bolt://localhost").session() as session: @@ -327,7 +374,7 @@ def test_defunct(self): assert session.connection.closed -class RecordTestCase(TestCase): +class RecordTestCase(ServerTestCase): def test_record_equality(self): record1 = Record(["name", "empire"], ["Nigel", "The British Empire"]) record2 = Record(["name", "empire"], ["Nigel", "The British Empire"]) @@ -401,7 +448,8 @@ def test_record_repr(self): assert repr(a_record) == "" -class TransactionTestCase(TestCase): +class TransactionTestCase(ServerTestCase): + def test_can_commit_transaction(self): with GraphDatabase.driver("bolt://localhost").session() as session: tx = session.begin_transaction() diff --git a/test/util.py b/test/util.py index 793fadb67..9148b5898 100644 --- a/test/util.py +++ b/test/util.py @@ -20,8 +20,15 @@ import functools +from os import rename +from os.path import isfile +from unittest import TestCase from neo4j.util import Watcher +from neo4j.v1.constants import KNOWN_HOSTS + + +KNOWN_HOSTS_BACKUP = KNOWN_HOSTS + ".backup" def watch(f): @@ -39,3 +46,19 @@ def wrapper(*args, **kwargs): f(*args, **kwargs) watcher.stop() return wrapper + + +class ServerTestCase(TestCase): + """ Base class for test cases that use a remote server. + """ + + known_hosts = KNOWN_HOSTS + known_hosts_backup = known_hosts + ".backup" + + def setUp(self): + if isfile(self.known_hosts): + rename(self.known_hosts, self.known_hosts_backup) + + def tearDown(self): + if isfile(self.known_hosts_backup): + rename(self.known_hosts_backup, self.known_hosts)