Skip to content
This repository has been archived by the owner on Aug 16, 2022. It is now read-only.

Commit

Permalink
refactor service config and support require context (#227)
Browse files Browse the repository at this point in the history
  • Loading branch information
oeway authored Oct 9, 2021
1 parent 3ee4d39 commit 5364a84
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 35 deletions.
28 changes: 25 additions & 3 deletions imjoy/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,22 @@ class StatusEnum(str, Enum):
not_initialized = "not_initialized"


class ServiceConfig(BaseModel):
"""Represent service config."""

visibility: VisibilityEnum = VisibilityEnum.protected
require_context: bool = False
workspace: str
id: str


class ServiceInfo(BaseModel):
"""Represent service."""

config: Dict[str, Any]
config: ServiceConfig
name: str
type: str
visibility: VisibilityEnum = VisibilityEnum.protected

_provider: DynamicPlugin = PrivateAttr(default_factory=lambda: None)

class Config:
Expand All @@ -78,9 +87,22 @@ class Config:
extra = Extra.allow

def set_provider(self, provider: DynamicPlugin) -> None:
"""Return the plugins."""
"""Set the provider plugin."""
self._provider = provider

def get_provider(self) -> DynamicPlugin:
"""Get the provider plugin."""
return self._provider

def get_summary(self) -> dict:
"""Get a summary about the service."""
return {
"name": self.name,
"type": self.type,
"provider": self._provider.name,
"provider_id": self._provider.id,
}.update(self.config.dict())


class UserInfo(BaseModel):
"""Represent user info."""
Expand Down
56 changes: 29 additions & 27 deletions imjoy/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,29 +235,31 @@ def register_service(self, service: dict):
raise Exception("Service should at least contain `name` and `type`")

# TODO: check if it's already exists
config = service.get("config", {})
assert isinstance(config, dict), "service.config must be a dictionary"
if config.get("name") and service["name"] != config.get("name"):
raise Exception("Service name should match the one in the service.config.")
if config.get("type") and service["type"] != config.get("type"):
raise Exception("Service type should match the one in the service.config.")
if config.get("visibility") and service["visibility"] != config.get(
"visibility"
):
raise Exception(
"Service visibility should match the one in the service.config."
)
service["visibility"] = service.get("visibility", "protected")
config["name"] = service["name"]
config["type"] = service["type"]
config["visibility"] = service["visibility"]
config["workspace"] = workspace.name
config["id"] = service_id
config["provider"] = plugin.name
config["provider_id"] = plugin.id
service["config"] = config
service.config = service.get("config", {})
assert isinstance(service.config, dict), "service.config must be a dictionary"
service.config["id"] = service_id
service.config["workspace"] = workspace.name
formated_service = ServiceInfo.parse_obj(service)
formated_service.set_provider(plugin)
service_dict = formated_service.dict()
if formated_service.config.require_context:
for key in service_dict:
if callable(service_dict[key]):

def wrap_func(func, *args, **kwargs):
user_info = self.current_user.get()
workspace = self.current_workspace.get()
kwargs["context"] = {
"user_id": user_info.id,
"email": user_info.email,
"is_anonymous": user_info.email,
"workspace": workspace.name,
}
return func(*args, **kwargs)

setattr(
formated_service, key, partial(wrap_func, service_dict[key])
)
# service["_rintf"] = True
# Note: service can set its `visibility` to `public` or `protected`
workspace.set_service(formated_service.name, formated_service)
Expand Down Expand Up @@ -295,7 +297,7 @@ async def get_service(self, service_id):
user_info = self.current_user.get()
if (
not self.check_permission(workspace, user_info)
and service.config.get("visibility", "protected") != "public"
and service.config.visibility != VisibilityEnum.public
):
raise Exception(f"Permission denied: {service_id}")

Expand Down Expand Up @@ -323,15 +325,15 @@ def list_services(self, query: Optional[dict] = None):
# To access the service, it should be public or owned by the user
if (
not can_access_ws
and service.config.get("visibility", "protected") != "public"
and service.config.visibility != VisibilityEnum.public
):
continue
match = True
for key in query:
if service.config[key] != query[key]:
if getattr(service, key) != query[key]:
match = False
if match:
ret.append(service.config)
ret.append(service.get_summary())
return ret
if ws is not None:
workspace = self.get_workspace(ws)
Expand All @@ -342,10 +344,10 @@ def list_services(self, query: Optional[dict] = None):
for service in workspace_services.values():
match = True
for key in query:
if service.config[key] != query[key]:
if getattr(service, key) != query[key]:
match = False
if match:
ret.append(service.config)
ret.append(service.get_summary())

if workspace is None:
raise Exception("Workspace not found: {ws}")
Expand Down
2 changes: 1 addition & 1 deletion imjoy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ async def disconnect(sid):
# with the plugins of the previous owners
plugin_services = plugin.workspace.get_services()
for service in list(plugin_services.values()):
if service.config.get("provider_id") == plugin.id:
if service.get_provider() == plugin:
plugin.workspace.remove_service(service.name)
del core_interface.all_sessions[sid]
bus.emit("plugin_disconnected", {"sid": sid})
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
fastapi==0.68.1
imjoy-jupyter-extension==0.2.17
imjoy-rpc==0.3.12
imjoy-rpc==0.3.30
ipykernel==5.5.5
jupyter==1.0.0
numpy==1.19.5 # needs to stay compatible with latest tensorflow
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
# pylint: disable=unused-import
import google.colab.output # noqa: F401

REQUIREMENTS = ["numpy", "imjoy-rpc>=0.3.12", "imjoy-elfinder"]
REQUIREMENTS = ["numpy", "imjoy-rpc>=0.3.30", "imjoy-elfinder"]
except ImportError:
REQUIREMENTS = [
"numpy",
"imjoy-rpc>=0.3.12",
"imjoy-rpc>=0.3.30",
"pydantic[email]>=1.8.2",
"typing-extensions>=3.7.4.3", # required by pydantic
"python-dotenv>=0.17.0",
Expand Down
18 changes: 17 additions & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,23 @@ async def test_workspace(socketio_server):
}
)
service = await ws.get_service(service_id)
assert service.config["name"] == "test_service"
assert service["name"] == "test_service"

def test(context=None):
return context

service_id = await ws.register_service(
{
"name": "test_service",
"type": "#test",
"config": {"require_context": True},
"test": test,
}
)
service = await ws.get_service(service_id)
context = await service.test()
assert "user_id" in context and "email" in context
assert service["name"] == "test_service"

# we should not get it because api is in another workspace
ss2 = await api.list_services({"type": "#test"})
Expand Down

0 comments on commit 5364a84

Please sign in to comment.