Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backward compatibility fix for the Conversation class #27176

Merged
merged 2 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/transformers/pipelines/conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(

# This block deals with the legacy args - new code should just totally
# avoid past_user_inputs and generated_responses
self._num_processed_user_inputs = 0
generated_responses = deprecated_kwargs.pop("generated_responses", None)
past_user_inputs = deprecated_kwargs.pop("past_user_inputs", None)
if generated_responses is not None and past_user_inputs is None:
Expand Down Expand Up @@ -114,10 +115,11 @@ def append_response(self, response: str):

def mark_processed(self):
"""
This is a legacy method that no longer has any effect, as the Conversation no longer distinguishes between
processed and unprocessed user input.
This is a legacy method, as the Conversation no longer distinguishes between processed and unprocessed user
input. We set a counter here to keep behaviour mostly backward-compatible, but in general you should just read
the messages directly when writing new code.
"""
pass
self._num_processed_user_inputs = len(self._user_messages)

def __iter__(self):
for message in self.messages:
Expand Down Expand Up @@ -163,7 +165,13 @@ def _user_messages(self):
@property
def past_user_inputs(self):
# This is a legacy property for backwards compatibility. It is recommended to just directly access
# conversation.messages instead.
# conversation.messages instead. The modern class does not care about which messages are "processed"
# or not.
if not self._user_messages:
return []
if self.messages[-1]["role"] != "user" or self._num_processed_user_inputs == len(self._user_messages):
return self._user_messages

return self._user_messages[:-1]
Comment on lines +176 to 179
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have a bit of comments about why this conditions and their output values?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean hard to understand why self._user_messages when self.messages[-1]["role"] != "user"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment!


@property
Expand Down
10 changes: 5 additions & 5 deletions tests/pipelines/test_pipelines_conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def test_integration_torch_conversation(self):
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
conversation_2 = Conversation("What's the last book you have read?")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 1)
self.assertEqual(len(conversation_2.past_user_inputs), 1)
self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_2.past_user_inputs), 0)
# When
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
# Then
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_integration_torch_conversation_truncated_history(self):
conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=DEFAULT_DEVICE_NUM)
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 1)
self.assertEqual(len(conversation_1.past_user_inputs), 0)
# When
result = conversation_agent(conversation_1, do_sample=False, max_length=36)
# Then
Expand Down Expand Up @@ -379,8 +379,8 @@ def test_integration_torch_conversation_encoder_decoder(self):
conversation_1 = Conversation("My name is Sarah and I live in London")
conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 1)
self.assertEqual(len(conversation_2.past_user_inputs), 1)
self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_2.past_user_inputs), 0)
# When
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
# Then
Expand Down