Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Wrong instance passed in the inputs argument of .format_output() #780

Closed
bergr7 opened this issue Jul 12, 2024 · 4 comments · Fixed by #781
Closed

[BUG] Wrong instance passed in the inputs argument of .format_output() #780

bergr7 opened this issue Jul 12, 2024 · 4 comments · Fixed by #781
Assignees

Comments

@bergr7
Copy link

bergr7 commented Jul 12, 2024

Describe the bug
The first instance in the batch is always passed to inputs argument of the .format_output() method of the Task class. This can create input - output mismatches e.g the data in inputs is used for creating metadata for the generated output.

Note the right instance is for formatting the input and generating the output.

To work around this, I had to set input_batch_size=1.

To Reproduce

I've created a simplified version of my code to reproduce the bug.

Code to reproduce

from distilabel.steps.tasks import Task
from distilabel.steps.tasks.typing import ChatType
from distilabel.llms import OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts, KeepColumns

from typing import List, Dict, Any

class MyCustomTask(Task):
    @property
    def inputs(self) -> List[str]:
        return ["input", "metadata"]

    @property
    def outputs(self) -> List[str]:
        return ["output", "metadata"]

    def format_input(self, input: Dict[str, Any]) -> ChatType:
        return [
            {
                "role": "system",
                "content": "You are a helpful assistant."
            },
            {
                "role": "user",
                "content": input["input"],
            },
        ]

    def format_output(self, output: str, inputs: Dict[str, Any]) -> Dict[str, Any]:
        # ! inputs always store the first instance in the batch
        # ! If inputs is used for creating metadata or similar
        # ! it creates an input - output mismatch
        metadata = {
            "parent_record_id": inputs["metadata"]["record_id"],
            "parent_record_type": inputs["metadata"]["record_type"],
        }

        return {"output": output, "metadata": metadata}


llm = OpenAILLM(model="gpt-4o")

# some dummy data
data = [
    {"input": "Hello, how are you?", "metadata": {"record_id": 1, "record_type": "user"}},
    {"input": "I'm doing well, thanks!", "metadata": {"record_id": 2, "record_type": "assistant"}},
    {"input": "How can I help you today?", "metadata": {"record_id": 3, "record_type": "user"}},
    {"input": "I'd like to book a flight.", "metadata": {"record_id": 4, "record_type": "assistant"}},
    {"input": "Can you please provide me with the flight details?", "metadata": {"record_id": 5, "record_type": "user"}},
    {"input": "Sure, I'll book the flight for you.", "metadata": {"record_id": 6, "record_type": "assistant"}},
    {"input": "Thank you for your booking!", "metadata": {"record_id": 7, "record_type": "user"}},
]


with Pipeline("my_pipeline") as pipeline:
    load_dataset = LoadDataFromDicts(
        name="load_dataset",
        data=data,
    )

    task = MyCustomTask(
        name="run_my_custom_task",
        llm=llm,
        input_batch_size=2,
    )

    output_cols = KeepColumns(
        name="output_cols",
        columns=["output", "metadata"],
    )

    load_dataset >> task >> output_cols
    
    
distiset = pipeline.run(
    parameters={
        task.name: {
            "llm": {"generation_kwargs": {"max_new_tokens": 10}}
            }
        },
        use_cache=False,
)

Expected behaviour
inputs contains the instance used for generating the output, instead of the first instance in the batch always.

Desktop (please complete the following information):

  • Package version: 1.2.1
  • Python version: 3.10.13
@gabrielmbmb gabrielmbmb self-assigned this Jul 12, 2024
@gabrielmbmb
Copy link
Member

Hi @bergr7, thanks for reporting! I'll work on fixing this

@bergr7
Copy link
Author

bergr7 commented Jul 12, 2024

Hi @bergr7, thanks for reporting! I'll work on fixing this

Many thanks @gabrielmbmb !! :)

@gabrielmbmb gabrielmbmb linked a pull request Jul 12, 2024 that will close this issue
@gabrielmbmb
Copy link
Member

Hi again @bergr7! We just released a new version 1.2.2 with the bug fixed. Thanks again for reporting!

@bergr7
Copy link
Author

bergr7 commented Jul 15, 2024

Hi again @bergr7! We just released a new version 1.2.2 with the bug fixed. Thanks again for reporting!

Lightning fast! Thanks. I can confirm it's fixed and I've benefited from the fix already!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants