From 91c56d11eb0f26b1bda62beef4048f34429f1b5f Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 31 Oct 2023 15:12:06 +0000 Subject: [PATCH] Backward compatibility fix for the Conversation class (#27176) * Backward compatibility fix for the Conversation class * Explain what's going on in the conditional --- src/transformers/pipelines/conversational.py | 20 +++++++++++++++---- .../test_pipelines_conversational.py | 10 +++++----- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/transformers/pipelines/conversational.py b/src/transformers/pipelines/conversational.py index 2beaf8cc2eaf..96a16e5b0f8f 100644 --- a/src/transformers/pipelines/conversational.py +++ b/src/transformers/pipelines/conversational.py @@ -54,6 +54,7 @@ def __init__( # This block deals with the legacy args - new code should just totally # avoid past_user_inputs and generated_responses + self._num_processed_user_inputs = 0 generated_responses = deprecated_kwargs.pop("generated_responses", None) past_user_inputs = deprecated_kwargs.pop("past_user_inputs", None) if generated_responses is not None and past_user_inputs is None: @@ -114,10 +115,11 @@ def append_response(self, response: str): def mark_processed(self): """ - This is a legacy method that no longer has any effect, as the Conversation no longer distinguishes between - processed and unprocessed user input. + This is a legacy method, as the Conversation no longer distinguishes between processed and unprocessed user + input. We set a counter here to keep behaviour mostly backward-compatible, but in general you should just read + the messages directly when writing new code. """ - pass + self._num_processed_user_inputs = len(self._user_messages) def __iter__(self): for message in self.messages: @@ -163,7 +165,17 @@ def _user_messages(self): @property def past_user_inputs(self): # This is a legacy property for backwards compatibility. It is recommended to just directly access - # conversation.messages instead. + # conversation.messages instead. The modern class does not care about which messages are "processed" + # or not. + if not self._user_messages: + return [] + # In the past, the most recent user message had to be mark_processed() before being included + # in past_user_messages. The class essentially had a single-message buffer, representing messages that + # had not yet been replied to. This is no longer the case, but we mimic the behaviour in this property + # for backward compatibility. + if self.messages[-1]["role"] != "user" or self._num_processed_user_inputs == len(self._user_messages): + return self._user_messages + return self._user_messages[:-1] @property diff --git a/tests/pipelines/test_pipelines_conversational.py b/tests/pipelines/test_pipelines_conversational.py index ba3b37055fd1..6ba2d8379d2a 100644 --- a/tests/pipelines/test_pipelines_conversational.py +++ b/tests/pipelines/test_pipelines_conversational.py @@ -136,8 +136,8 @@ def test_integration_torch_conversation(self): conversation_1 = Conversation("Going to the movies tonight - any suggestions?") conversation_2 = Conversation("What's the last book you have read?") # Then - self.assertEqual(len(conversation_1.past_user_inputs), 1) - self.assertEqual(len(conversation_2.past_user_inputs), 1) + self.assertEqual(len(conversation_1.past_user_inputs), 0) + self.assertEqual(len(conversation_2.past_user_inputs), 0) # When result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000) # Then @@ -167,7 +167,7 @@ def test_integration_torch_conversation_truncated_history(self): conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=torch_device) conversation_1 = Conversation("Going to the movies tonight - any suggestions?") # Then - self.assertEqual(len(conversation_1.past_user_inputs), 1) + self.assertEqual(len(conversation_1.past_user_inputs), 0) # When result = conversation_agent(conversation_1, do_sample=False, max_length=36) # Then @@ -375,8 +375,8 @@ def test_integration_torch_conversation_encoder_decoder(self): conversation_1 = Conversation("My name is Sarah and I live in London") conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ") # Then - self.assertEqual(len(conversation_1.past_user_inputs), 1) - self.assertEqual(len(conversation_2.past_user_inputs), 1) + self.assertEqual(len(conversation_1.past_user_inputs), 0) + self.assertEqual(len(conversation_2.past_user_inputs), 0) # When result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000) # Then