diff --git a/ravendb/changes/database_changes.py b/ravendb/changes/database_changes.py index 06442fd4..e6610fe2 100644 --- a/ravendb/changes/database_changes.py +++ b/ravendb/changes/database_changes.py @@ -1,5 +1,9 @@ +import base64 +import ssl from threading import Lock -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, Optional, Callable, Any, List + +from websocket import WebSocket from ravendb.changes.observers import Observable from ravendb.changes.types import ( @@ -9,7 +13,9 @@ CounterChange, OperationStatusChange, TopologyChange, + DatabaseChange, ) +from ravendb.serverwide.commands import GetTcpInfoCommand from ravendb.tools.parsers import IncrementalJsonParser import websocket from ravendb.exceptions.exceptions import NotSupportedException @@ -26,7 +32,14 @@ class DatabaseChanges: - def __init__(self, request_executor: "RequestExecutor", database_name, on_close, on_error=None, executor=None): + def __init__( + self, + request_executor: "RequestExecutor", + database_name: str, + on_close: Callable[[str], None], + on_error: Optional[Callable[[Exception], None]] = None, + executor: Optional[ThreadPoolExecutor] = None, + ): self._request_executor = request_executor self._conventions = request_executor.conventions self._database_name = database_name @@ -36,13 +49,13 @@ def __init__(self, request_executor: "RequestExecutor", database_name, on_close, self._closed = False self._on_close = on_close self.on_error = on_error - self._observables = dict() + self._observables_by_group: Dict[str, Dict[str, Observable[DatabaseChange]]] = {} self._executor = executor if executor else ThreadPoolExecutor(max_workers=10) self._worker = self._executor.submit(self.do_work) self.send_lock = Lock() self._confirmations_lock = Lock() - self._confirmations = {} + self._confirmations: Dict[int, Future] = {} self._command_id = 0 self._immediate_connection = 0 @@ -54,8 +67,42 @@ def __init__(self, request_executor: "RequestExecutor", database_name, on_close, self._logger.addHandler(handler) self._logger.setLevel(logging.DEBUG) + def _ensure_websocket_connected(self, url: str) -> None: + if self._request_executor.certificate_path: + self._connect_websocket_secured(url) + else: + self.client_websocket.connect(url) + + for observables_by_name in self._observables_by_group.values(): + for observer in observables_by_name.values(): + observer.set(self._executor.submit(observer.on_connect)) + self._immediate_connection = 1 + + def _get_server_certificate(self) -> Optional[str]: + cmd = GetTcpInfoCommand(self._request_executor.url) + self._request_executor.execute_command(cmd) + return cmd.result.certificate + + def _connect_websocket_secured(self, url: str) -> None: + # Get server certificate via HTTPS and prepare SSL context + server_certificate = base64.b64decode(self._get_server_certificate()) + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + ssl_context.load_cert_chain(self._request_executor.certificate_path) + if self._request_executor.trust_store_path: + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.load_verify_locations(self._request_executor.trust_store_path) + + # Create SSL WebSocket and connect it + self.client_websocket = WebSocket(sslopt={"context": ssl_context}) + self.client_websocket.connect(url, suppress_origin=True) + + # Server certificate authentication + server_certificate_from_tls = self.client_websocket.sock.getpeercert(True) + if server_certificate != server_certificate_from_tls: + raise ValueError("Certificates don't match") + def do_work(self): - preferred_node = self._request_executor.preferred_node.current_node # todo: refactor, protected access + preferred_node = self._request_executor.preferred_node.current_node url = ( f"{preferred_node.url}/databases/{self._database_name}/changes".replace("http://", "ws://") .lower() @@ -66,20 +113,7 @@ def do_work(self): while not self._closed: try: if not self.client_websocket.connected: - # todo: certificates - # if self._request_executor.certificate: - # if isinstance(self._request_executor.certificate, tuple): - # (crt, key) = self._request_executor.certificate - # self.client_websocket.sock_opt.sslopt.update({"certfile": crt, "keyfile": key}) - # else: - # self.client_websocket.sock_opt.sslopt.update( - # {"ca_certs": self._request_executor.certificate} - # ) - self.client_websocket.connect(url) - for observables in self._observables.values(): - for observer in observables.values(): - observer.set(self._executor.submit(observer.on_connect)) - self._immediate_connection = 1 + self._ensure_websocket_connected(url) self.process_changes() except ChangeProcessingException as e: self.notify_about_error(e) @@ -100,39 +134,46 @@ def process_changes(self): while not self._closed: try: response = parser.next_object() - if response: - response_type = response.get("Type", None) - if not response_type: + if not response: + continue + + response_type: Optional[str] = response.get("Type", None) + if not response_type: + continue + + if response_type == "Error": + exception = response["Exception"] + self.notify_about_error(Exception(exception)) + elif response_type == "Confirm": + command_id: Optional[int] = response.get("CommandId", None) + if not command_id or command_id not in self._confirmations: continue - if response_type == "Error": - exception = response["Exception"] - self.notify_about_error(Exception(exception)) - elif response_type == "Confirm": - command_id = response.get("CommandId", None) - if command_id and command_id in self._confirmations: - with self._confirmations_lock: - future = self._confirmations.pop(command_id) - future.set_result("done complete future") - else: - value = response.get("Value", None) - self._notify_subscribers(response_type, value, copy.copy(self._observables[response_type])) + with self._confirmations_lock: + future = self._confirmations.pop(command_id) + future.set_result("done complete future") + else: + change_json_dict: Optional[Dict[str, Any]] = response.get("Value", None) + self._notify_subscribers( + response_type, change_json_dict, copy.copy(self._observables_by_group[response_type]) + ) except Exception as e: self.notify_about_error(e) raise ChangeProcessingException(e) - def _notify_subscribers(self, type_of_change: str, value: Dict, observables: Dict[str, Observable]): + @staticmethod + def _notify_subscribers(type_of_change: str, change_json_dict: Dict[str, Any], observables: Dict[str, Observable]): if type_of_change == "DocumentChange": - result = DocumentChange.from_json(value) + result = DocumentChange.from_json(change_json_dict) elif type_of_change == "IndexChange": - result = IndexChange.from_json(value) + result = IndexChange.from_json(change_json_dict) elif type_of_change == "TimeSeriesChange": - result = TimeSeriesChange.from_json(value) + result = TimeSeriesChange.from_json(change_json_dict) elif type_of_change == "CounterChange": - result = CounterChange.from_json(value) + result = CounterChange.from_json(change_json_dict) elif type_of_change == "OperationStatusChange": - result = OperationStatusChange.from_json(value) + result = OperationStatusChange.from_json(change_json_dict) elif type_of_change == "TopologyChange": - result = TopologyChange.from_json(value) + result = TopologyChange.from_json(change_json_dict) else: raise NotSupportedException(type_of_change) @@ -143,7 +184,7 @@ def close(self): self._closed = True self.client_websocket.close() - for observable in self._observables.values(): + for observable in self._observables_by_group.values(): for observer in observable.values(): observer.close() @@ -151,23 +192,22 @@ def close(self): for confirmation in self._confirmations.values(): confirmation.cancel() - self._observables.clear() + self._observables_by_group.clear() if self._on_close: self._on_close(self._database_name) self._executor.shutdown(wait=True) - def notify_about_error(self, e): + def notify_about_error(self, e: Exception): if self.on_error: self.on_error(e) - for _, observables in self._observables.items(): + for _, observables in self._observables_by_group.items(): for observer in observables.values(): observer.error(e) def for_all_documents(self) -> Observable[DocumentChange]: - # todo: ResourceWarning: unclosed socket - observable = self.get_or_add_observable("DocumentChange", "all-docs", "watch-docs", "unwatch-docs", None)( + observable = self.get_or_add_observable("DocumentChange", "all-docs", "watch-docs", "unwatch-docs")( lambda x: True ) return observable @@ -178,17 +218,16 @@ def for_all_operations(self) -> Observable[OperationStatusChange]: "all-operations", "watch-operations", "unwatch-operations", - None, )(lambda x: True) return observable def for_all_indexes(self) -> Observable[IndexChange]: - observable = self.get_or_add_observable("IndexChange", "all-indexes", "watch-indexes", "unwatch-indexes", None)( + observable = self.get_or_add_observable("IndexChange", "all-indexes", "watch-indexes", "unwatch-indexes")( lambda x: True ) return observable - def for_index(self, index_name) -> Observable[IndexChange]: + def for_index(self, index_name: str) -> Observable[IndexChange]: observable = self.get_or_add_observable( "IndexChange", "indexes/" + index_name, @@ -198,7 +237,7 @@ def for_index(self, index_name) -> Observable[IndexChange]: )(lambda x: x.name.casefold() == index_name.casefold()) return observable - def for_operation_id(self, operation_id) -> Observable[OperationStatusChange]: + def for_operation_id(self, operation_id: int) -> Observable[OperationStatusChange]: observable = self.get_or_add_observable( "OperationsStatusChange", "operations/" + str(operation_id), @@ -208,13 +247,13 @@ def for_operation_id(self, operation_id) -> Observable[OperationStatusChange]: )(lambda x: x.operation_id == str(operation_id)) return observable - def for_document(self, doc_id) -> Observable[DocumentChange]: + def for_document(self, doc_id: str) -> Observable[DocumentChange]: observable = self.get_or_add_observable("DocumentChange", "docs/" + doc_id, "watch-doc", "unwatch-doc", doc_id)( lambda x: x.key.casefold() == doc_id.casefold() ) return observable - def for_documents_start_with(self, doc_id_prefix) -> Observable[DocumentChange]: + def for_documents_start_with(self, doc_id_prefix: str) -> Observable[DocumentChange]: observable = self.get_or_add_observable( "DocumentChange", "prefixes/" + doc_id_prefix, @@ -224,7 +263,7 @@ def for_documents_start_with(self, doc_id_prefix) -> Observable[DocumentChange]: )(lambda x: x.key is not None and x.key.casefold().startswith(doc_id_prefix.casefold())) return observable - def for_documents_in_collection(self, collection_name) -> Observable[DocumentChange]: + def for_documents_in_collection(self, collection_name: str) -> Observable[DocumentChange]: observable = self.get_or_add_observable( "DocumentChange", "collections/" + collection_name, @@ -240,11 +279,10 @@ def for_all_time_series(self) -> Observable[TimeSeriesChange]: "all-timeseries", "watch-all-timeseries", "unwatch-all-timeseries", - None, )(lambda x: True) return observable - def for_time_series(self, time_series_name) -> Observable[TimeSeriesChange]: + def for_time_series(self, time_series_name: str) -> Observable[TimeSeriesChange]: if not time_series_name: raise ValueError("time_series_name cannot be None or empty") observable = self.get_or_add_observable( @@ -256,7 +294,7 @@ def for_time_series(self, time_series_name) -> Observable[TimeSeriesChange]: )(lambda x: x.name.casefold() == time_series_name.casefold()) return observable - def for_time_series_of_document(self, doc_id, time_series_name=None) -> Observable[TimeSeriesChange]: + def for_time_series_of_document(self, doc_id: str, time_series_name: str = None) -> Observable[TimeSeriesChange]: """ Can subscribe to all time series changes that associated with the document or by passing the time series name only for a specific time series @@ -279,18 +317,18 @@ def get_lambda(): name, watch_command, unwatch_command, - value=value, - values=values, + resource_name=value, + resources_names=values, )(get_lambda()) return observable def for_all_counters(self) -> Observable[CounterChange]: - observable = self.get_or_add_observable( - "CounterChange", "all-counters", "watch-counters", "unwatch-counters", None - )(lambda x: True) + observable = self.get_or_add_observable("CounterChange", "all-counters", "watch-counters", "unwatch-counters")( + lambda x: True + ) return observable - def for_counter(self, counter_name) -> Observable[CounterChange]: + def for_counter(self, counter_name: str) -> Observable[CounterChange]: if not counter_name: raise ValueError("counter_name cannot be None or empty") observable = self.get_or_add_observable( @@ -302,7 +340,7 @@ def for_counter(self, counter_name) -> Observable[CounterChange]: )(lambda x: x.name.casefold() == counter_name.casefold()) return observable - def for_counters_of_document(self, doc_id) -> Observable[CounterChange]: + def for_counters_of_document(self, doc_id: str) -> Observable[CounterChange]: """ Can subscribe to all counters changes that associated with the document or """ @@ -314,12 +352,11 @@ def for_counters_of_document(self, doc_id) -> Observable[CounterChange]: f"document/{doc_id}/counter", "watch-document-counters", "unwatch-document-counters", - value=doc_id, - values=None, + resource_name=doc_id, )(lambda x: x.document_id.casefold() == doc_id.casefold()) return observable - def for_counter_of_document(self, doc_id, counter_name) -> Observable[CounterChange]: + def for_counter_of_document(self, doc_id: str, counter_name: str) -> Observable[CounterChange]: """ Can subscribe to all counters changes that associated with the document and for counter name """ @@ -333,38 +370,45 @@ def for_counter_of_document(self, doc_id, counter_name) -> Observable[CounterCha f"document/{doc_id}/counter/{counter_name}", "watch-document-counter", "unwatch-document-counter", - value=None, - values=[doc_id, counter_name], + resources_names=[doc_id, counter_name], )(lambda x: x.document_id.casefold() == doc_id.casefold() and x.name.casefold()) return observable - def get_or_add_observable(self, group, name, watch_command, unwatch_command, value, values=None): - if group not in self._observables: - self._observables[group] = {} + def get_or_add_observable( + self, + group: str, + name: str, + watch_command: str, + unwatch_command: str, + resource_name: Optional[str] = None, + resources_names: Optional[List[str]] = None, + ): + if group not in self._observables_by_group: + self._observables_by_group[group] = {} - if name not in self._observables[group]: + if name not in self._observables_by_group[group]: def on_disconnect(): try: if self.client_websocket.connected: - self.send(unwatch_command, value, values) + self.send(unwatch_command, resource_name, resources_names) except websocket.WebSocketException: pass def on_connect(): - self.send(watch_command, value, values) + self.send(watch_command, resource_name, resources_names) observable = Observable( on_connect=on_connect, on_disconnect=on_disconnect, executor=self._executor, ) - self._observables[group][name] = observable + self._observables_by_group[group][name] = observable if self._immediate_connection != 0: observable.set(self._executor.submit(observable.on_connect)) - return self._observables[group][name] + return self._observables_by_group[group][name] - def send(self, command, value, values=None): + def send(self, command: str, value: Optional[str], values: Optional[List[str]] = None): current_command_id = 0 future = Future() try: diff --git a/ravendb/changes/observers.py b/ravendb/changes/observers.py index ec89bcec..0baaf16f 100644 --- a/ravendb/changes/observers.py +++ b/ravendb/changes/observers.py @@ -1,7 +1,7 @@ from __future__ import annotations -from concurrent.futures import Future +from concurrent.futures import Future, ThreadPoolExecutor from threading import Lock -from typing import Callable, Generic, TypeVar +from typing import Callable, Generic, TypeVar, Optional from ravendb.tools.concurrentset import ConcurrentSet @@ -9,7 +9,12 @@ class Observable(Generic[_T_Change]): - def __init__(self, on_connect=None, on_disconnect=None, executor=None): + def __init__( + self, + on_connect: Optional[Callable[[], None]] = None, + on_disconnect: Optional[Callable[[], None]] = None, + executor: Optional[ThreadPoolExecutor] = None, + ): self.on_connect = on_connect self._on_disconnect = on_disconnect self.last_exception = None @@ -75,7 +80,7 @@ def dec(self): if self._value == 0: self.set(self._executor.submit(self._on_disconnect)) - def set(self, future): + def set(self, future: Future) -> None: if not self._future_set.done(): def done_callback(f): @@ -92,7 +97,7 @@ def done_callback(f): future.add_done_callback(done_callback) self._future = future - def error(self, exception): + def error(self, exception: Exception): future = Future() self.set(future) future.set_exception(exception) diff --git a/ravendb/documents/store/definition.py b/ravendb/documents/store/definition.py index 8bb5e8d5..7e836c95 100644 --- a/ravendb/documents/store/definition.py +++ b/ravendb/documents/store/definition.py @@ -320,7 +320,7 @@ def __init__(self, urls: Union[str, List[str]] = None, database: Optional[str] = self.__multi_db_hilo: Optional[MultiDatabaseHiLoGenerator] = None self.__identifier: Optional[str] = None self.__add_change_lock = threading.Lock() - self.__database_changes = {} + self.__database_changes: Dict[str, DatabaseChanges] = {} self.__after_close: List[Callable[[], None]] = [] self.__before_close: List[Callable[[], None]] = [] self.__time_series_operation: Optional[TimeSeriesOperations] = None @@ -493,8 +493,8 @@ def changes(self, database=None, on_error=None, executor=None) -> DatabaseChange ) return self.__database_changes[database] - def __on_close_change(self, database): - self.__database_changes.pop(database, None) + def __on_close_change(self, database_name: str): + self.__database_changes.pop(database_name, None) def set_request_timeout(self, timeout: datetime.timedelta, database: Optional[str] = None) -> Callable[[], None]: self.assert_initialized() diff --git a/ravendb/serverwide/operations/common.py b/ravendb/serverwide/operations/common.py index 6140ca15..e2aa1a0f 100644 --- a/ravendb/serverwide/operations/common.py +++ b/ravendb/serverwide/operations/common.py @@ -143,7 +143,8 @@ def __init__( def create_request(self, node: ServerNode) -> requests.Request: url = ( - f"{node.url}/admin/databases?name={self.__database_name}&replicationFactor={self.__replication_factor}" + f"{node.url}/admin/databases?name={self.__database_name}" + f"&replicationFactor={self.__replication_factor}&?raft-request-id={self.get_raft_unique_request_id}" ) request = requests.Request("PUT") diff --git a/ravendb/tests/jvm_migrated_tests/client_tests/indexing_tests/test_indexes_from_client.py b/ravendb/tests/jvm_migrated_tests/client_tests/indexing_tests/test_indexes_from_client.py index 20481e61..39dce71f 100644 --- a/ravendb/tests/jvm_migrated_tests/client_tests/indexing_tests/test_indexes_from_client.py +++ b/ravendb/tests/jvm_migrated_tests/client_tests/indexing_tests/test_indexes_from_client.py @@ -173,7 +173,7 @@ def test_can_stop_and_start(self): self.assertEqual(IndexRunningStatus.RUNNING, status.status) - self.assertEquals(1, len(status.indexes)) + self.assertEqual(1, len(status.indexes)) self.assertEqual(IndexRunningStatus.RUNNING, status.indexes[0].status) diff --git a/ravendb/tests/jvm_migrated_tests/issues_tests/test_ravenDB_6967.py b/ravendb/tests/jvm_migrated_tests/issues_tests/test_ravenDB_6967.py index e4585466..559214c1 100644 --- a/ravendb/tests/jvm_migrated_tests/issues_tests/test_ravenDB_6967.py +++ b/ravendb/tests/jvm_migrated_tests/issues_tests/test_ravenDB_6967.py @@ -78,7 +78,7 @@ def test_can_delete_index_errors(self): index_errors3 = self.store.maintenance.send(GetIndexErrorsOperation("Index3")) self.assertGreater(sum([len(x.errors) for x in index_errors1]), 0) - self.assertEquals(sum([len(x.errors) for x in index_errors2]), 0) + self.assertEqual(sum([len(x.errors) for x in index_errors2]), 0) self.assertGreater(sum([len(x.errors) for x in index_errors3]), 0) self.store.maintenance.send(DeleteIndexErrorsOperation()) @@ -87,8 +87,8 @@ def test_can_delete_index_errors(self): index_errors2 = self.store.maintenance.send(GetIndexErrorsOperation("Index2")) index_errors3 = self.store.maintenance.send(GetIndexErrorsOperation("Index3")) - self.assertEquals(sum([len(x.errors) for x in index_errors1]), 0) - self.assertEquals(sum([len(x.errors) for x in index_errors2]), 0) - self.assertEquals(sum([len(x.errors) for x in index_errors3]), 0) + self.assertEqual(sum([len(x.errors) for x in index_errors1]), 0) + self.assertEqual(sum([len(x.errors) for x in index_errors2]), 0) + self.assertEqual(sum([len(x.errors) for x in index_errors3]), 0) RavenTestHelper.assert_no_index_errors(self.store) diff --git a/ravendb/tests/jvm_migrated_tests/server_tests/documents/notifications/test_changes.py b/ravendb/tests/jvm_migrated_tests/server_tests/documents/notifications/test_changes.py index 7da6414a..4f5e1164 100644 --- a/ravendb/tests/jvm_migrated_tests/server_tests/documents/notifications/test_changes.py +++ b/ravendb/tests/jvm_migrated_tests/server_tests/documents/notifications/test_changes.py @@ -3,7 +3,7 @@ from ravendb import AbstractIndexCreationTask, SetIndexesPriorityOperation from ravendb.changes.observers import ActionObserver -from ravendb.changes.types import DocumentChange, IndexChange +from ravendb.changes.types import DocumentChange, IndexChange, DocumentChangeType from ravendb.documents.indexes.definitions import IndexPriority from ravendb.infrastructure.entities import User from ravendb.infrastructure.orders import Order @@ -235,3 +235,62 @@ def __ev(value: DocumentChange): self.assertEqual("users/2", document_changes[1].key) close_action() + + def test_changes_with_https(self): + event = Event() + changes_list = [] + exception = None + + def _on_error(e): + nonlocal exception + exception = e + + changes = self.secured_document_store.changes(on_error=_on_error) + observable = changes.for_document("users/1") + + def __ev(value: DocumentChange): + changes_list.append(value) + event.set() + + observer = ActionObserver(__ev) + close_action = observable.subscribe_with_observer(observer) + try: + observable.ensure_subscribe_now() + except Exception: + raise exception + + with self.secured_document_store.open_session() as session: + user = User() + session.store(user, "users/1") + session.save_changes() + + event.wait(2) + document_change = changes_list[0] + self.assertIsNotNone(document_change) + self.assertEqual("users/1", document_change.key) + self.assertEqual(DocumentChangeType.PUT, document_change.type_of_change) + + changes_list.clear() + + try: + event.wait(1) + except Exception: + pass + + self.assertEqual(0, len(changes_list)) + close_action() + # at this point we should be unsubscribed from changes on 'users/1' + + with self.secured_document_store.open_session() as session: + user = User() + user.name = "another name" + session.store(user, "users/1") + session.save_changes() + + # it should be empty + try: + event.wait(1) + except Exception: + pass + + self.assertEqual(0, len(changes_list)) diff --git a/ravendb/tools/parsers.py b/ravendb/tools/parsers.py index 7bfd4453..8e1f8ee3 100644 --- a/ravendb/tools/parsers.py +++ b/ravendb/tools/parsers.py @@ -1,4 +1,6 @@ from decimal import InvalidOperation +from typing import Any, Dict, Optional + from ijson.common import integer_or_decimal, IncompleteJSONError from ijson.backends.python import UnexpectedSymbol from _elementtree import ParseError @@ -112,7 +114,7 @@ def create_object(self, gen): raise ParseError("End object expected, but the generator ended before we got it") - def next_object(self): + def next_object(self) -> Optional[Dict[str, Any]]: try: (_, text) = next(self.lexer) if IS_WEBSOCKET and text == ",": diff --git a/ravendb/util/tcp_utils.py b/ravendb/util/tcp_utils.py index 5f7dde77..1e624bf4 100644 --- a/ravendb/util/tcp_utils.py +++ b/ravendb/util/tcp_utils.py @@ -17,11 +17,13 @@ def connect( ) -> socket.socket: hostname, port = url_string.replace("tcp://", "").split(":") s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + is_ssl_socket = server_certificate_base64 and client_certificate_pem_path - if server_certificate_base64 and client_certificate_pem_path: + if is_ssl_socket: context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) context.load_cert_chain(client_certificate_pem_path, password=certificate_private_key_password) s = context.wrap_socket(s) + s.connect((hostname, int(port))) if is_ssl_socket and base64.b64decode(server_certificate_base64) != s.getpeercert(True): raise ConnectionError("Failed to validate public server certificate.")