diff --git a/jupyter_collaboration/handlers.py b/jupyter_collaboration/handlers.py index 25d338c1..01c87df2 100644 --- a/jupyter_collaboration/handlers.py +++ b/jupyter_collaboration/handlers.py @@ -209,7 +209,7 @@ async def recv(self): message = await self._message_queue.get() return message - def on_message(self, message): + async def on_message(self, message): """ On message receive. """ @@ -240,6 +240,9 @@ def on_message(self, message): ) return skip + if message_type == MessageType.ROOM: + await self.room.handle_msg(message[1:]) + if message_type == MessageType.CHAT: msg = message[2:].decode("utf-8") @@ -316,7 +319,7 @@ async def _clean_room(self) -> None: file = self._file_loaders[file_id] if file.number_of_subscriptions == 0: self.log.info("Deleting file %s", file.path) - del self._file_loaders[file_id] + await self._file_loaders.remove(file_id) self._emit(LogLevel.INFO, "clean", "Loader deleted.") def check_origin(self, origin): diff --git a/jupyter_collaboration/loaders.py b/jupyter_collaboration/loaders.py index eba3fc8a..02ae2b84 100644 --- a/jupyter_collaboration/loaders.py +++ b/jupyter_collaboration/loaders.py @@ -118,6 +118,26 @@ async def load_content(self, format: str, file_type: str, content: bool) -> dict ) async def save_content(self, model: dict[str, Any]) -> dict[str, Any]: + """ + Save the content of the file. + + Parameters: + model (dict): A dictionary with format, type, last_modified, and content of the file. + + Returns: + model (dict): A dictionary with the metadata and content of the file. + """ + async with self._lock: + path = self.path + if model["type"] not in {"directory", "file", "notebook"}: + # fall back to file if unknown type, the content manager only knows + # how to handle these types + model["type"] = "file" + + self._log.info("Saving file: %s", path) + return await ensure_async(self._contents_manager.save(model, path)) + + async def maybe_save_content(self, model: dict[str, Any]) -> dict[str, Any]: """ Save the content of the file. diff --git a/jupyter_collaboration/rooms.py b/jupyter_collaboration/rooms.py index 38823769..a34ffabf 100644 --- a/jupyter_collaboration/rooms.py +++ b/jupyter_collaboration/rooms.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio +import uuid from logging import Logger from typing import Any @@ -11,9 +12,16 @@ from jupyter_ydoc import ydocs as YDOCS from ypy_websocket.websocket_server import YRoom from ypy_websocket.ystore import BaseYStore, YDocNotFound +from ypy_websocket.yutils import write_var_uint from .loaders import FileLoader -from .utils import JUPYTER_COLLABORATION_EVENTS_URI, LogLevel, OutOfBandChanges +from .utils import ( + JUPYTER_COLLABORATION_EVENTS_URI, + LogLevel, + MessageType, + OutOfBandChanges, + RoomMessages, +) YFILE = YDOCS["file"] @@ -45,9 +53,11 @@ def __init__( self._save_delay = save_delay self._update_lock = asyncio.Lock() + self._outofband_lock = asyncio.Lock() self._initialization_lock = asyncio.Lock() self._cleaner: asyncio.Task | None = None self._saving_document: asyncio.Task | None = None + self._messages: dict[str, asyncio.Lock] = {} # Listen for document changes self._document.observe(self._on_document_change) @@ -149,6 +159,41 @@ async def initialize(self) -> None: self.ready = True self._emit(LogLevel.INFO, "initialize", "Room initialized") + async def handle_msg(self, data: bytes) -> None: + msg_type = data[0] + msg_id = data[2:].decode() + + # Use a lock to prevent handling responses from multiple clients + # at the same time + async with self._messages[msg_id]: + # Check whether the previous client resolved the conflict + if msg_id not in self._messages: + return + + try: + ans = None + if msg_type == RoomMessages.RELOAD: + # Restore the room with the content from disk + await self._load_document() + ans = RoomMessages.DOC_OVERWRITTEN + + elif msg_type == RoomMessages.OVERWRITE: + # Overwrite the file with content from the room + await self._save_document() + ans = RoomMessages.FILE_OVERWRITTEN + + if ans is not None: + # Remove the lock and broadcast the resolution + self._messages.pop(msg_id) + data = msg_id.encode() + self._outofband_lock.release() + await self._broadcast_msg( + bytes([MessageType.ROOM, ans]) + write_var_uint(len(data)) + data + ) + + except Exception: + return + def _emit(self, level: LogLevel, action: str | None = None, msg: str | None = None) -> None: data = {"level": level.value, "room": self._room_id, "path": self._file.path} if action: @@ -187,24 +232,24 @@ async def _on_content_change(self, event: str, args: dict[str, Any]) -> None: event (str): Type of change. args (dict): A dictionary with format, type, last_modified. """ + if self._outofband_lock.locked(): + return + if event == "metadata" and ( self._last_modified is None or self._last_modified < args["last_modified"] ): self.log.info("Out-of-band changes. Overwriting the content in room %s", self._room_id) self._emit(LogLevel.INFO, "overwrite", "Out-of-band changes. Overwriting the room.") - try: - model = await self._file.load_content(self._file_format, self._file_type, True) - except Exception as e: - msg = f"Error loading content from file: {self._file.path}\n{e!r}" - self.log.error(msg, exc_info=e) - self._emit(LogLevel.ERROR, None, msg) - return None - - async with self._update_lock: - self._document.source = model["content"] - self._last_modified = model["last_modified"] - self._document.dirty = False + msg_id = str(uuid.uuid4()) + self._messages[msg_id] = asyncio.Lock() + await self._outofband_lock.acquire() + data = msg_id.encode() + await self._broadcast_msg( + bytes([MessageType.ROOM, RoomMessages.FILE_CHANGED]) + + write_var_uint(len(data)) + + data + ) def _on_document_change(self, target: str, event: Any) -> None: """ @@ -231,6 +276,45 @@ def _on_document_change(self, target: str, event: Any) -> None: self._saving_document = asyncio.create_task(self._maybe_save_document()) + async def _load_document(self) -> None: + try: + model = await self._file.load_content(self._file_format, self._file_type, True) + except Exception as e: + msg = f"Error loading content from file: {self._file.path}\n{e!r}" + self.log.error(msg, exc_info=e) + self._emit(LogLevel.ERROR, None, msg) + return None + + async with self._update_lock: + self._document.source = model["content"] + self._last_modified = model["last_modified"] + self._document.dirty = False + + async def _save_document(self) -> None: + """ + Saves the content of the document to disk. + """ + try: + self.log.info("Saving the content from room %s", self._room_id) + model = await self._file.save_content( + { + "format": self._file_format, + "type": self._file_type, + "last_modified": self._last_modified, + "content": self._document.source, + } + ) + self._last_modified = model["last_modified"] + async with self._update_lock: + self._document.dirty = False + + self._emit(LogLevel.INFO, "save", "Content saved.") + + except Exception as e: + msg = f"Error saving file: {self._file.path}\n{e!r}" + self.log.error(msg, exc_info=e) + self._emit(LogLevel.ERROR, None, msg) + async def _maybe_save_document(self) -> None: """ Saves the content of the document to disk. @@ -248,7 +332,7 @@ async def _maybe_save_document(self) -> None: try: self.log.info("Saving the content from room %s", self._room_id) - model = await self._file.save_content( + model = await self._file.maybe_save_content( { "format": self._file_format, "type": self._file_type, @@ -284,6 +368,10 @@ async def _maybe_save_document(self) -> None: self.log.error(msg, exc_info=e) self._emit(LogLevel.ERROR, None, msg) + async def _broadcast_msg(self, msg: bytes) -> None: + for client in self.clients: + await client.send(msg) + class TransientRoom(YRoom): """A Y room for sharing state (e.g. awareness).""" diff --git a/jupyter_collaboration/utils.py b/jupyter_collaboration/utils.py index e6c974cf..cf1c7bc8 100644 --- a/jupyter_collaboration/utils.py +++ b/jupyter_collaboration/utils.py @@ -12,9 +12,18 @@ class MessageType(IntEnum): SYNC = 0 AWARENESS = 1 + ROOM = 124 CHAT = 125 +class RoomMessages(IntEnum): + RELOAD = 0 + OVERWRITE = 1 + FILE_CHANGED = 2 + FILE_OVERWRITTEN = 3 + DOC_OVERWRITTEN = 4 + + class LogLevel(Enum): INFO = "INFO" DEBUG = "DEBUG" diff --git a/packages/docprovider/src/awareness.ts b/packages/docprovider/src/awareness.ts index 48d0afa1..762a9f60 100644 --- a/packages/docprovider/src/awareness.ts +++ b/packages/docprovider/src/awareness.ts @@ -14,12 +14,9 @@ import * as decoding from 'lib0/decoding'; import * as encoding from 'lib0/encoding'; import { WebsocketProvider } from 'y-websocket'; +import { MessageType } from './utils'; import { IAwarenessProvider } from './tokens'; -export enum MessageType { - CHAT = 125 -} - export interface IContent { type: string; body: string; diff --git a/packages/docprovider/src/utils.ts b/packages/docprovider/src/utils.ts new file mode 100644 index 00000000..47815952 --- /dev/null +++ b/packages/docprovider/src/utils.ts @@ -0,0 +1,17 @@ +/* ----------------------------------------------------------------------------- +| Copyright (c) Jupyter Development Team. +| Distributed under the terms of the Modified BSD License. +|----------------------------------------------------------------------------*/ + +export enum MessageType { + ROOM = 124, + CHAT = 125 +} + +export enum RoomMessage { + RELOAD = 0, + OVERWRITE = 1, + FILE_CHANGED = 2, + FILE_OVERWRITTEN = 3, + DOC_OVERWRITTEN = 4 +} diff --git a/packages/docprovider/src/yprovider.ts b/packages/docprovider/src/yprovider.ts index 4061ebeb..566b3f0f 100644 --- a/packages/docprovider/src/yprovider.ts +++ b/packages/docprovider/src/yprovider.ts @@ -13,10 +13,13 @@ import { Signal } from '@lumino/signaling'; import { DocumentChange, YDocument } from '@jupyter/ydoc'; +import * as decoding from 'lib0/decoding'; +import * as encoding from 'lib0/encoding'; import { Awareness } from 'y-protocols/awareness'; import { WebsocketProvider as YWebsocketProvider } from 'y-websocket'; import { requestDocSession } from './requests'; +import { MessageType, RoomMessage } from './utils'; /** * An interface for a document provider. @@ -111,6 +114,18 @@ export class WebSocketProvider implements IDocumentProvider { this._yWebsocketProvider.on('sync', this._onSync); this._yWebsocketProvider.on('connection-close', this._onConnectionClosed); + + this._yWebsocketProvider.messageHandlers[MessageType.ROOM] = ( + encoder, + decoder, + provider, + emitSynced, + messageType + ) => { + const msgType = decoding.readVarUint(decoder); + const data = decoding.readVarString(decoder); + this._handleRoomMessage(msgType, data); + }; } private _onUserChanged(user: User.IManager): void { @@ -138,6 +153,59 @@ export class WebSocketProvider implements IDocumentProvider { } }; + private _handleRoomMessage(type: number, data: string): void { + switch (type) { + case RoomMessage.FILE_CHANGED: + this._handleFileChanged(data); + break; + + case RoomMessage.DOC_OVERWRITTEN: + case RoomMessage.FILE_OVERWRITTEN: + if (this._dialog) { + this._dialog.close(); + this._dialog = null; + } + break; + } + } + + private _handleFileChanged(data: string): void { + this._dialog = new Dialog({ + title: this._trans.__('File changed'), + body: this._trans.__('Do you want to overwrite the file or reload it?'), + buttons: [ + Dialog.okButton({ label: 'Reload' }), + Dialog.warnButton({ label: 'Overwrite' }) + ], + hasClose: false + }); + + this._dialog.launch().then(resp => { + if (resp.button.label === 'Reload') { + this._sendReloadMsg(data); + } else if (resp.button.label === 'Overwrite') { + this._sendOverwriteMsg(data); + } + }); + } + + private _sendReloadMsg(data: string): void { + const encoder = encoding.createEncoder(); + encoding.writeVarUint(encoder, MessageType.ROOM); + encoding.writeVarUint(encoder, RoomMessage.RELOAD); + encoding.writeVarString(encoder, data); + this._yWebsocketProvider?.ws!.send(encoding.toUint8Array(encoder)); + } + + private _sendOverwriteMsg(data: string): void { + const encoder = encoding.createEncoder(); + encoding.writeVarUint(encoder, MessageType.ROOM); + encoding.writeVarUint(encoder, RoomMessage.OVERWRITE); + encoding.writeVarString(encoder, data); + this._yWebsocketProvider?.ws!.send(encoding.toUint8Array(encoder)); + } + + private _dialog: Dialog | null = null; private _awareness: Awareness; private _contentType: string; private _format: string;