Skip to content

Commit

Permalink
Fix passing input to format_output function
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Jul 12, 2024
1 parent fe615d6 commit 5277cd4
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 16 deletions.
12 changes: 6 additions & 6 deletions src/distilabel/steps/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,20 @@ def format_output(
def _format_outputs(
self,
outputs: "GenerateOutput",
inputs: Union[List[Dict[str, Any]], None] = None,
input: Union[Dict[str, Any], None] = None,
) -> List[Dict[str, Any]]:
"""Formats the outputs of the task using the `format_output` method. If the output
is `None` (i.e. the LLM failed to generate a response), then the outputs will be
set to `None` as well.
Args:
outputs: The outputs of the LLM.
inputs: The inputs used to generate the outputs.
outputs: The outputs (`n` generations) for the provided `input`.
input: The input used to generate the output.
Returns:
A list containing a dictionary with the outputs of the task for each input.
"""
if inputs is None:
inputs = [None] # type: ignore
inputs = [None] if input is None else [input]

formatted_outputs = []
for output, input in zip(outputs, inputs * len(outputs)): # type: ignore
Expand Down Expand Up @@ -195,6 +194,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore

formatted_inputs = self._format_inputs(inputs)

# `outputs` is a list containing a list of generations per input
outputs = self.llm.generate(
inputs=formatted_inputs,
num_generations=self.num_generations, # type: ignore
Expand All @@ -203,7 +203,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore

task_outputs = []
for input, input_outputs in zip(inputs, outputs):
formatted_outputs = self._format_outputs(input_outputs, inputs)
formatted_outputs = self._format_outputs(input_outputs, input)

if self.group_generations:
combined = combine_dicts(*formatted_outputs)
Expand Down
131 changes: 121 additions & 10 deletions tests/unit/steps/tasks/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,22 @@
class DummyTask(Task):
@property
def inputs(self) -> List[str]:
return ["instruction"]
return ["instruction", "additional_info"]

def format_input(self, input: Dict[str, Any]) -> "ChatType":
return [
{"role": "system", "content": ""},
{"role": "user", "content": input["instruction"]},
]

def format_output(self, output: Union[str, None], input: Dict[str, Any]) -> dict:
return {"output": output}
@property
def outputs(self) -> List[str]:
return ["output", "info_from_input"]

def format_output(
self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
) -> Dict[str, Any]:
return {"output": output, "info_from_input": input["additional_info"]} # type: ignore


class DummyRuntimeLLM(DummyLLM):
Expand Down Expand Up @@ -85,37 +91,139 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
Task(name="task", llm=DummyLLM()) # type: ignore

@pytest.mark.parametrize(
"group_generations, expected",
"input, group_generations, expected",
[
(
[
{"instruction": "test_0", "additional_info": "additional_info_0"},
{"instruction": "test_1", "additional_info": "additional_info_1"},
{"instruction": "test_2", "additional_info": "additional_info_2"},
],
False,
[
{
"instruction": "test",
"instruction": "test_0",
"additional_info": "additional_info_0",
"output": "output",
"info_from_input": "additional_info_0",
"model_name": "test",
"distilabel_metadata": {"raw_output_task": "output"},
},
{
"instruction": "test_0",
"additional_info": "additional_info_0",
"output": "output",
"info_from_input": "additional_info_0",
"model_name": "test",
"distilabel_metadata": {"raw_output_task": "output"},
},
{
"instruction": "test_0",
"additional_info": "additional_info_0",
"output": "output",
"info_from_input": "additional_info_0",
"model_name": "test",
"distilabel_metadata": {"raw_output_task": "output"},
},
{
"instruction": "test_1",
"additional_info": "additional_info_1",
"output": "output",
"info_from_input": "additional_info_1",
"model_name": "test",
"distilabel_metadata": {"raw_output_task": "output"},
},
{
"instruction": "test_1",
"additional_info": "additional_info_1",
"output": "output",
"info_from_input": "additional_info_1",
"model_name": "test",
"distilabel_metadata": {"raw_output_task": "output"},
},
{
"instruction": "test_1",
"additional_info": "additional_info_1",
"output": "output",
"info_from_input": "additional_info_1",
"model_name": "test",
"distilabel_metadata": {"raw_output_task": "output"},
},
{
"instruction": "test",
"instruction": "test_2",
"additional_info": "additional_info_2",
"output": "output",
"info_from_input": "additional_info_2",
"model_name": "test",
"distilabel_metadata": {"raw_output_task": "output"},
},
{
"instruction": "test",
"instruction": "test_2",
"additional_info": "additional_info_2",
"output": "output",
"info_from_input": "additional_info_2",
"model_name": "test",
"distilabel_metadata": {"raw_output_task": "output"},
},
{
"instruction": "test_2",
"additional_info": "additional_info_2",
"output": "output",
"info_from_input": "additional_info_2",
"model_name": "test",
"distilabel_metadata": {"raw_output_task": "output"},
},
],
),
(
[
{"instruction": "test_0", "additional_info": "additional_info_0"},
{"instruction": "test_1", "additional_info": "additional_info_1"},
{"instruction": "test_2", "additional_info": "additional_info_2"},
],
True,
[
{
"instruction": "test",
"instruction": "test_0",
"additional_info": "additional_info_0",
"output": ["output", "output", "output"],
"info_from_input": [
"additional_info_0",
"additional_info_0",
"additional_info_0",
],
"model_name": "test",
"distilabel_metadata": [
{"raw_output_task": "output"},
{"raw_output_task": "output"},
{"raw_output_task": "output"},
],
},
{
"instruction": "test_1",
"additional_info": "additional_info_1",
"output": ["output", "output", "output"],
"info_from_input": [
"additional_info_1",
"additional_info_1",
"additional_info_1",
],
"model_name": "test",
"distilabel_metadata": [
{"raw_output_task": "output"},
{"raw_output_task": "output"},
{"raw_output_task": "output"},
],
},
{
"instruction": "test_2",
"additional_info": "additional_info_2",
"output": ["output", "output", "output"],
"info_from_input": [
"additional_info_2",
"additional_info_2",
"additional_info_2",
],
"model_name": "test",
"distilabel_metadata": [
{"raw_output_task": "output"},
Expand All @@ -128,7 +236,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
],
)
def test_process(
self, group_generations: bool, expected: List[Dict[str, Any]]
self,
input: List[Dict[str, str]],
group_generations: bool,
expected: List[Dict[str, Any]],
) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
llm = DummyLLM()
Expand All @@ -139,7 +250,7 @@ def test_process(
group_generations=group_generations,
num_generations=3,
)
result = next(task.process([{"instruction": "test"}]))
result = next(task.process(input))
assert result == expected

def test_process_with_runtime_parameters(self) -> None:
Expand Down

0 comments on commit 5277cd4

Please sign in to comment.