From 5364a84576afcd1a835d02d8c3f21dc82981cedb Mon Sep 17 00:00:00 2001 From: Wei Ouyang Date: Sat, 9 Oct 2021 21:05:51 +0200 Subject: [PATCH] refactor service config and support require context (#227) --- imjoy/core/__init__.py | 28 ++++++++++++++++++--- imjoy/core/interface.py | 56 +++++++++++++++++++++-------------------- imjoy/server.py | 2 +- requirements.txt | 2 +- setup.py | 4 +-- tests/test_server.py | 18 ++++++++++++- 6 files changed, 75 insertions(+), 35 deletions(-) diff --git a/imjoy/core/__init__.py b/imjoy/core/__init__.py index 8d6ceb59..d3e644d3 100644 --- a/imjoy/core/__init__.py +++ b/imjoy/core/__init__.py @@ -63,13 +63,22 @@ class StatusEnum(str, Enum): not_initialized = "not_initialized" +class ServiceConfig(BaseModel): + """Represent service config.""" + + visibility: VisibilityEnum = VisibilityEnum.protected + require_context: bool = False + workspace: str + id: str + + class ServiceInfo(BaseModel): """Represent service.""" - config: Dict[str, Any] + config: ServiceConfig name: str type: str - visibility: VisibilityEnum = VisibilityEnum.protected + _provider: DynamicPlugin = PrivateAttr(default_factory=lambda: None) class Config: @@ -78,9 +87,22 @@ class Config: extra = Extra.allow def set_provider(self, provider: DynamicPlugin) -> None: - """Return the plugins.""" + """Set the provider plugin.""" self._provider = provider + def get_provider(self) -> DynamicPlugin: + """Get the provider plugin.""" + return self._provider + + def get_summary(self) -> dict: + """Get a summary about the service.""" + return { + "name": self.name, + "type": self.type, + "provider": self._provider.name, + "provider_id": self._provider.id, + }.update(self.config.dict()) + class UserInfo(BaseModel): """Represent user info.""" diff --git a/imjoy/core/interface.py b/imjoy/core/interface.py index 8bfea384..04f7e26a 100644 --- a/imjoy/core/interface.py +++ b/imjoy/core/interface.py @@ -235,29 +235,31 @@ def register_service(self, service: dict): raise Exception("Service should at least contain `name` and `type`") # TODO: check if it's already exists - config = service.get("config", {}) - assert isinstance(config, dict), "service.config must be a dictionary" - if config.get("name") and service["name"] != config.get("name"): - raise Exception("Service name should match the one in the service.config.") - if config.get("type") and service["type"] != config.get("type"): - raise Exception("Service type should match the one in the service.config.") - if config.get("visibility") and service["visibility"] != config.get( - "visibility" - ): - raise Exception( - "Service visibility should match the one in the service.config." - ) - service["visibility"] = service.get("visibility", "protected") - config["name"] = service["name"] - config["type"] = service["type"] - config["visibility"] = service["visibility"] - config["workspace"] = workspace.name - config["id"] = service_id - config["provider"] = plugin.name - config["provider_id"] = plugin.id - service["config"] = config + service.config = service.get("config", {}) + assert isinstance(service.config, dict), "service.config must be a dictionary" + service.config["id"] = service_id + service.config["workspace"] = workspace.name formated_service = ServiceInfo.parse_obj(service) formated_service.set_provider(plugin) + service_dict = formated_service.dict() + if formated_service.config.require_context: + for key in service_dict: + if callable(service_dict[key]): + + def wrap_func(func, *args, **kwargs): + user_info = self.current_user.get() + workspace = self.current_workspace.get() + kwargs["context"] = { + "user_id": user_info.id, + "email": user_info.email, + "is_anonymous": user_info.email, + "workspace": workspace.name, + } + return func(*args, **kwargs) + + setattr( + formated_service, key, partial(wrap_func, service_dict[key]) + ) # service["_rintf"] = True # Note: service can set its `visibility` to `public` or `protected` workspace.set_service(formated_service.name, formated_service) @@ -295,7 +297,7 @@ async def get_service(self, service_id): user_info = self.current_user.get() if ( not self.check_permission(workspace, user_info) - and service.config.get("visibility", "protected") != "public" + and service.config.visibility != VisibilityEnum.public ): raise Exception(f"Permission denied: {service_id}") @@ -323,15 +325,15 @@ def list_services(self, query: Optional[dict] = None): # To access the service, it should be public or owned by the user if ( not can_access_ws - and service.config.get("visibility", "protected") != "public" + and service.config.visibility != VisibilityEnum.public ): continue match = True for key in query: - if service.config[key] != query[key]: + if getattr(service, key) != query[key]: match = False if match: - ret.append(service.config) + ret.append(service.get_summary()) return ret if ws is not None: workspace = self.get_workspace(ws) @@ -342,10 +344,10 @@ def list_services(self, query: Optional[dict] = None): for service in workspace_services.values(): match = True for key in query: - if service.config[key] != query[key]: + if getattr(service, key) != query[key]: match = False if match: - ret.append(service.config) + ret.append(service.get_summary()) if workspace is None: raise Exception("Workspace not found: {ws}") diff --git a/imjoy/server.py b/imjoy/server.py index 97ceccf0..fe428a66 100644 --- a/imjoy/server.py +++ b/imjoy/server.py @@ -207,7 +207,7 @@ async def disconnect(sid): # with the plugins of the previous owners plugin_services = plugin.workspace.get_services() for service in list(plugin_services.values()): - if service.config.get("provider_id") == plugin.id: + if service.get_provider() == plugin: plugin.workspace.remove_service(service.name) del core_interface.all_sessions[sid] bus.emit("plugin_disconnected", {"sid": sid}) diff --git a/requirements.txt b/requirements.txt index adbda4f8..490ad0e2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ fastapi==0.68.1 imjoy-jupyter-extension==0.2.17 -imjoy-rpc==0.3.12 +imjoy-rpc==0.3.30 ipykernel==5.5.5 jupyter==1.0.0 numpy==1.19.5 # needs to stay compatible with latest tensorflow diff --git a/setup.py b/setup.py index 7a77e4a4..13bdbdd6 100644 --- a/setup.py +++ b/setup.py @@ -14,11 +14,11 @@ # pylint: disable=unused-import import google.colab.output # noqa: F401 - REQUIREMENTS = ["numpy", "imjoy-rpc>=0.3.12", "imjoy-elfinder"] + REQUIREMENTS = ["numpy", "imjoy-rpc>=0.3.30", "imjoy-elfinder"] except ImportError: REQUIREMENTS = [ "numpy", - "imjoy-rpc>=0.3.12", + "imjoy-rpc>=0.3.30", "pydantic[email]>=1.8.2", "typing-extensions>=3.7.4.3", # required by pydantic "python-dotenv>=0.17.0", diff --git a/tests/test_server.py b/tests/test_server.py index 745e826a..6e93504e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -166,7 +166,23 @@ async def test_workspace(socketio_server): } ) service = await ws.get_service(service_id) - assert service.config["name"] == "test_service" + assert service["name"] == "test_service" + + def test(context=None): + return context + + service_id = await ws.register_service( + { + "name": "test_service", + "type": "#test", + "config": {"require_context": True}, + "test": test, + } + ) + service = await ws.get_service(service_id) + context = await service.test() + assert "user_id" in context and "email" in context + assert service["name"] == "test_service" # we should not get it because api is in another workspace ss2 = await api.list_services({"type": "#test"})