diff --git a/chatterbot/chatterbot.py b/chatterbot/chatterbot.py index cec453386..66280afcf 100644 --- a/chatterbot/chatterbot.py +++ b/chatterbot/chatterbot.py @@ -1,7 +1,7 @@ import logging from chatterbot.storage import StorageAdapter from chatterbot.logic import LogicAdapter -from chatterbot.search import IndexedTextSearch +from chatterbot.search import TextSearch, IndexedTextSearch from chatterbot import utils @@ -28,9 +28,11 @@ def __init__(self, name, **kwargs): self.storage = utils.initialize_class(storage_adapter, **kwargs) primary_search_algorithm = IndexedTextSearch(self, **kwargs) + text_search_algorithm = TextSearch(self, **kwargs) self.search_algorithms = { - primary_search_algorithm.name: primary_search_algorithm + primary_search_algorithm.name: primary_search_algorithm, + text_search_algorithm.name: text_search_algorithm } for adapter in logic_adapters: diff --git a/chatterbot/comparisons.py b/chatterbot/comparisons.py index 978f76354..741315677 100644 --- a/chatterbot/comparisons.py +++ b/chatterbot/comparisons.py @@ -3,7 +3,6 @@ designed to compare one statement to another. """ from difflib import SequenceMatcher -import spacy class Comparator: @@ -64,6 +63,7 @@ class SpacySimilarity(Comparator): def __init__(self, language): super().__init__(language) + import spacy self.nlp = spacy.load(self.language.ISO_639_1) diff --git a/chatterbot/logic/best_match.py b/chatterbot/logic/best_match.py index ebf04f075..0b56881e8 100644 --- a/chatterbot/logic/best_match.py +++ b/chatterbot/logic/best_match.py @@ -30,10 +30,10 @@ def process(self, input_statement, additional_response_selection_parameters=None # Search for the closest match to the input statement for result in search_results: + closest_match = result # Stop searching if a match that is close enough is found if result.confidence >= self.maximum_similarity_threshold: - closest_match = result break self.chatbot.logger.info('Using "{}" as a close match to "{}" with a confidence of {}'.format( diff --git a/chatterbot/search.py b/chatterbot/search.py index f3af44b7c..15b6990bb 100644 --- a/chatterbot/search.py +++ b/chatterbot/search.py @@ -1,10 +1,6 @@ -from chatterbot.conversation import Statement - - class IndexedTextSearch: """ - :param statement_comparison_function: The dot-notated import path - to a statement comparison function. + :param statement_comparison_function: A comparison class. Defaults to ``LevenshteinDistance``. :param search_page_size: @@ -69,8 +65,81 @@ def search(self, input_statement, **additional_parameters): statement_list = self.chatbot.storage.filter(**search_parameters) - closest_match = Statement(text='') - closest_match.confidence = 0 + best_confidence_so_far = 0 + + self.chatbot.logger.info('Processing search results') + + # Find the closest matching known statement + for statement in statement_list: + confidence = self.compare_statements(input_statement, statement) + + if confidence > best_confidence_so_far: + best_confidence_so_far = confidence + statement.confidence = confidence + + self.chatbot.logger.info('Similar text found: {} {}'.format( + statement.text, confidence + )) + + yield statement + + +class TextSearch: + """ + :param statement_comparison_function: A comparison class. + Defaults to ``LevenshteinDistance``. + + :param search_page_size: + The maximum number of records to load into memory at a time when searching. + Defaults to 1000 + """ + + name = 'text_search' + + def __init__(self, chatbot, **kwargs): + from chatterbot.comparisons import LevenshteinDistance + + self.chatbot = chatbot + + statement_comparison_function = kwargs.get( + 'statement_comparison_function', + LevenshteinDistance + ) + + self.compare_statements = statement_comparison_function( + language=self.chatbot.storage.tagger.language + ) + + self.search_page_size = kwargs.get( + 'search_page_size', 1000 + ) + + def search(self, input_statement, **additional_parameters): + """ + Search for close matches to the input. Confidence scores for + subsequent results will order of increasing value. + + :param input_statement: A statement. + :type input_statement: chatterbot.conversation.Statement + + :param **additional_parameters: Additional parameters to be passed + to the ``filter`` method of the storage adapter when searching. + + :rtype: Generator yielding one closest matching statement at a time. + """ + self.chatbot.logger.info('Beginning search for close text match') + + search_parameters = { + 'persona_not_startswith': 'bot:', + 'page_size': self.search_page_size + } + + if additional_parameters: + search_parameters.update(additional_parameters) + + statement_list = self.chatbot.storage.filter(**search_parameters) + + best_confidence_so_far = 0 self.chatbot.logger.info('Processing search results') @@ -78,12 +147,12 @@ def search(self, input_statement, **additional_parameters): for statement in statement_list: confidence = self.compare_statements(input_statement, statement) - if confidence > closest_match.confidence: + if confidence > best_confidence_so_far: + best_confidence_so_far = confidence statement.confidence = confidence - closest_match = statement self.chatbot.logger.info('Similar text found: {} {}'.format( - closest_match.text, confidence + statement.text, confidence )) - yield closest_match + yield statement diff --git a/tests/logic/test_best_match.py b/tests/logic/test_best_match.py index b2e5f9bdc..8335890e3 100644 --- a/tests/logic/test_best_match.py +++ b/tests/logic/test_best_match.py @@ -134,3 +134,33 @@ def test_low_confidence_options_list(self): self.assertEqual(match.confidence, 0) self.assertEqual(match.text, 'No') + + def test_text_search_algorithm(self): + """ + Test that a close match is found when the text_search algorithm is used. + """ + self.adapter = BestMatch( + self.chatbot, + search_algorithm_name='text_search' + ) + + self.chatbot.storage.create( + text='I am hungry.' + ) + self.chatbot.storage.create( + text='Okay, what would you like to eat?', + in_response_to='I am hungry.' + ) + self.chatbot.storage.create( + text='Can you help me?' + ) + self.chatbot.storage.create( + text='Sure, what seems to be the problem?', + in_response_to='Can you help me?' + ) + + statement = Statement(text='Could you help me?') + match = self.adapter.process(statement) + + self.assertEqual(match.confidence, 0.82) + self.assertEqual(match.text, 'Sure, what seems to be the problem?') diff --git a/tests/test_search.py b/tests/test_search.py index 08f77564d..3f9793ae0 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1,6 +1,6 @@ from tests.base_case import ChatBotTestCase from chatterbot.conversation import Statement -from chatterbot.search import IndexedTextSearch +from chatterbot.search import TextSearch, IndexedTextSearch from chatterbot import comparisons @@ -50,7 +50,7 @@ def test_search_additional_parameters(self): self.assertEqual(results[0].conversation, 'test_1') -class SearchComparisonFunctionSpacySimilarityTests(ChatBotTestCase): +class IndexedTextSearchComparisonFunctionSpacySimilarityTests(ChatBotTestCase): """ Test that the search algorithm works correctly with the spacy similarity comparison function. @@ -97,7 +97,7 @@ def test_different_punctuation(self): self.assertEqual(results[-1].text, 'Are you good?') -class SearchComparisonFunctionLevenshteinDistanceComparisonTests(ChatBotTestCase): +class IndexedTextSearchComparisonFunctionLevenshteinDistanceComparisonTests(ChatBotTestCase): """ Test that the search algorithm works correctly with the Levenshtein distance comparison function. @@ -166,3 +166,121 @@ def test_confidence_no_match(self): results = list(self.search_algorithm.search(statement)) self.assertIsLength(results, 0) + + +class TextSearchComparisonFunctionSpacySimilarityTests(ChatBotTestCase): + """ + Test that the search algorithm works correctly with the + spacy similarity comparison function. + """ + + def setUp(self): + super().setUp() + self.search_algorithm = TextSearch( + self.chatbot, + statement_comparison_function=comparisons.SpacySimilarity + ) + + def test_get_closest_statement(self): + """ + Note, the content of the in_response_to field for each of the + test statements is only required because the logic adapter will + filter out any statements that are not in response to a known statement. + """ + self.chatbot.storage.create_many([ + Statement(text='This is a lovely bog.', in_response_to='This is a lovely bog.'), + Statement(text='This is a beautiful swamp.', in_response_to='This is a beautiful swamp.'), + Statement(text='It smells like a swamp.', in_response_to='It smells like a swamp.') + ]) + + statement = Statement(text='This is a lovely swamp.') + results = list(self.search_algorithm.search(statement)) + + self.assertIsLength(results, 2) + self.assertEqual(results[-1].text, 'This is a beautiful swamp.') + self.assertGreater(results[-1].confidence, 0) + + def test_different_punctuation(self): + self.chatbot.storage.create_many([ + Statement(text='Who are you?'), + Statement(text='Are you good?'), + Statement(text='You are good') + ]) + + statement = Statement(text='Are you good') + results = list(self.search_algorithm.search(statement)) + + self.assertEqual(len(results), 2) + # Note: the last statement in the list always has the highest confidence + self.assertEqual(results[-1].text, 'Are you good?') + + +class TextSearchComparisonFunctionLevenshteinDistanceComparisonTests(ChatBotTestCase): + """ + Test that the search algorithm works correctly with the + Levenshtein distance comparison function. + """ + + def setUp(self): + super().setUp() + self.search_algorithm = TextSearch( + self.chatbot, + statement_comparison_function=comparisons.LevenshteinDistance + ) + + def test_get_closest_statement(self): + """ + Note, the content of the in_response_to field for each of the + test statements is only required because the search process will + filter out any statements that are not in response to something. + """ + self.chatbot.storage.create_many([ + Statement(text='What is the meaning of life?', in_response_to='...'), + Statement(text='I am Iron Man.', in_response_to='...'), + Statement(text='What... is your quest?', in_response_to='...'), + Statement(text='Yuck, black licorice jelly beans.', in_response_to='...'), + Statement(text='I hear you are going on a quest?', in_response_to='...'), + ]) + + statement = Statement(text='What is your quest?') + + results = list(self.search_algorithm.search(statement)) + + self.assertEqual(len(results), 2) + self.assertEqual(results[-1].text, 'What... is your quest?', msg=results[-1].confidence) + + def test_confidence_exact_match(self): + self.chatbot.storage.create(text='What is your quest?', in_response_to='What is your quest?') + + statement = Statement(text='What is your quest?') + results = list(self.search_algorithm.search(statement)) + + self.assertIsLength(results, 1) + self.assertEqual(results[0].confidence, 1) + + def test_confidence_half_match(self): + from unittest.mock import MagicMock + + # Assume that the storage adapter returns a partial match + self.chatbot.storage.filter = MagicMock(return_value=[ + Statement(text='xxyy') + ]) + + statement = Statement(text='wwxx') + results = list(self.search_algorithm.search(statement)) + + self.assertIsLength(results, 1) + self.assertEqual(results[0].confidence, 0.5) + + def test_confidence_no_match(self): + from unittest.mock import MagicMock + + # Assume that the storage adapter returns a partial match + self.search_algorithm.chatbot.storage.filter = MagicMock(return_value=[ + Statement(text='xxx', in_response_to='xxx') + ]) + + statement = Statement(text='yyy') + results = list(self.search_algorithm.search(statement)) + + self.assertIsLength(results, 0)