diff --git a/ovos_core/intent_services/ocp_service.py b/ovos_core/intent_services/ocp_service.py index 73d3616dd0b..5e0e22799f4 100644 --- a/ovos_core/intent_services/ocp_service.py +++ b/ovos_core/intent_services/ocp_service.py @@ -4,8 +4,7 @@ from threading import RLock from typing import List -from ovos_bus_client.apis.ocp import OCPInterface, OCPQuery -from ovos_bus_client.message import Message +from ovos_utils import classproperty from ovos_utils.log import LOG from ovos_utils.messagebus import FakeBus from ovos_utils.ocp import MediaType, PlaybackType, PlaybackMode, PlayerState, OCP_ID @@ -13,13 +12,15 @@ from sklearn.pipeline import FeatureUnion import ovos_core.intent_services +from ovos_bus_client.apis.ocp import OCPInterface, OCPQuery +from ovos_bus_client.message import Message from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier from ovos_classifiers.skovos.features import ClassifierProbaVectorizer, KeywordFeaturesVectorizer from ovos_workshop.app import OVOSAbstractApplication class OCPFeaturizer: - ocp_keywords = KeywordFeaturesVectorizer() + ocp_keywords = KeywordFeaturesVectorizer(ignore_list=["play"]) ocp_keywords.load_entities(f"{dirname(__file__)}/models/ocp_entities_v0.csv") def __init__(self, base_clf=None): @@ -31,6 +32,40 @@ def __init__(self, base_clf=None): base_clf = SklearnOVOSClassifier.from_file(clf_path) self.clf_feats = ClassifierProbaVectorizer(base_clf) + @classproperty + def labels(cls): + """ + in V0 classifier using synth dataset + + lbls = ['ad_keyword', 'album_name', 'anime_genre', 'anime_name', 'anime_streaming_service', + 'artist_name', 'asmr_keyword', 'asmr_trigger', 'audio_genre', 'audiobook_narrator', + 'audiobook_streaming_service', 'book_author', 'book_genre', 'book_name', + 'bw_movie_name', 'cartoon_genre', 'cartoon_name', 'cartoon_streaming_service', + 'comic_name', 'comic_streaming_service', 'comics_genre', 'country_name', + 'documentary_genre', 'documentary_name', 'documentary_streaming_service', + 'film_genre', 'film_studio', 'game_genre', 'game_name', 'gaming_console_name', + 'generic_streaming_service', 'hentai_name', 'hentai_streaming_service', + 'media_type_adult', 'media_type_adult_audio', 'media_type_anime', 'media_type_audio', + 'media_type_audiobook', 'media_type_bts', 'media_type_bw_movie', 'media_type_cartoon', + 'media_type_documentary', 'media_type_game', 'media_type_hentai', 'media_type_movie', + 'media_type_music', 'media_type_news', 'media_type_podcast', 'media_type_radio', + 'media_type_radio_theatre', 'media_type_short_film', 'media_type_silent_movie', + 'media_type_sound', 'media_type_trailer', 'media_type_tv', 'media_type_video', + 'media_type_video_episodes', 'media_type_visual_story', 'movie_actor', + 'movie_director', 'movie_name', 'movie_streaming_service', 'music_genre', + 'music_streaming_service', 'news_provider', 'news_streaming_service', + 'play_verb_audio', 'play_verb_video', 'playback_device', 'playlist_name', + 'podcast_genre', 'podcast_name', 'podcast_streaming_service', 'podcaster', + 'porn_film_name', 'porn_genre', 'porn_streaming_service', 'pornstar_name', + 'radio_drama_actor', 'radio_drama_genre', 'radio_drama_name', 'radio_program', + 'radio_program_name', 'radio_streaming_service', 'radio_theatre_company', + 'radio_theatre_streaming_service', 'record_label', 'series_name', + 'short_film_name', 'shorts_streaming_service', 'silent_movie_name', + 'song_name', 'sound_name', 'soundtrack_keyword', 'tv_channel', 'tv_genre', + 'tv_streaming_service', 'video_genre', 'video_streaming_service', 'youtube_channel'] + """ + return cls.ocp_keywords._transformer.labels + def transform(self, X): if self.clf_feats: vec = FeatureUnion([ @@ -169,20 +204,27 @@ def handle_skill_keyword_register(self, message: Message): skill_id = message.data["skill_id"] kw_label = message.data["label"] media = message.data["media_type"] - samples = message.data["samples"] - langs = message.data["langs"] + samples = message.data.get("samples", []) + csv_path = message.data.get("csv") - # set bias in classifier - OCPFeaturizer.ocp_keywords.register_entity(kw_label, samples) + # NB: we need to validate labels, + # they MUST be part of the classifier training data + + if kw_label in OCPFeaturizer.labels: + # set bias in classifier + if csv_path: + OCPFeaturizer.ocp_keywords.load_entities(csv_path) + if samples: + OCPFeaturizer.ocp_keywords.register_entity(kw_label, samples) + OCPFeaturizer.ocp_keywords.fit() # update def handle_skill_keyword_deregister(self, message: Message): skill_id = message.data["skill_id"] kw_label = message.data["label"] media = message.data["media_type"] - langs = message.data["langs"] # unset bias in classifier - OCPFeaturizer.ocp_keywords.deregister_entity(kw_label) + # OCPFeaturizer.ocp_keywords.deregister_entity(kw_label) def handle_player_state_update(self, message: Message): """ @@ -219,6 +261,7 @@ def match_high(self, utterances: List[str], lang: str, message: Message = None): if match["name"] is None: return None if match["name"] == "play": + utterance = match["entities"].pop("query") return self._process_play_query(utterance, lang, match) if self.player_state == PlayerState.STOPPED: @@ -367,7 +410,6 @@ def handle_play_intent(self, message: Message): if e == media_type: media_type = e break - results = self._search(query, media_type, lang) # tell OCP to play @@ -434,6 +476,8 @@ def _do_play(self, phrase: str, results, media_type=MediaType.GENERIC): # NLP @staticmethod def label2media(label: str): + if isinstance(label, MediaType): + return label if label == "ad": mt = MediaType.AUDIO_DESCRIPTION elif label == "adult": @@ -487,6 +531,7 @@ def label2media(label: str): elif label == "video": mt = MediaType.VIDEO else: + LOG.error(f"bad label {label}") mt = MediaType.GENERIC return mt @@ -506,7 +551,11 @@ def classify_media(self, query: str, lang: str): prob = preds[label] LOG.info(f"OVOSCommonPlay MediaType prediction: {label} confidence: {prob}") LOG.debug(f" utterance: {query}") - return self.label2media(label), prob + if prob < self.config.get("classifier_threshold", 0.5): + LOG.info("ignoring MediaType classifier, low confidence prediction") + return MediaType.GENERIC, prob + else: + return self.label2media(label), prob def is_ocp_query(self, query: str, lang: str): """ determine if a playback question is being asked""" @@ -541,65 +590,66 @@ def _should_resume(self, phrase: str, lang: str) -> bool: return False # search - def filter_results(self, results: list) -> list: + def filter_results(self, results: list, phrase: str, lang: str, + media_type: MediaType = MediaType.GENERIC) -> list: + + # for debugging TODO delete + # from pprint import pprint + # pprint(results) # ignore very low score matches + l1 = len(results) results = [r for r in results if r["match_confidence"] >= self.config.get("min_score", 50)] + LOG.debug(f"filtered {len(results) - l1} low confidence results") - # TODO filter based on available stream handlers - return results - - def _search(self, phrase: str, media_type: MediaType, lang: str): + # filter based on MediaType (default disabled) + if self.config.get("filter_media") and media_type != MediaType.GENERIC: + l1 = len(results) + results = [r for r in results + if r["media_type"] == media_type] + LOG.debug(f"filtered {len(results) - l1} wrong MediaType results") - self.enclosure.mouth_think() # animate mk1 mouth during search + # TODO filter based on available stream handlers # check if user said "play XXX audio only/no video" audio_only = False video_only = False if self.voc_match(phrase, "audio_only", lang=lang): audio_only = True - # dont include "audio only" in search query - phrase = self.remove_voc(phrase, "audio_only", lang=lang) elif self.voc_match(phrase, "video_only", lang=lang): video_only = True - # dont include "video only" in search query - phrase = self.remove_voc(phrase, "video_only", lang=lang) - - # Now we place a query on the messsagebus for anyone who wants to - # attempt to service a 'play.request' message. - results = [] - for r in self._execute_query(phrase, media_type=media_type): - results += r["results"] - LOG.debug(f"Got {len(results)} results") - - # ignore very low score matches - results = self.filter_results(results) - LOG.debug(f"Got {len(results)} usable results") # check if user said "play XXX audio only" if audio_only: - LOG.info("audio only requested, forcing audio playback " - "unconditionally") - for idx, r in enumerate(results): - # force streams to be played audio only - results[idx]["playback"] = PlaybackType.AUDIO + l1 = len(results) + results = [r for r in results + if r["playback"] == PlaybackType.AUDIO] + LOG.debug(f"filtered {len(results) - l1} non-audio results") # check if user said "play XXX video only" elif video_only: - LOG.info("video only requested, filtering non-video results") - for idx, r in enumerate(results): - if results[idx]["media_type"] == MediaType.VIDEO: - # force streams to be played in video mode, even if - # audio playback requested - results[idx]["playback"] = PlaybackType.VIDEO - - # filter audio only streams + l1 = len(results) results = [r for r in results if r["playback"] == PlaybackType.VIDEO] + LOG.debug(f"filtered {len(results) - l1} non-video results") + + return results + + def _search(self, phrase: str, media_type: MediaType, lang: str): + + self.enclosure.mouth_think() # animate mk1 mouth during search - LOG.debug(f"Returning {len(results)} results") + # Now we place a query on the messsagebus for anyone who wants to + # attempt to service a 'play.request' message. + results = [] + for r in self._execute_query(phrase, media_type=media_type): + results += r["results"] + + LOG.debug(f"Got {len(results)} results") + results = self.filter_results(results, phrase, lang) + LOG.debug(f"Got {len(results)} usable results") return results def _execute_query(self, phrase: str, media_type: MediaType = MediaType.GENERIC): @@ -671,8 +721,8 @@ def select_best(self, results): # TODO: Ask user to pick between ties or do it automagically else: selected = best - LOG.debug(f"OVOSCommonPlay selected: {selected['skill_id']} - " - f"{selected['match_confidence']}") + LOG.info(f"OVOSCommonPlay selected: {selected['skill_id']} - {selected['match_confidence']}") + LOG.debug(str(selected)) return selected diff --git a/requirements/requirements.txt b/requirements/requirements.txt index d9ee832315a..dc6032f80ec 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -15,4 +15,4 @@ ovos-backend-client~=0.1.0 ovos-workshop<0.1.0, >=0.0.15 # provides plugins and classic machine learning framework -ovos-classifiers<0.1.0, >=0.0.0a44 +ovos-classifiers<0.1.0, >=0.0.0a47