Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hard-negative flag to include similar challenging negatives on triplets #856

Merged
merged 2 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 68 additions & 20 deletions src/distilabel/steps/tasks/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,22 @@
" section: `## Positive`."
)

NEGATIVE_STYLE: Dict[str, str] = {
"negative": "can use similar words but must not be related to the anchor sentence",
"hard-negative": (
"is a 'hard negative' that meets the following criteria:\n"
"- Uses similar keywords or phrases as the anchor sentence\n"
"- Has a similar grammatical structure or syntax\n"
"- Is not related to the anchor sentence, but could be mistaken for it\n"
"Try to create a negative sentence that would be challenging for a model to distinguish "
"from the positive sentence"
),
}

POSITIVE_NEGATIVE_SYSTEM_PROMPT: str = (
"Your task is to generate a positive and a negative sentence given an anchor sentence.{context}"
" The positive sentence has to {action_sentence} the anchor sentence, while the negative"
" sentence can use similar words but must not be related to the anchor sentence. You"
" must output only two new sections: `## Positive` and `## Negative`."
" sentence {negative_style}. You must output only two new sections: `## Positive` and `## Negative`."
)

CONTEXT_INTRO: Final[str] = " Take into account the context given."
Expand All @@ -63,23 +74,28 @@ class GenerateSentencePair(Task):

`GenerateSentencePair` is a pre-defined task that given an anchor sentence generates
a positive sentence related to the anchor and optionally a negative sentence unrelated
to the anchor. Optionally, you can give a context to guide the LLM towards more specific
behavior. This task is useful to generate training datasets for training embeddings
models.
to the anchor or similar to it. Optionally, you can give a context to guide the LLM
towards more specific behavior. This task is useful to generate training datasets for
training embeddings models.

Attributes:
triplet: a flag to indicate if the task should generate a triplet of sentences
(anchor, positive, negative). Defaults to `False`.
action: the action to perform to generate the positive sentence.
context: the context to use for the generation. Can be helpful to guide the LLM
towards more specific context. Not used by default.
hard_negative: A flag to indicate if the negative should be a hard-negative or not.
Hard negatives make it hard for the model to distinguish against the positive,
with a higher degree of semantic similarity.

Input columns:
- anchor (`str`): The anchor sentence to generate the positive and negative sentences.

Output columns:
- positive (`str`): The positive sentence related to the `anchor`.
- negative (`str`): The negative sentence unrelated to the `anchor` if `triplet=True`.
- negative (`str`): The negative sentence unrelated to the `anchor` if `triplet=True`,
or more similar to the positive to make it more challenging for a model to distinguish
in case `hard_negative=True`.
- model_name (`str`): The name of the model that was used to generate the sentences.

Categories:
Expand All @@ -97,8 +113,8 @@ class GenerateSentencePair(Task):
triplet=True, # `False` to generate only positive
action="paraphrase",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
input_batch_size=10,
)
Expand All @@ -118,8 +134,8 @@ class GenerateSentencePair(Task):
triplet=True, # `False` to generate only positive
action="semantically-similar",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
input_batch_size=10,
)
Expand All @@ -139,8 +155,8 @@ class GenerateSentencePair(Task):
triplet=True, # `False` to generate only positive
action="query",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
input_batch_size=10,
)
Expand All @@ -160,8 +176,8 @@ class GenerateSentencePair(Task):
triplet=True, # `False` to generate only positive
action="answer",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
input_batch_size=10,
)
Expand All @@ -182,8 +198,31 @@ class GenerateSentencePair(Task):
action="query",
context="Argilla is an open-source data curation platform for LLMs.",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
input_batch_size=10,
)

generate_sentence_pair.load()

result = generate_sentence_pair.process([{"anchor": "I want to generate queries for my LLM."}])
```

Generating Hard-negatives (**applies to every action**):

```python
from distilabel.steps.tasks import GenerateSentencePair
from distilabel.llms import InferenceEndpointsLLM

generate_sentence_pair = GenerateSentencePair(
triplet=True, # `False` to generate only positive
action="query",
context="Argilla is an open-source data curation platform for LLMs.",
hard_negative=True,
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
input_batch_size=10,
)
Expand All @@ -192,10 +231,12 @@ class GenerateSentencePair(Task):

result = generate_sentence_pair.process([{"anchor": "I want to generate queries for my LLM."}])
```

"""

triplet: bool = False
action: GenerationAction
hard_negative: bool = False
context: str = ""

def load(self) -> None:
Expand Down Expand Up @@ -229,12 +270,19 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType":
A list of dictionaries containing the system and user interactions.
"""
action_sentence = GENERATION_ACTION_SENTENCES[self.action]

format_system_prompt = {
"action_sentence": action_sentence,
"context": CONTEXT_INTRO if self.context else "",
}
if self.triplet:
format_system_prompt["negative_style"] = NEGATIVE_STYLE[
"hard-negative" if self.hard_negative else "negative"
]

system_prompt = (
POSITIVE_NEGATIVE_SYSTEM_PROMPT if self.triplet else POSITIVE_SYSTEM_PROMPT
).format(
action_sentence=action_sentence,
context=CONTEXT_INTRO if self.context else "",
)
).format(**format_system_prompt)

return [
{"role": "system", "content": system_prompt},
Expand Down
Loading
Loading