Skip to content
This repository has been archived by the owner on Dec 22, 2024. It is now read-only.

filter allowed_types #71

Merged
merged 1 commit into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions hivemind_core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
def cast_to_client_obj():
valid_kwargs: Iterable[str] = ("client_id", "api_key", "name",
"description", "is_admin", "last_seen",
"blacklist", "crypto_key", "password")
"blacklist", "allowed_types", "crypto_key",
"password")

def _handler(func):

Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(self,
is_admin: bool = False,
last_seen: float = -1,
blacklist: Optional[Dict[str, List[str]]] = None,
allowed_types: Optional[List[str]] = None,
crypto_key: Optional[str] = None,
password: Optional[str] = None):

Expand All @@ -62,6 +64,9 @@ def __init__(self,
"skills": [],
"intents": []
}
self.allowed_types = allowed_types or ["recognizer_loop:utterance"]
if "recognizer_loop:utterance" not in self.allowed_types:
self.allowed_types.append("recognizer_loop:utterance")

def __getitem__(self, item: str) -> Any:
return self.__dict__.get(item)
Expand Down Expand Up @@ -179,6 +184,7 @@ def add_client(self,
key: str = "",
admin: bool = False,
blacklist: Optional[Dict[str, Any]] = None,
allowed_types: Optional[List[str]] = None,
crypto_key: Optional[str] = None,
password: Optional[str] = None) -> Client:

Expand All @@ -191,6 +197,8 @@ def add_client(self,
user["name"] = name
if blacklist:
user["blacklist"] = blacklist
if allowed_types:
user["allowed_types"] = allowed_types
if admin is not None:
user["is_admin"] = admin
if crypto_key:
Expand All @@ -202,7 +210,8 @@ def add_client(self,
user = Client(api_key=key, name=name,
blacklist=blacklist, crypto_key=crypto_key,
client_id=self.total_clients() + 1,
is_admin=admin, password=password)
is_admin=admin, password=password,
allowed_types=allowed_types)
self.add_item(user)
return user

Expand Down
20 changes: 17 additions & 3 deletions hivemind_core/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class HiveMindClientConnection:
socket: Optional[WebSocketHandler] = None
crypto_key: Optional[str] = None
blacklist: List[str] = field(default_factory=list) # list of ovos message_type to never be sent to this client
allowed_types: List[str] = field(default_factory=list) # list of ovos message_type to allow to be sent from this client

@property
def peer(self) -> str:
Expand All @@ -62,6 +63,16 @@ def peer(self) -> str:
return f"{self.name}:{self.ip}::{self.sess.session_id}"

def send(self, message: HiveMessage):
# TODO some cleaning around HiveMessage
if isinstance(message.payload, dict):
_msg_type = message.payload.get("type")
else:
_msg_type = message.payload.msg_type

if _msg_type in self.blacklist:
return LOG.debug(f"message type {_msg_type} "
f"is blacklisted for {self.peer}")

LOG.info(f"sending to {self.peer}: {message}")
payload = message.serialize() # json string
if self.crypto_key and message.msg_type not in [HiveMessageType.HANDSHAKE,
Expand Down Expand Up @@ -90,7 +101,7 @@ def decode(self, payload: str) -> HiveMessage:
def authorize(self, message: Message) -> bool:
""" parse the message being injected into ovos-core bus
if this client is not authorized to inject it return False"""
if message.msg_type in self.blacklist:
if message.msg_type not in self.allowed_types:
return False

# TODO check intent / skill that will trigger
Expand Down Expand Up @@ -185,7 +196,7 @@ def handle_internal_mycroft(self, message: str):
@dataclass()
class HiveMindListenerProtocol:
loop: ioloop.IOLoop
clients: dict = field(default_factory=dict)
clients = {}
internal_protocol: Optional[HiveMindListenerInternalProtocol] = None
peer: str = "master:0.0.0.0"

Expand Down Expand Up @@ -455,7 +466,10 @@ def handle_inject_mycroft_msg(self, message: Message, client: HiveMindClientConn

# ensure client specific session data is injected in query to ovos
message.context["session"] = client.sess.serialize()
message.context["destination"] = "skills" # ensure not treated as a broadcast
if message.msg_type == "speak":
message.context["destination"] = ["audio"]
elif message.context.get("destination") is None:
message.context["destination"] = "skills" # ensure not treated as a broadcast

# send client message to internal mycroft bus
LOG.info(f"Forwarding message to mycroft bus from client: {client.peer}")
Expand Down
53 changes: 52 additions & 1 deletion hivemind_core/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import click
from ovos_utils.xdg_utils import xdg_data_home
from rich.console import Console
from rich.prompt import Prompt
from rich.table import Table

from hivemind_core.database import ClientDatabase
Expand Down Expand Up @@ -51,7 +52,57 @@ def add_client(name, access_key, password, crypto_key):
print("WARNING: Encryption Key is deprecated, only use if your client does not support password")


@hmcore_cmds.command(help="remove credentials for a client (numeric unique ID)", name="delete-client")
@hmcore_cmds.command(help="allow message types sent from a client", name="allow-msg")
@click.argument("msg_type", required=True, type=str)
@click.argument("node_id", required=False, type=int)
def allow_msg(msg_type, node_id):
if not node_id:
# list clients and prompt for id using rich
table = Table(title="HiveMind Clients")
table.add_column("ID", justify="right", style="cyan", no_wrap=True)
table.add_column("Name", style="magenta")
table.add_column("Allowed Msg Types", style="yellow")
_choices = []
for client in ClientDatabase():
if client["client_id"] != -1:
table.add_row(str(client["client_id"]),
client["name"],
str(client.get("allowed_types", [])))
_choices.append(str(client["client_id"]))

if not _choices:
print("No clients found!")
exit()
elif len(_choices) > 1:
console = Console()
console.print(table)
_exit = str(max(int(i) for i in _choices) + 1)
node_id = Prompt.ask(f"To which client you want to add '{msg_type}'? ({_exit}='Exit')",
choices=_choices + [_exit])
if node_id == _exit:
console.print("User exit", style="red")
exit()
else:
node_id = _choices[0]

with ClientDatabase() as db:
for client in db:
if client["client_id"] == int(node_id):
allowed_types = client.get("allowed_types", [])
if msg_type in allowed_types:
print(f"Client {client['name']} already allowed '{msg_type}'")
exit()

allowed_types.append(msg_type)
client["allowed_types"] = allowed_types
item_id = db.get_item_id(client)
db.update_item(item_id, client)
print(f"Allowed '{msg_type}' for {client['name']}")
break


@hmcore_cmds.command(help="remove credentials for a client (numeric unique ID)",
name="delete-client")
@click.argument("node_id", required=True, type=int)
def delete_client(node_id):
with ClientDatabase() as db:
Expand Down
9 changes: 5 additions & 4 deletions hivemind_core/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,12 @@ def open(self):
self.close()
return

self.client.crypto_key = users.get_crypto_key(key)
pswd = users.get_password(key)
if pswd:
self.client.crypto_key = user.crypto_key
self.client.blacklist = user.blacklist.get("messages", [])
self.client.allowed_types = user.allowed_types
if user.password:
# pre-shared password to derive aes_key
self.client.pswd_handshake = PasswordHandShake(pswd)
self.client.pswd_handshake = PasswordHandShake(user.password)

self.client.node_type = HiveMindNodeType.NODE # TODO . placeholder

Expand Down