Skip to content

Commit

Permalink
ocp_entities
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl committed Jan 9, 2024
1 parent e602d07 commit 268c2a1
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 49 deletions.
146 changes: 98 additions & 48 deletions ovos_core/intent_services/ocp_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,23 @@
from threading import RLock
from typing import List

from ovos_bus_client.apis.ocp import OCPInterface, OCPQuery
from ovos_bus_client.message import Message
from ovos_utils import classproperty
from ovos_utils.log import LOG
from ovos_utils.messagebus import FakeBus
from ovos_utils.ocp import MediaType, PlaybackType, PlaybackMode, PlayerState, OCP_ID
from padacioso import IntentContainer
from sklearn.pipeline import FeatureUnion

import ovos_core.intent_services
from ovos_bus_client.apis.ocp import OCPInterface, OCPQuery
from ovos_bus_client.message import Message
from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier
from ovos_classifiers.skovos.features import ClassifierProbaVectorizer, KeywordFeaturesVectorizer
from ovos_workshop.app import OVOSAbstractApplication


class OCPFeaturizer:
ocp_keywords = KeywordFeaturesVectorizer()
ocp_keywords = KeywordFeaturesVectorizer(ignore_list=["play"])
ocp_keywords.load_entities(f"{dirname(__file__)}/models/ocp_entities_v0.csv")

def __init__(self, base_clf=None):
Expand All @@ -31,6 +32,40 @@ def __init__(self, base_clf=None):
base_clf = SklearnOVOSClassifier.from_file(clf_path)
self.clf_feats = ClassifierProbaVectorizer(base_clf)

@classproperty
def labels(cls):
"""
in V0 classifier using synth dataset
lbls = ['ad_keyword', 'album_name', 'anime_genre', 'anime_name', 'anime_streaming_service',
'artist_name', 'asmr_keyword', 'asmr_trigger', 'audio_genre', 'audiobook_narrator',
'audiobook_streaming_service', 'book_author', 'book_genre', 'book_name',
'bw_movie_name', 'cartoon_genre', 'cartoon_name', 'cartoon_streaming_service',
'comic_name', 'comic_streaming_service', 'comics_genre', 'country_name',
'documentary_genre', 'documentary_name', 'documentary_streaming_service',
'film_genre', 'film_studio', 'game_genre', 'game_name', 'gaming_console_name',
'generic_streaming_service', 'hentai_name', 'hentai_streaming_service',
'media_type_adult', 'media_type_adult_audio', 'media_type_anime', 'media_type_audio',
'media_type_audiobook', 'media_type_bts', 'media_type_bw_movie', 'media_type_cartoon',
'media_type_documentary', 'media_type_game', 'media_type_hentai', 'media_type_movie',
'media_type_music', 'media_type_news', 'media_type_podcast', 'media_type_radio',
'media_type_radio_theatre', 'media_type_short_film', 'media_type_silent_movie',
'media_type_sound', 'media_type_trailer', 'media_type_tv', 'media_type_video',
'media_type_video_episodes', 'media_type_visual_story', 'movie_actor',
'movie_director', 'movie_name', 'movie_streaming_service', 'music_genre',
'music_streaming_service', 'news_provider', 'news_streaming_service',
'play_verb_audio', 'play_verb_video', 'playback_device', 'playlist_name',
'podcast_genre', 'podcast_name', 'podcast_streaming_service', 'podcaster',
'porn_film_name', 'porn_genre', 'porn_streaming_service', 'pornstar_name',
'radio_drama_actor', 'radio_drama_genre', 'radio_drama_name', 'radio_program',
'radio_program_name', 'radio_streaming_service', 'radio_theatre_company',
'radio_theatre_streaming_service', 'record_label', 'series_name',
'short_film_name', 'shorts_streaming_service', 'silent_movie_name',
'song_name', 'sound_name', 'soundtrack_keyword', 'tv_channel', 'tv_genre',
'tv_streaming_service', 'video_genre', 'video_streaming_service', 'youtube_channel']
"""
return cls.ocp_keywords._transformer.labels

def transform(self, X):
if self.clf_feats:
vec = FeatureUnion([
Expand Down Expand Up @@ -169,20 +204,27 @@ def handle_skill_keyword_register(self, message: Message):
skill_id = message.data["skill_id"]
kw_label = message.data["label"]
media = message.data["media_type"]
samples = message.data["samples"]
langs = message.data["langs"]
samples = message.data.get("samples", [])
csv_path = message.data.get("csv")

# set bias in classifier
OCPFeaturizer.ocp_keywords.register_entity(kw_label, samples)
# NB: we need to validate labels,
# they MUST be part of the classifier training data

if kw_label in OCPFeaturizer.labels:
# set bias in classifier
if csv_path:
OCPFeaturizer.ocp_keywords.load_entities(csv_path)
if samples:
OCPFeaturizer.ocp_keywords.register_entity(kw_label, samples)
OCPFeaturizer.ocp_keywords.fit() # update

def handle_skill_keyword_deregister(self, message: Message):
skill_id = message.data["skill_id"]
kw_label = message.data["label"]
media = message.data["media_type"]
langs = message.data["langs"]

# unset bias in classifier
OCPFeaturizer.ocp_keywords.deregister_entity(kw_label)
# OCPFeaturizer.ocp_keywords.deregister_entity(kw_label)

def handle_player_state_update(self, message: Message):
"""
Expand Down Expand Up @@ -219,6 +261,7 @@ def match_high(self, utterances: List[str], lang: str, message: Message = None):
if match["name"] is None:
return None
if match["name"] == "play":
utterance = match["entities"].pop("query")
return self._process_play_query(utterance, lang, match)

if self.player_state == PlayerState.STOPPED:
Expand Down Expand Up @@ -367,7 +410,6 @@ def handle_play_intent(self, message: Message):
if e == media_type:
media_type = e
break

results = self._search(query, media_type, lang)

# tell OCP to play
Expand Down Expand Up @@ -434,6 +476,8 @@ def _do_play(self, phrase: str, results, media_type=MediaType.GENERIC):
# NLP
@staticmethod
def label2media(label: str):
if isinstance(label, MediaType):
return label
if label == "ad":
mt = MediaType.AUDIO_DESCRIPTION
elif label == "adult":
Expand Down Expand Up @@ -487,6 +531,7 @@ def label2media(label: str):
elif label == "video":
mt = MediaType.VIDEO
else:
LOG.error(f"bad label {label}")
mt = MediaType.GENERIC
return mt

Expand All @@ -506,7 +551,11 @@ def classify_media(self, query: str, lang: str):
prob = preds[label]
LOG.info(f"OVOSCommonPlay MediaType prediction: {label} confidence: {prob}")
LOG.debug(f" utterance: {query}")
return self.label2media(label), prob
if prob < self.config.get("classifier_threshold", 0.5):
LOG.info("ignoring MediaType classifier, low confidence prediction")
return MediaType.GENERIC, prob
else:
return self.label2media(label), prob

def is_ocp_query(self, query: str, lang: str):
""" determine if a playback question is being asked"""
Expand Down Expand Up @@ -541,65 +590,66 @@ def _should_resume(self, phrase: str, lang: str) -> bool:
return False

# search
def filter_results(self, results: list) -> list:
def filter_results(self, results: list, phrase: str, lang: str,
media_type: MediaType = MediaType.GENERIC) -> list:

# for debugging TODO delete
# from pprint import pprint
# pprint(results)

# ignore very low score matches
l1 = len(results)
results = [r for r in results
if r["match_confidence"] >= self.config.get("min_score", 50)]
LOG.debug(f"filtered {len(results) - l1} low confidence results")

# TODO filter based on available stream handlers
return results

def _search(self, phrase: str, media_type: MediaType, lang: str):
# filter based on MediaType (default disabled)
if self.config.get("filter_media") and media_type != MediaType.GENERIC:
l1 = len(results)
results = [r for r in results
if r["media_type"] == media_type]
LOG.debug(f"filtered {len(results) - l1} wrong MediaType results")

self.enclosure.mouth_think() # animate mk1 mouth during search
# TODO filter based on available stream handlers

# check if user said "play XXX audio only/no video"
audio_only = False
video_only = False
if self.voc_match(phrase, "audio_only", lang=lang):
audio_only = True
# dont include "audio only" in search query
phrase = self.remove_voc(phrase, "audio_only", lang=lang)

elif self.voc_match(phrase, "video_only", lang=lang):
video_only = True
# dont include "video only" in search query
phrase = self.remove_voc(phrase, "video_only", lang=lang)

# Now we place a query on the messsagebus for anyone who wants to
# attempt to service a 'play.request' message.
results = []
for r in self._execute_query(phrase, media_type=media_type):
results += r["results"]
LOG.debug(f"Got {len(results)} results")

# ignore very low score matches
results = self.filter_results(results)
LOG.debug(f"Got {len(results)} usable results")

# check if user said "play XXX audio only"
if audio_only:
LOG.info("audio only requested, forcing audio playback "
"unconditionally")
for idx, r in enumerate(results):
# force streams to be played audio only
results[idx]["playback"] = PlaybackType.AUDIO
l1 = len(results)
results = [r for r in results
if r["playback"] == PlaybackType.AUDIO]
LOG.debug(f"filtered {len(results) - l1} non-audio results")

# check if user said "play XXX video only"
elif video_only:
LOG.info("video only requested, filtering non-video results")
for idx, r in enumerate(results):
if results[idx]["media_type"] == MediaType.VIDEO:
# force streams to be played in video mode, even if
# audio playback requested
results[idx]["playback"] = PlaybackType.VIDEO

# filter audio only streams
l1 = len(results)
results = [r for r in results
if r["playback"] == PlaybackType.VIDEO]
LOG.debug(f"filtered {len(results) - l1} non-video results")

return results

def _search(self, phrase: str, media_type: MediaType, lang: str):

self.enclosure.mouth_think() # animate mk1 mouth during search

LOG.debug(f"Returning {len(results)} results")
# Now we place a query on the messsagebus for anyone who wants to
# attempt to service a 'play.request' message.
results = []
for r in self._execute_query(phrase, media_type=media_type):
results += r["results"]

LOG.debug(f"Got {len(results)} results")
results = self.filter_results(results, phrase, lang)
LOG.debug(f"Got {len(results)} usable results")
return results

def _execute_query(self, phrase: str, media_type: MediaType = MediaType.GENERIC):
Expand Down Expand Up @@ -671,8 +721,8 @@ def select_best(self, results):
# TODO: Ask user to pick between ties or do it automagically
else:
selected = best
LOG.debug(f"OVOSCommonPlay selected: {selected['skill_id']} - "
f"{selected['match_confidence']}")
LOG.info(f"OVOSCommonPlay selected: {selected['skill_id']} - {selected['match_confidence']}")
LOG.debug(str(selected))
return selected


Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ ovos-backend-client~=0.1.0
ovos-workshop<0.1.0, >=0.0.15

# provides plugins and classic machine learning framework
ovos-classifiers<0.1.0, >=0.0.0a44
ovos-classifiers<0.1.0, >=0.0.0a47

0 comments on commit 268c2a1

Please sign in to comment.