diff --git a/model/model_training/configs/config.yaml b/model/model_training/configs/config.yaml index af0d67fb93..610a5ee239 100644 --- a/model/model_training/configs/config.yaml +++ b/model/model_training/configs/config.yaml @@ -88,6 +88,7 @@ defaults: deepspeed_config: configs/zero_config.json peft_model: false peft_type: "lora" + superhot: false use_system_tag: use_system_tag: True diff --git a/model/model_training/custom_datasets/__init__.py b/model/model_training/custom_datasets/__init__.py index e3c3eb3882..740fae8563 100644 --- a/model/model_training/custom_datasets/__init__.py +++ b/model/model_training/custom_datasets/__init__.py @@ -7,7 +7,7 @@ from model_training.custom_datasets.extra_rm_datasets import load_anthropic_rlhf, load_hellaswag, load_shp from model_training.custom_datasets.instruction import INSTRUCTION_DATASETS, InstructionDataset from model_training.custom_datasets.oasst_dataset import load_oasst_export -from model_training.custom_datasets.pretrain_datasets import RedPajama +from model_training.custom_datasets.pretrain_datasets import FanFics, RedPajama from model_training.custom_datasets.prompt_dialogue import Gpt4All, OrcaChat, load_oig_file from model_training.custom_datasets.qa_datasets import ( SODA, @@ -170,6 +170,8 @@ def get_one_dataset( dataset = AlpacaGpt4(cache_dir=data_path, mode=mode, **kwargs) elif dataset_name == "red_pajama": dataset = RedPajama(cache_dir=data_path, mode=mode, **kwargs) + elif dataset_name == "fanfics": + dataset = FanFics(cache_dir=data_path, mode=mode, **kwargs) elif dataset_name == "gpteacher_roleplay": dataset = GPTeacher_Roleplay(cache_dir=data_path, mode=mode, **kwargs) elif dataset_name == "orca-chat": diff --git a/model/model_training/custom_datasets/pretrain_datasets.py b/model/model_training/custom_datasets/pretrain_datasets.py index eb067a6b36..fead0719b5 100644 --- a/model/model_training/custom_datasets/pretrain_datasets.py +++ b/model/model_training/custom_datasets/pretrain_datasets.py @@ -1,7 +1,9 @@ """ Datasets for LM objective pre-training aimed to prevent catastrophic forgetting during fine-tuning """ +import random from pathlib import Path +from typing import Optional from datasets import load_dataset from model_training.custom_datasets.formatting import DatasetEntryLm @@ -11,17 +13,54 @@ class RedPajama(Dataset): name = "red_pajama" - def __init__(self, cache_dir: str | Path, mode: str = "sft", char_max_len: str = 9216) -> None: + def __init__( + self, + cache_dir: str | Path, + mode: str = "sft", + char_max_len: Optional[int] = 65536, + random_offset: bool = False, + ) -> None: super().__init__() self.mode = mode assert mode in ("sft", "rm", "rl") self.char_max_len = char_max_len - + self.random_offset = random_offset self.dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=cache_dir)["train"] def __len__(self) -> int: return len(self.dataset) def __getitem__(self, index) -> DatasetEntryLm: - dialogue = DatasetEntryLm(text=self.dataset[index]["text"][: self.char_max_len]) - return dialogue + text = self.dataset[index]["text"] + if self.char_max_len and len(text) > self.char_max_len: + offset = 0 if not self.random_offset else random.randrange(len(text) - self.char_max_len) + text = text[offset : offset + self.char_max_len] + return DatasetEntryLm(text=text) + + +class FanFics(Dataset): + name = "fanfics" + + def __init__( + self, + cache_dir: str | Path, + mode: str = "sft", + char_max_len: Optional[int] = 65536, + random_offset: bool = False, + ) -> None: + super().__init__() + self.mode = mode + assert mode in ("sft", "rm", "rl") + self.char_max_len = char_max_len + self.random_offset = random_offset + self.dataset = load_dataset("atom-in-the-universe/fanfics-10k-50k", cache_dir=cache_dir)["train"] + + def __len__(self) -> int: + return len(self.dataset) + + def __getitem__(self, index) -> DatasetEntryLm: + text = self.dataset[index]["TEXT"] + if self.char_max_len and len(text) > self.char_max_len: + offset = 0 if not self.random_offset else random.randrange(len(text) - self.char_max_len) + text = text[offset : offset + self.char_max_len] + return DatasetEntryLm(text=text)