diff --git a/README.md b/README.md index a2543cd..213e102 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ [[`🏠 Project Homepage`](https://sites.google.com/view/safe-sora)] +[[`📕 Paper`](https://arxiv.org/abs/2406.14477)] [[`🤗 SafeSora Datasets`](https://huggingface.co/datasets/PKU-Alignment/SafeSora)] [[`🤗 SafeSora Label`](https://huggingface.co/datasets/PKU-Alignment/SafeSora-Label)] [[`🤗 SafeSora Evaluation`](https://huggingface.co/datasets/PKU-Alignment/SafeSora-Eval)] @@ -140,11 +141,13 @@ eval_data = PromptDataset.load("path/to/config", video_dir="path/to/video_dir") If you find the SafeSora dataset family useful in your research, please cite the following paper: ```bibtex -@article{SafeSora2024, - title = {SafeSora: Towards Safety Alignment of Text2Video Generation via a Human Preference Dataset}, - author = {Josef Dai and Tianle Chen and Xuyao Wang and Ziran Yang and Taiye Chen and Jiaming Ji and Yaodong Yang}, - url = {https://github.com/calico-1226/safe-sora}, - year = {2024} +@misc{dai2024safesora, + title={SafeSora: Towards Safety Alignment of Text2Video Generation via a Human Preference Dataset}, + author={Josef Dai and Tianle Chen and Xuyao Wang and Ziran Yang and Taiye Chen and Jiaming Ji and Yaodong Yang}, + year={2024}, + eprint={2406.14477}, + archivePrefix={arXiv}, + primaryClass={id='cs.CV' full_name='Computer Vision and Pattern Recognition' is_active=True alt_name=None in_archive='cs' is_general=False description='Covers image processing, computer vision, pattern recognition, and scene understanding. Roughly includes material in ACM Subject Classes I.2.10, I.4, and I.5.'} } ``` diff --git a/pyproject.toml b/pyproject.toml index a434acc..d9647a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "transformers", "datasets", "tokenizers", + "av", ] dynamic = ["version"] diff --git a/safe_sora/datasets/base.py b/safe_sora/datasets/base.py index 288ec46..d929d5d 100644 --- a/safe_sora/datasets/base.py +++ b/safe_sora/datasets/base.py @@ -40,9 +40,7 @@ def is_complete(data_dict: dict) -> bool: """Check if a dictionary is complete, i.e., all values are not None.""" - for key, value in data_dict.items(): - if key == 'info': - continue + for _, value in data_dict.items(): if isinstance(value, dict) and not is_complete(value): return False if value is None: @@ -153,7 +151,6 @@ class VideoSample(TypedDict): is_safe: NotRequired[bool] video_labels: NotRequired[HarmLabel] generated_from: NotRequired[str] - info: NotRequired[dict] def format_video_sample_from_dict(data: dict, contain_labels: bool = False) -> VideoSample: @@ -183,7 +180,6 @@ def format_video_sample_from_dict(data: dict, contain_labels: bool = False) -> V is_safe=data.get('is_safe'), video_labels=video_labels, generated_from=data.get('generated_from'), - info=data.get('info', {}), ) return VideoSample( @@ -195,7 +191,6 @@ def format_video_sample_from_dict(data: dict, contain_labels: bool = False) -> V video_path=data.get('video_path'), is_safe=data.get('is_safe'), generated_from=data.get('generated_from'), - info=data.get('info', {}), ) @@ -236,7 +231,6 @@ class VideoPairSample(TypedDict): helpfulness: NotRequired[Literal['video_0', 'video_1']] harmlessness: NotRequired[Literal['video_0', 'video_1']] sub_preferences: NotRequired[SubPreference] - info: NotRequired[dict[str, str]] def format_video_pair_sample_from_dict(data: dict) -> VideoPairSample: @@ -270,7 +264,6 @@ def format_video_pair_sample_from_dict(data: dict) -> VideoPairSample: helpfulness=data.get('helpfulness'), harmlessness=data.get('harmlessness'), sub_preferences=sub_preferences, - info=data.get('info', {}), )