Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
plaguss committed Dec 11, 2024
1 parent b54d46d commit ca07a58
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 0 deletions.
48 changes: 48 additions & 0 deletions tests/integration/test_distiset_card_with_uses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from distilabel.pipeline import Pipeline
from distilabel.steps import (
FormatTextGenerationDPO,
FormatTextGenerationSFT,
LoadDataFromDicts,
)


def test_dataset_card() -> None:
with Pipeline() as pipeline:
data = LoadDataFromDicts(
data=[
{
"instruction": "What's 2+2?",
"generation": "4",
"generations": ["4", "5"],
"ratings": [1, 5],
},
]
)
formatter = FormatTextGenerationSFT()
formatter_dpo = FormatTextGenerationDPO()

data >> formatter >> formatter_dpo

distiset = pipeline.run(use_cache=False)
disti_card = distiset._get_card("user/repo_id")
# Check that the card has the expected content
assert "## Uses\n\n### Supervised Fine-Tuning (SFT)" in str(disti_card)
assert "### Direct Preference Optimization (DPO)" in str(disti_card)


if __name__ == "__main__":
test_dataset_card()
45 changes: 45 additions & 0 deletions tests/unit/test_distiset.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,48 @@ def test_dataset_card(self, distiset: Distiset) -> None:
"size_categories": "n<1K",
"tags": ["synthetic", "distilabel", "rlaif"],
}

@pytest.mark.parametrize(
"dataset_uses, expected",
[
(None, None),
(
[
{
"title": "Title Use 1",
"template": "Test Template",
"variables": [],
},
],
"## Uses\n\n### Title Use 1\n\nTest Template",
),
(
[
{
"title": "Title Use 1",
"template": "Test Template",
"variables": ["var1", "var2"],
},
{
"title": "Title Use 2",
"template": "Template with {{ dataset_name }}",
"variables": ["dataset_name"],
},
],
"## Uses\n\n### Title Use 1\n\nTest Template\n\n### Title Use 2\n\nTemplate with repo_name_or_path",
),
],
)
def test_get_card_with_uses(
self,
distiset: Distiset,
dataset_uses: Optional[list[dict[str, Any]]],
expected: Optional[Dict[str, Any]],
) -> None:
distiset._dataset_uses = dataset_uses
distiset_card = distiset._get_card("repo_name_or_path")

if dataset_uses is None:
assert "## Uses" not in str(distiset_card)
else:
assert expected in str(distiset_card)

0 comments on commit ca07a58

Please sign in to comment.