Skip to content

Commit

Permalink
Backward compatibility fix for the Conversation class (huggingface#27176
Browse files Browse the repository at this point in the history
)

* Backward compatibility fix for the Conversation class

* Explain what's going on in the conditional
  • Loading branch information
Rocketknight1 authored and EduardoPach committed Nov 19, 2023
1 parent 29eaa13 commit 91c56d1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
20 changes: 16 additions & 4 deletions src/transformers/pipelines/conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(

# This block deals with the legacy args - new code should just totally
# avoid past_user_inputs and generated_responses
self._num_processed_user_inputs = 0
generated_responses = deprecated_kwargs.pop("generated_responses", None)
past_user_inputs = deprecated_kwargs.pop("past_user_inputs", None)
if generated_responses is not None and past_user_inputs is None:
Expand Down Expand Up @@ -114,10 +115,11 @@ def append_response(self, response: str):

def mark_processed(self):
"""
This is a legacy method that no longer has any effect, as the Conversation no longer distinguishes between
processed and unprocessed user input.
This is a legacy method, as the Conversation no longer distinguishes between processed and unprocessed user
input. We set a counter here to keep behaviour mostly backward-compatible, but in general you should just read
the messages directly when writing new code.
"""
pass
self._num_processed_user_inputs = len(self._user_messages)

def __iter__(self):
for message in self.messages:
Expand Down Expand Up @@ -163,7 +165,17 @@ def _user_messages(self):
@property
def past_user_inputs(self):
# This is a legacy property for backwards compatibility. It is recommended to just directly access
# conversation.messages instead.
# conversation.messages instead. The modern class does not care about which messages are "processed"
# or not.
if not self._user_messages:
return []
# In the past, the most recent user message had to be mark_processed() before being included
# in past_user_messages. The class essentially had a single-message buffer, representing messages that
# had not yet been replied to. This is no longer the case, but we mimic the behaviour in this property
# for backward compatibility.
if self.messages[-1]["role"] != "user" or self._num_processed_user_inputs == len(self._user_messages):
return self._user_messages

return self._user_messages[:-1]

@property
Expand Down
10 changes: 5 additions & 5 deletions tests/pipelines/test_pipelines_conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def test_integration_torch_conversation(self):
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
conversation_2 = Conversation("What's the last book you have read?")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 1)
self.assertEqual(len(conversation_2.past_user_inputs), 1)
self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_2.past_user_inputs), 0)
# When
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
# Then
Expand Down Expand Up @@ -167,7 +167,7 @@ def test_integration_torch_conversation_truncated_history(self):
conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=torch_device)
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 1)
self.assertEqual(len(conversation_1.past_user_inputs), 0)
# When
result = conversation_agent(conversation_1, do_sample=False, max_length=36)
# Then
Expand Down Expand Up @@ -375,8 +375,8 @@ def test_integration_torch_conversation_encoder_decoder(self):
conversation_1 = Conversation("My name is Sarah and I live in London")
conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 1)
self.assertEqual(len(conversation_2.past_user_inputs), 1)
self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_2.past_user_inputs), 0)
# When
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
# Then
Expand Down

0 comments on commit 91c56d1

Please sign in to comment.