Skip to content

Commit

Permalink
Merge pull request #204 from poissoncorp/RDBC-788
Browse files Browse the repository at this point in the history
RDBC-788 Secured Changes API
  • Loading branch information
ml054 authored Feb 8, 2024
2 parents 4990617 + def555a commit 16cf815
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 95 deletions.
200 changes: 122 additions & 78 deletions ravendb/changes/database_changes.py

Large diffs are not rendered by default.

15 changes: 10 additions & 5 deletions ravendb/changes/observers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from __future__ import annotations
from concurrent.futures import Future
from concurrent.futures import Future, ThreadPoolExecutor
from threading import Lock
from typing import Callable, Generic, TypeVar
from typing import Callable, Generic, TypeVar, Optional

from ravendb.tools.concurrentset import ConcurrentSet

_T_Change = TypeVar("_T_Change")


class Observable(Generic[_T_Change]):
def __init__(self, on_connect=None, on_disconnect=None, executor=None):
def __init__(
self,
on_connect: Optional[Callable[[], None]] = None,
on_disconnect: Optional[Callable[[], None]] = None,
executor: Optional[ThreadPoolExecutor] = None,
):
self.on_connect = on_connect
self._on_disconnect = on_disconnect
self.last_exception = None
Expand Down Expand Up @@ -75,7 +80,7 @@ def dec(self):
if self._value == 0:
self.set(self._executor.submit(self._on_disconnect))

def set(self, future):
def set(self, future: Future) -> None:
if not self._future_set.done():

def done_callback(f):
Expand All @@ -92,7 +97,7 @@ def done_callback(f):
future.add_done_callback(done_callback)
self._future = future

def error(self, exception):
def error(self, exception: Exception):
future = Future()
self.set(future)
future.set_exception(exception)
Expand Down
6 changes: 3 additions & 3 deletions ravendb/documents/store/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def __init__(self, urls: Union[str, List[str]] = None, database: Optional[str] =
self.__multi_db_hilo: Optional[MultiDatabaseHiLoGenerator] = None
self.__identifier: Optional[str] = None
self.__add_change_lock = threading.Lock()
self.__database_changes = {}
self.__database_changes: Dict[str, DatabaseChanges] = {}
self.__after_close: List[Callable[[], None]] = []
self.__before_close: List[Callable[[], None]] = []
self.__time_series_operation: Optional[TimeSeriesOperations] = None
Expand Down Expand Up @@ -493,8 +493,8 @@ def changes(self, database=None, on_error=None, executor=None) -> DatabaseChange
)
return self.__database_changes[database]

def __on_close_change(self, database):
self.__database_changes.pop(database, None)
def __on_close_change(self, database_name: str):
self.__database_changes.pop(database_name, None)

def set_request_timeout(self, timeout: datetime.timedelta, database: Optional[str] = None) -> Callable[[], None]:
self.assert_initialized()
Expand Down
3 changes: 2 additions & 1 deletion ravendb/serverwide/operations/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def __init__(

def create_request(self, node: ServerNode) -> requests.Request:
url = (
f"{node.url}/admin/databases?name={self.__database_name}&replicationFactor={self.__replication_factor}"
f"{node.url}/admin/databases?name={self.__database_name}"
f"&replicationFactor={self.__replication_factor}&?raft-request-id={self.get_raft_unique_request_id}"
)

request = requests.Request("PUT")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def test_can_stop_and_start(self):

self.assertEqual(IndexRunningStatus.RUNNING, status.status)

self.assertEquals(1, len(status.indexes))
self.assertEqual(1, len(status.indexes))

self.assertEqual(IndexRunningStatus.RUNNING, status.indexes[0].status)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_can_delete_index_errors(self):
index_errors3 = self.store.maintenance.send(GetIndexErrorsOperation("Index3"))

self.assertGreater(sum([len(x.errors) for x in index_errors1]), 0)
self.assertEquals(sum([len(x.errors) for x in index_errors2]), 0)
self.assertEqual(sum([len(x.errors) for x in index_errors2]), 0)
self.assertGreater(sum([len(x.errors) for x in index_errors3]), 0)

self.store.maintenance.send(DeleteIndexErrorsOperation())
Expand All @@ -87,8 +87,8 @@ def test_can_delete_index_errors(self):
index_errors2 = self.store.maintenance.send(GetIndexErrorsOperation("Index2"))
index_errors3 = self.store.maintenance.send(GetIndexErrorsOperation("Index3"))

self.assertEquals(sum([len(x.errors) for x in index_errors1]), 0)
self.assertEquals(sum([len(x.errors) for x in index_errors2]), 0)
self.assertEquals(sum([len(x.errors) for x in index_errors3]), 0)
self.assertEqual(sum([len(x.errors) for x in index_errors1]), 0)
self.assertEqual(sum([len(x.errors) for x in index_errors2]), 0)
self.assertEqual(sum([len(x.errors) for x in index_errors3]), 0)

RavenTestHelper.assert_no_index_errors(self.store)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from ravendb import AbstractIndexCreationTask, SetIndexesPriorityOperation
from ravendb.changes.observers import ActionObserver
from ravendb.changes.types import DocumentChange, IndexChange
from ravendb.changes.types import DocumentChange, IndexChange, DocumentChangeType
from ravendb.documents.indexes.definitions import IndexPriority
from ravendb.infrastructure.entities import User
from ravendb.infrastructure.orders import Order
Expand Down Expand Up @@ -235,3 +235,62 @@ def __ev(value: DocumentChange):
self.assertEqual("users/2", document_changes[1].key)

close_action()

def test_changes_with_https(self):
event = Event()
changes_list = []
exception = None

def _on_error(e):
nonlocal exception
exception = e

changes = self.secured_document_store.changes(on_error=_on_error)
observable = changes.for_document("users/1")

def __ev(value: DocumentChange):
changes_list.append(value)
event.set()

observer = ActionObserver(__ev)
close_action = observable.subscribe_with_observer(observer)
try:
observable.ensure_subscribe_now()
except Exception:
raise exception

with self.secured_document_store.open_session() as session:
user = User()
session.store(user, "users/1")
session.save_changes()

event.wait(2)
document_change = changes_list[0]
self.assertIsNotNone(document_change)
self.assertEqual("users/1", document_change.key)
self.assertEqual(DocumentChangeType.PUT, document_change.type_of_change)

changes_list.clear()

try:
event.wait(1)
except Exception:
pass

self.assertEqual(0, len(changes_list))
close_action()
# at this point we should be unsubscribed from changes on 'users/1'

with self.secured_document_store.open_session() as session:
user = User()
user.name = "another name"
session.store(user, "users/1")
session.save_changes()

# it should be empty
try:
event.wait(1)
except Exception:
pass

self.assertEqual(0, len(changes_list))
4 changes: 3 additions & 1 deletion ravendb/tools/parsers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from decimal import InvalidOperation
from typing import Any, Dict, Optional

from ijson.common import integer_or_decimal, IncompleteJSONError
from ijson.backends.python import UnexpectedSymbol
from _elementtree import ParseError
Expand Down Expand Up @@ -112,7 +114,7 @@ def create_object(self, gen):

raise ParseError("End object expected, but the generator ended before we got it")

def next_object(self):
def next_object(self) -> Optional[Dict[str, Any]]:
try:
(_, text) = next(self.lexer)
if IS_WEBSOCKET and text == ",":
Expand Down
4 changes: 3 additions & 1 deletion ravendb/util/tcp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ def connect(
) -> socket.socket:
hostname, port = url_string.replace("tcp://", "").split(":")
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

is_ssl_socket = server_certificate_base64 and client_certificate_pem_path
if server_certificate_base64 and client_certificate_pem_path:
if is_ssl_socket:
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
context.load_cert_chain(client_certificate_pem_path, password=certificate_private_key_password)
s = context.wrap_socket(s)

s.connect((hostname, int(port)))
if is_ssl_socket and base64.b64decode(server_certificate_base64) != s.getpeercert(True):
raise ConnectionError("Failed to validate public server certificate.")
Expand Down

0 comments on commit 16cf815

Please sign in to comment.