Skip to content

Commit

Permalink
Merge pull request #36 from neo4j/1.0-tls
Browse files Browse the repository at this point in the history
1.0 tls
  • Loading branch information
pontusmelke committed Feb 25, 2016
2 parents b070622 + e9269db commit 0a704bf
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 57 deletions.
5 changes: 3 additions & 2 deletions examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@
# limitations under the License.


from unittest import TestCase
from test.util import ServerTestCase

# tag::minimal-example-import[]
from neo4j.v1 import GraphDatabase
# end::minimal-example-import[]


class FreshDatabaseTestCase(TestCase):
class FreshDatabaseTestCase(ServerTestCase):

def setUp(self):
ServerTestCase.setUp(self)
session = GraphDatabase.driver("bolt://localhost").session()
session.run("MATCH (n) DETACH DELETE n")
session.close()
Expand Down
1 change: 1 addition & 0 deletions neo4j/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .constants import *
from .session import *
from .typesystem import *
16 changes: 0 additions & 16 deletions neo4j/v1/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,3 @@ def perf_counter():
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse


try:
from ssl import SSLContext, PROTOCOL_SSLv23, OP_NO_SSLv2, HAS_SNI
except ImportError:
from ssl import wrap_socket, PROTOCOL_SSLv23

def secure_socket(s, host):
return wrap_socket(s, ssl_version=PROTOCOL_SSLv23)

else:

def secure_socket(s, host):
ssl_context = SSLContext(PROTOCOL_SSLv23)
ssl_context.options |= OP_NO_SSLv2
return ssl_context.wrap_socket(s, server_hostname=host if HAS_SNI else None)
99 changes: 81 additions & 18 deletions neo4j/v1/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,24 @@

from __future__ import division

from base64 import b64encode
from collections import deque
from io import BytesIO
import logging
from os import environ
from os import makedirs, open as os_open, write as os_write, close as os_close, O_CREAT, O_APPEND, O_WRONLY
from os.path import dirname, isfile
from select import select
from socket import create_connection, SHUT_RDWR
from ssl import HAS_SNI, SSLError
from struct import pack as struct_pack, unpack as struct_unpack, unpack_from as struct_unpack_from

from ..meta import version
from .compat import hex2, secure_socket
from .constants import DEFAULT_PORT, DEFAULT_USER_AGENT, KNOWN_HOSTS, MAGIC_PREAMBLE, \
SECURITY_DEFAULT, SECURITY_TRUST_ON_FIRST_USE
from .compat import hex2
from .exceptions import ProtocolError
from .packstream import Packer, Unpacker


DEFAULT_PORT = 7687
DEFAULT_USER_AGENT = "neo4j-python/%s" % version

MAGIC_PREAMBLE = 0x6060B017

# Signature bytes for each message type
INIT = b"\x01" # 0000 0001 // INIT <user_agent>
RESET = b"\x0F" # 0000 1111 // RESET
Expand Down Expand Up @@ -211,14 +210,18 @@ def __init__(self, sock, **config):
user_agent = config.get("user_agent", DEFAULT_USER_AGENT)
if isinstance(user_agent, bytes):
user_agent = user_agent.decode("UTF-8")
self.user_agent = user_agent

# Pick up the server certificate, if any
self.der_encoded_server_certificate = config.get("der_encoded_server_certificate")

def on_failure(metadata):
raise ProtocolError("Initialisation failed")

response = Response(self)
response.on_failure = on_failure

self.append(INIT, (user_agent,), response=response)
self.append(INIT, (self.user_agent,), response=response)
self.send()
while not response.complete:
self.fetch()
Expand Down Expand Up @@ -313,7 +316,53 @@ def close(self):
self.closed = True


def connect(host, port=None, **config):
class CertificateStore(object):

def match_or_trust(self, host, der_encoded_certificate):
""" Check whether the supplied certificate matches that stored for the
specified host. If it does, return ``True``, if it doesn't, return
``False``. If no entry for that host is found, add it to the store
and return ``True``.
:arg host:
:arg der_encoded_certificate:
:return:
"""
raise NotImplementedError()


class PersonalCertificateStore(CertificateStore):

def __init__(self, path=None):
self.path = path or KNOWN_HOSTS

def match_or_trust(self, host, der_encoded_certificate):
base64_encoded_certificate = b64encode(der_encoded_certificate)
if isfile(self.path):
with open(self.path) as f_in:
for line in f_in:
known_host, _, known_cert = line.strip().partition(":")
known_cert = known_cert.encode("utf-8")
if host == known_host:
return base64_encoded_certificate == known_cert
# First use (no hosts match)
try:
makedirs(dirname(self.path))
except OSError:
pass
f_out = os_open(self.path, O_CREAT | O_APPEND | O_WRONLY, 0o600) # TODO: Windows
if isinstance(host, bytes):
os_write(f_out, host)
else:
os_write(f_out, host.encode("utf-8"))
os_write(f_out, b":")
os_write(f_out, base64_encoded_certificate)
os_write(f_out, b"\n")
os_close(f_out)
return True


def connect(host, port=None, ssl_context=None, **config):
""" Connect and perform a handshake and return a valid Connection object, assuming
a protocol version can be agreed.
"""
Expand All @@ -323,14 +372,28 @@ def connect(host, port=None, **config):
if __debug__: log_info("~~ [CONNECT] %s %d", host, port)
s = create_connection((host, port))

# Secure the connection if so requested
try:
secure = environ["NEO4J_SECURE"]
except KeyError:
secure = config.get("secure", False)
if secure:
# Secure the connection if an SSL context has been provided
if ssl_context:
if __debug__: log_info("~~ [SECURE] %s", host)
s = secure_socket(s, host)
try:
s = ssl_context.wrap_socket(s, server_hostname=host if HAS_SNI else None)
except SSLError as cause:
error = ProtocolError("Cannot establish secure connection; %s" % cause.args[1])
error.__cause__ = cause
raise error
else:
# Check that the server provides a certificate
der_encoded_server_certificate = s.getpeercert(binary_form=True)
if der_encoded_server_certificate is None:
raise ProtocolError("When using a secure socket, the server should always provide a certificate")
security = config.get("security", SECURITY_DEFAULT)
if security == SECURITY_TRUST_ON_FIRST_USE:
store = PersonalCertificateStore()
if not store.match_or_trust(host, der_encoded_server_certificate):
raise ProtocolError("Server certificate does not match known certificate for %r; check "
"details in file %r" % (host, KNOWN_HOSTS))
else:
der_encoded_server_certificate = None

# Send details of the protocol versions supported
supported_versions = [1, 0, 0, 0]
Expand Down Expand Up @@ -364,4 +427,4 @@ def connect(host, port=None, **config):
s.shutdown(SHUT_RDWR)
s.close()
else:
return Connection(s, **config)
return Connection(s, der_encoded_server_certificate=der_encoded_server_certificate, **config)
38 changes: 38 additions & 0 deletions neo4j/v1/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

# Copyright (c) 2002-2016 "Neo Technology,"
# Network Engine for Objects in Lund AB [http://neotechnology.com]
#
# This file is part of Neo4j.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from os.path import expanduser, join

from ..meta import version


DEFAULT_PORT = 7687
DEFAULT_USER_AGENT = "neo4j-python/%s" % version

KNOWN_HOSTS = join(expanduser("~"), ".neo4j", "known_hosts")

MAGIC_PREAMBLE = 0x6060B017

SECURITY_NONE = 0
SECURITY_TRUST_ON_FIRST_USE = 1
SECURITY_VERIFIED = 2

SECURITY_DEFAULT = SECURITY_TRUST_ON_FIRST_USE
15 changes: 14 additions & 1 deletion neo4j/v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ class which can be used to obtain `Driver` instances that are used for
from __future__ import division

from collections import deque, namedtuple
from ssl import SSLContext, PROTOCOL_SSLv23, OP_NO_SSLv2, CERT_REQUIRED, Purpose

from .compat import integer, string, urlparse
from .connection import connect, Response, RUN, PULL_ALL
from .constants import SECURITY_NONE, SECURITY_VERIFIED, SECURITY_DEFAULT
from .exceptions import CypherError, ResultError
from .typesystem import hydrated

Expand Down Expand Up @@ -77,6 +79,16 @@ def __init__(self, url, **config):
self.config = config
self.max_pool_size = config.get("max_pool_size", DEFAULT_MAX_POOL_SIZE)
self.session_pool = deque()
self.security = security = config.get("security", SECURITY_DEFAULT)
if security > SECURITY_NONE:
ssl_context = SSLContext(PROTOCOL_SSLv23)
ssl_context.options |= OP_NO_SSLv2
if security >= SECURITY_VERIFIED:
ssl_context.verify_mode = CERT_REQUIRED
ssl_context.load_default_certs(Purpose.SERVER_AUTH)
self.ssl_context = ssl_context
else:
self.ssl_context = None

def session(self):
""" Create a new session based on the graph database details
Expand Down Expand Up @@ -425,7 +437,7 @@ class Session(object):

def __init__(self, driver):
self.driver = driver
self.connection = connect(driver.host, driver.port, **driver.config)
self.connection = connect(driver.host, driver.port, driver.ssl_context, **driver.config)
self.transaction = None
self.last_cursor = None

Expand Down Expand Up @@ -654,6 +666,7 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)


def record(obj):
""" Obtain an immutable record for the given object
(either by calling obj.__record__() or by copying out the record data)
Expand Down
17 changes: 8 additions & 9 deletions test/tck/tck_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from neo4j.v1 import compat, Relationship, Node, Path
from neo4j.v1 import GraphDatabase, Relationship, Node, Path, SECURITY_NONE
from neo4j.v1.compat import string

from neo4j.v1 import GraphDatabase

driver = GraphDatabase.driver("bolt://localhost")
driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_NONE)


def send_string(text):
Expand All @@ -39,11 +39,10 @@ def send_parameters(statement, parameters):
return list(cursor.stream())


def to_unicode(val):
try:
return unicode(val)
except NameError:
return str(val)
try:
to_unicode = unicode
except NameError:
to_unicode = str


def string_to_type(str):
Expand Down Expand Up @@ -91,7 +90,7 @@ def __init__(self, entity):
elif isinstance(entity, Path):
self.content = self.create_path(entity)
elif isinstance(entity, int) or isinstance(entity, float) or isinstance(entity,
(str, compat.string)) or entity is None:
(str, string)) or entity is None:
self.content['value'] = entity
else:
raise ValueError("Do not support object type: %s" % entity)
Expand Down
Loading

0 comments on commit 0a704bf

Please sign in to comment.