From d91a8b5e767af37980ba386ccbb14420a2f53fcd Mon Sep 17 00:00:00 2001 From: "tobias.pitters" Date: Thu, 20 Apr 2023 22:13:14 +0200 Subject: [PATCH 1/6] update poem instruction path --- data/datasets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data/datasets/__init__.py b/data/datasets/__init__.py index 284f874e7e..dbdc0d98b6 100644 --- a/data/datasets/__init__.py +++ b/data/datasets/__init__.py @@ -22,7 +22,7 @@ "oa_leet10k": "ehartford/oa_leet10k", "LogicInference_OA": "KK04/LogicInference_OA", "oa_dolly_15k": "OllieStanley/oa_dolly_15k", - "poetry_instruction": "checkai/poetry-instruction", + "poetry_instruction": "checkai/instruction-poems", } SAFETY_DATASETS = { From bb651e353f8f3bae71f4640cee4b667e660a21e0 Mon Sep 17 00:00:00 2001 From: "tobias.pitters" Date: Fri, 21 Apr 2023 18:12:02 +0200 Subject: [PATCH 2/6] update entities --- .../custom_datasets/entities.py | 195 ++++++++++++++++++ 1 file changed, 195 insertions(+) create mode 100644 model/model_training/custom_datasets/entities.py diff --git a/model/model_training/custom_datasets/entities.py b/model/model_training/custom_datasets/entities.py new file mode 100644 index 0000000000..68c28a9f72 --- /dev/null +++ b/model/model_training/custom_datasets/entities.py @@ -0,0 +1,195 @@ +from enum import Enum, unique + + +class Mode(Enum): + sft = "sft" + rm = "rm" + rl = "rl" + + +@unique +class Language(str, Enum): + AB = "ab" # Abkhazian + AA = "aa" # Afar + AF = "af" # Afrikaans + AK = "ak" # Akan + SQ = "sq" # Albanian + AM = "am" # Amharic + AR = "ar" # Arabic + AN = "an" # Aragonese + HY = "hy" # Armenian + AS = "as" # Assamese + AV = "av" # Avaric + AE = "ae" # Avestan + AY = "ay" # Aymara + AZ = "az" # Azerbaijani + BM = "bm" # Bambara + BA = "ba" # Bashkir + EU = "eu" # Basque + BE = "be" # Belarusian + BN = "bn" # Bengali + BH = "bh" # Bihari languages + BI = "bi" # Bislama + BS = "bs" # Bosnian + BR = "br" # Breton + BG = "bg" # Bulgarian + MY = "my" # Burmese + CA = "ca" # Catalan, Valencian + CH = "ch" # Chamorro + CE = "ce" # Chechen + NY = "ny" # Chichewa, Chewa, Nyanja + ZH = "zh" # Chinese + CV = "cv" # Chuvash + KW = "kw" # Cornish + CO = "co" # Corsican + CR = "cr" # Cree + HR = "hr" # Croatian + CS = "cs" # Czech + DA = "da" # Danish + DV = "dv" # Divehi, Dhivehi, Maldivian + NL = "nl" # Dutch, Flemish + DZ = "dz" # Dzongkha + EN = "en" # English + EO = "eo" # Esperanto + ET = "et" # Estonian + EE = "ee" # Ewe + FO = "fo" # Faroese + FJ = "fj" # Fijian + FI = "fi" # Finnish + FR = "fr" # French + FF = "ff" # Fulah + GL = "gl" # Galician + KA = "ka" # Georgian + DE = "de" # German + EL = "el" # Greek, Modern (1453-) + GN = "gn" # Guarani + GU = "gu" # Gujarati + HT = "ht" # Haitian, Haitian Creole + HA = "ha" # Hausa + HE = "he" # Hebrew + HZ = "hz" # Herero + HI = "hi" # Hindi + HO = "ho" # Hiri Motu + HU = "hu" # Hungarian + IA = "ia" # Interlingua (International Auxiliary Language Association) + ID = "id" # Indonesian + IE = "ie" # Interlingue, Occidental + GA = "ga" # Irish + IG = "ig" # Igbo + IK = "ik" # Inupiaq + IO = "io" # Ido + IS = "is" # Icelandic + IT = "it" # Italian + IU = "iu" # Inuktitut + JA = "ja" # Japanese + JV = "jv" # Javanese + KL = "kl" # Kalaallisut, Greenlandic + KN = "kn" # Kannada + KR = "kr" # Kanuri + KS = "ks" # Kashmiri + KK = "kk" # Kazakh + KM = "km" # Central Khmer + KI = "ki" # Kikuyu, Gikuyu + RW = "rw" # Kinyarwanda + KY = "ky" # Kirghiz, Kyrgyz + KV = "kv" # Komi + KG = "kg" # Kongo + KO = "ko" # Korean + KU = "ku" # Kurdish + KJ = "kj" # Kuanyama, Kwanyama + LA = "la" # Latin + LB = "lb" # Luxembourgish, Letzeburgesch + LG = "lg" # Ganda + LI = "li" # Limburgan, Limburger, Limburgish + LN = "ln" # Lingala + LO = "lo" # Lao + LT = "lt" # Lithuanian + LU = "lu" # Luba-Katanga + LV = "lv" # Latvian + GV = "gv" # Manx + MK = "mk" # Macedonian + MG = "mg" # Malagasy + MS = "ms" # Malay + ML = "ml" # Malayalam + MT = "mt" # Maltese + MI = "mi" # Maori + MR = "mr" # Marathi + MH = "mh" # Marshallese + MN = "mn" # Mongolian + NA = "na" # Nauru + NV = "nv" # Navajo, Navaho + ND = "nd" # North Ndebele + NE = "ne" # Nepali + NG = "ng" # Ndonga + NB = "nb" # Norwegian Bokmål + NN = "nn" # Norwegian Nynorsk + NO = "no" # Norwegian + II = "ii" # Sichuan Yi, Nuosu + NR = "nr" # South Ndebele + OC = "oc" # Occitan + OJ = "oj" # Ojibwa + CU = "cu" # Church Slavic, Old Slavonic, Church Slavonic, Old Bulgarian, Old Church Slavonic + OM = "om" # Oromo + OR = "or" # Oriya + OS = "os" # Ossetian, Ossetic + PA = "pa" # Panjabi, Punjabi + PI = "pi" # Pali + FA = "fa" # Persian + PL = "pl" # Polish + PS = "ps" # Pashto, Pushto + PT = "pt" # Portuguese + QU = "qu" # Quechua + RM = "rm" # Romansh + RN = "rn" # Rundi + RO = "ro" # Romanian, Moldavian, Moldovan + RU = "ru" # Russian + SA = "sa" # Sanskrit + SC = "sc" # Sardinian + SD = "sd" # Sindhi + SE = "se" # Northern Sami + SM = "sm" # Samoan + SG = "sg" # Sango + SR = "sr" # Serbian + GD = "gd" # Gaelic, Scottish Gaelic + SN = "sn" # Shona + SI = "si" # Sinhala, Sinhalese + SK = "sk" # Slovak + SL = "sl" # Slovenian + SO = "so" # Somali + ST = "st" # Southern Sotho + ES = "es" # Spanish, Castilian + SU = "su" # Sundanese + SW = "sw" # Swahili + SS = "ss" # Swati + SV = "sv" # Swedish + TA = "ta" # Tamil + TE = "te" # Telugu + TG = "tg" # Tajik + TH = "th" # Thai + TI = "ti" # Tigrinya + BO = "bo" # Tibetan + TK = "tk" # Turkmen + TL = "tl" # Tagalog + TN = "tn" # Tswana + TO = "to" # Tonga (Tonga Islands) + TR = "tr" # Turkish + TS = "ts" # Tsonga + TT = "tt" # Tatar + TW = "tw" # Twi + TY = "ty" # Tahitian + UG = "ug" # Uighur, Uyghur + UK = "uk" # Ukrainian + UR = "ur" # Urdu + UZ = "uz" # Uzbek + VE = "ve" # Venda + VI = "vi" # Vietnamese + VO = "vo" # Volapük + WA = "wa" # Walloon + CY = "cy" # Welsh + WO = "wo" # Wolof + FY = "fy" # Western Frisian + XH = "xh" # Xhosa + YI = "yi" # Yiddish + YO = "yo" # Yoruba + ZA = "za" # Zhuang, Chuang + ZU = "zu" # Zulu From ef1ec019495ab1c99b64aebc00f4c0d4888877c0 Mon Sep 17 00:00:00 2001 From: "tobias.pitters" Date: Fri, 21 Apr 2023 20:37:02 +0200 Subject: [PATCH 3/6] update custom datasets init --- .../custom_datasets/__init__.py | 7 +++-- .../custom_datasets/qa_datasets.py | 31 +++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/model/model_training/custom_datasets/__init__.py b/model/model_training/custom_datasets/__init__.py index 4cbf314fe1..1183a5eb0e 100644 --- a/model/model_training/custom_datasets/__init__.py +++ b/model/model_training/custom_datasets/__init__.py @@ -12,6 +12,7 @@ SODA, AlpacaGpt4, DatabricksDolly15k, + InstructionPoems, JokeExplaination, QADataset, SODADialogue, @@ -162,9 +163,11 @@ def get_one_dataset( elif dataset_name == "hellaswag": train, eval = load_hellaswag() elif dataset_name == "dolly15k": - dataset = DatabricksDolly15k(cache_dir=data_path) + dataset = DatabricksDolly15k(cache_dir=data_path, mode=mode, **kwargs) elif dataset_name == "alpaca_gpt4": - dataset = AlpacaGpt4(cache_dir=data_path, **kwargs) + dataset = AlpacaGpt4(cache_dir=data_path, mode=mode, **kwargs) + elif dataset_name == "instruction_poems": + dataset = InstructionPoems(cache_dir=data_path, mode=mode, **kwargs) else: raise ValueError(f"Unknown dataset {dataset_name}") diff --git a/model/model_training/custom_datasets/qa_datasets.py b/model/model_training/custom_datasets/qa_datasets.py index 31dc4ca858..7fa97dc2d2 100644 --- a/model/model_training/custom_datasets/qa_datasets.py +++ b/model/model_training/custom_datasets/qa_datasets.py @@ -20,6 +20,8 @@ # @agoryuno contributed this re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]") +# check if the whole string is just a combination of (multiple) whitespaces and newlines +re_whitespace_newline_match = re.compile(r"^[\s\n]*$") LINKING_CHARS = ["\n", "\n\n", " "] @@ -448,6 +450,9 @@ def process_split( dataset: Subset, reverse_augmentation: bool = False, keep_unreversed: bool = True ) -> list[tuple[str, str]]: data = [] + import pdb + + pdb.set_trace() for row in dataset: question = row["instruction"] if len(row["input"]) > 0: @@ -616,3 +621,29 @@ def __getitem__(self, index: int) -> list[str] | tuple[str]: return dialogue elif self.mode == "rl": return tuple(dialogue[:-1]) + + +class InstructionPoems(Dataset): + def __init__(self, cache_dir: str | Path, mode: str = "sft", input_max_length: int = 2048) -> None: + super().__init__() + self.rows = [] + if mode not in ("sft", "rl"): + raise NotImplementedError(f"Currently only the modes 'sft' and 'rl' are implemented. Received {mode}.") + self.mode = mode + data = load_dataset("checkai/instruction-poems", cache_dir=cache_dir) + for line in data["train"]: + if (conv := self._process_instruction(line, input_max_length)) is not None: + self.rows.append(conv) + + def _process_instruction(self, row: dict[str, str], input_max_length: int) -> list[str] | None: + return [row["INSTRUCTION"], row["RESPONSE"]] + + def __len__(self) -> int: + return len(self.rows) + + def __getitem__(self, index: int) -> list[str] | tuple[str]: + dialogue: list[str] = self.rows[index] + if self.mode == "sft": + return dialogue + elif self.mode == "rl": + return tuple(dialogue[:-1]) From 7212d96fda1fd9a0ef659c0e4d8d0861e5e70e24 Mon Sep 17 00:00:00 2001 From: "tobias.pitters" Date: Fri, 21 Apr 2023 20:48:09 +0200 Subject: [PATCH 4/6] move poem instruction dataset --- .../custom_datasets/__init__.py | 3 --- .../custom_datasets/instruction.py | 1 + .../custom_datasets/qa_datasets.py | 26 ------------------- 3 files changed, 1 insertion(+), 29 deletions(-) diff --git a/model/model_training/custom_datasets/__init__.py b/model/model_training/custom_datasets/__init__.py index 1183a5eb0e..13029d750f 100644 --- a/model/model_training/custom_datasets/__init__.py +++ b/model/model_training/custom_datasets/__init__.py @@ -12,7 +12,6 @@ SODA, AlpacaGpt4, DatabricksDolly15k, - InstructionPoems, JokeExplaination, QADataset, SODADialogue, @@ -166,8 +165,6 @@ def get_one_dataset( dataset = DatabricksDolly15k(cache_dir=data_path, mode=mode, **kwargs) elif dataset_name == "alpaca_gpt4": dataset = AlpacaGpt4(cache_dir=data_path, mode=mode, **kwargs) - elif dataset_name == "instruction_poems": - dataset = InstructionPoems(cache_dir=data_path, mode=mode, **kwargs) else: raise ValueError(f"Unknown dataset {dataset_name}") diff --git a/model/model_training/custom_datasets/instruction.py b/model/model_training/custom_datasets/instruction.py index 3ca022d62f..05aa492836 100644 --- a/model/model_training/custom_datasets/instruction.py +++ b/model/model_training/custom_datasets/instruction.py @@ -18,6 +18,7 @@ "zhihu-kol": "wangrui6/zhihu-kol", "minimath": "kentsui/minimath", "oa_wiki_qa_bart_10000row": "michaelthwan/oa_wiki_qa_bart_10000row", + "poem_instructions": "checkai/instruction-poems", } diff --git a/model/model_training/custom_datasets/qa_datasets.py b/model/model_training/custom_datasets/qa_datasets.py index 7fa97dc2d2..85095c2a3b 100644 --- a/model/model_training/custom_datasets/qa_datasets.py +++ b/model/model_training/custom_datasets/qa_datasets.py @@ -621,29 +621,3 @@ def __getitem__(self, index: int) -> list[str] | tuple[str]: return dialogue elif self.mode == "rl": return tuple(dialogue[:-1]) - - -class InstructionPoems(Dataset): - def __init__(self, cache_dir: str | Path, mode: str = "sft", input_max_length: int = 2048) -> None: - super().__init__() - self.rows = [] - if mode not in ("sft", "rl"): - raise NotImplementedError(f"Currently only the modes 'sft' and 'rl' are implemented. Received {mode}.") - self.mode = mode - data = load_dataset("checkai/instruction-poems", cache_dir=cache_dir) - for line in data["train"]: - if (conv := self._process_instruction(line, input_max_length)) is not None: - self.rows.append(conv) - - def _process_instruction(self, row: dict[str, str], input_max_length: int) -> list[str] | None: - return [row["INSTRUCTION"], row["RESPONSE"]] - - def __len__(self) -> int: - return len(self.rows) - - def __getitem__(self, index: int) -> list[str] | tuple[str]: - dialogue: list[str] = self.rows[index] - if self.mode == "sft": - return dialogue - elif self.mode == "rl": - return tuple(dialogue[:-1]) From 612d0d37da4e45b9dfd83a0e897d3dc229bfe009 Mon Sep 17 00:00:00 2001 From: "tobias.pitters" Date: Fri, 21 Apr 2023 20:51:23 +0200 Subject: [PATCH 5/6] remove entities --- .../custom_datasets/entities.py | 195 ------------------ 1 file changed, 195 deletions(-) delete mode 100644 model/model_training/custom_datasets/entities.py diff --git a/model/model_training/custom_datasets/entities.py b/model/model_training/custom_datasets/entities.py deleted file mode 100644 index 68c28a9f72..0000000000 --- a/model/model_training/custom_datasets/entities.py +++ /dev/null @@ -1,195 +0,0 @@ -from enum import Enum, unique - - -class Mode(Enum): - sft = "sft" - rm = "rm" - rl = "rl" - - -@unique -class Language(str, Enum): - AB = "ab" # Abkhazian - AA = "aa" # Afar - AF = "af" # Afrikaans - AK = "ak" # Akan - SQ = "sq" # Albanian - AM = "am" # Amharic - AR = "ar" # Arabic - AN = "an" # Aragonese - HY = "hy" # Armenian - AS = "as" # Assamese - AV = "av" # Avaric - AE = "ae" # Avestan - AY = "ay" # Aymara - AZ = "az" # Azerbaijani - BM = "bm" # Bambara - BA = "ba" # Bashkir - EU = "eu" # Basque - BE = "be" # Belarusian - BN = "bn" # Bengali - BH = "bh" # Bihari languages - BI = "bi" # Bislama - BS = "bs" # Bosnian - BR = "br" # Breton - BG = "bg" # Bulgarian - MY = "my" # Burmese - CA = "ca" # Catalan, Valencian - CH = "ch" # Chamorro - CE = "ce" # Chechen - NY = "ny" # Chichewa, Chewa, Nyanja - ZH = "zh" # Chinese - CV = "cv" # Chuvash - KW = "kw" # Cornish - CO = "co" # Corsican - CR = "cr" # Cree - HR = "hr" # Croatian - CS = "cs" # Czech - DA = "da" # Danish - DV = "dv" # Divehi, Dhivehi, Maldivian - NL = "nl" # Dutch, Flemish - DZ = "dz" # Dzongkha - EN = "en" # English - EO = "eo" # Esperanto - ET = "et" # Estonian - EE = "ee" # Ewe - FO = "fo" # Faroese - FJ = "fj" # Fijian - FI = "fi" # Finnish - FR = "fr" # French - FF = "ff" # Fulah - GL = "gl" # Galician - KA = "ka" # Georgian - DE = "de" # German - EL = "el" # Greek, Modern (1453-) - GN = "gn" # Guarani - GU = "gu" # Gujarati - HT = "ht" # Haitian, Haitian Creole - HA = "ha" # Hausa - HE = "he" # Hebrew - HZ = "hz" # Herero - HI = "hi" # Hindi - HO = "ho" # Hiri Motu - HU = "hu" # Hungarian - IA = "ia" # Interlingua (International Auxiliary Language Association) - ID = "id" # Indonesian - IE = "ie" # Interlingue, Occidental - GA = "ga" # Irish - IG = "ig" # Igbo - IK = "ik" # Inupiaq - IO = "io" # Ido - IS = "is" # Icelandic - IT = "it" # Italian - IU = "iu" # Inuktitut - JA = "ja" # Japanese - JV = "jv" # Javanese - KL = "kl" # Kalaallisut, Greenlandic - KN = "kn" # Kannada - KR = "kr" # Kanuri - KS = "ks" # Kashmiri - KK = "kk" # Kazakh - KM = "km" # Central Khmer - KI = "ki" # Kikuyu, Gikuyu - RW = "rw" # Kinyarwanda - KY = "ky" # Kirghiz, Kyrgyz - KV = "kv" # Komi - KG = "kg" # Kongo - KO = "ko" # Korean - KU = "ku" # Kurdish - KJ = "kj" # Kuanyama, Kwanyama - LA = "la" # Latin - LB = "lb" # Luxembourgish, Letzeburgesch - LG = "lg" # Ganda - LI = "li" # Limburgan, Limburger, Limburgish - LN = "ln" # Lingala - LO = "lo" # Lao - LT = "lt" # Lithuanian - LU = "lu" # Luba-Katanga - LV = "lv" # Latvian - GV = "gv" # Manx - MK = "mk" # Macedonian - MG = "mg" # Malagasy - MS = "ms" # Malay - ML = "ml" # Malayalam - MT = "mt" # Maltese - MI = "mi" # Maori - MR = "mr" # Marathi - MH = "mh" # Marshallese - MN = "mn" # Mongolian - NA = "na" # Nauru - NV = "nv" # Navajo, Navaho - ND = "nd" # North Ndebele - NE = "ne" # Nepali - NG = "ng" # Ndonga - NB = "nb" # Norwegian Bokmål - NN = "nn" # Norwegian Nynorsk - NO = "no" # Norwegian - II = "ii" # Sichuan Yi, Nuosu - NR = "nr" # South Ndebele - OC = "oc" # Occitan - OJ = "oj" # Ojibwa - CU = "cu" # Church Slavic, Old Slavonic, Church Slavonic, Old Bulgarian, Old Church Slavonic - OM = "om" # Oromo - OR = "or" # Oriya - OS = "os" # Ossetian, Ossetic - PA = "pa" # Panjabi, Punjabi - PI = "pi" # Pali - FA = "fa" # Persian - PL = "pl" # Polish - PS = "ps" # Pashto, Pushto - PT = "pt" # Portuguese - QU = "qu" # Quechua - RM = "rm" # Romansh - RN = "rn" # Rundi - RO = "ro" # Romanian, Moldavian, Moldovan - RU = "ru" # Russian - SA = "sa" # Sanskrit - SC = "sc" # Sardinian - SD = "sd" # Sindhi - SE = "se" # Northern Sami - SM = "sm" # Samoan - SG = "sg" # Sango - SR = "sr" # Serbian - GD = "gd" # Gaelic, Scottish Gaelic - SN = "sn" # Shona - SI = "si" # Sinhala, Sinhalese - SK = "sk" # Slovak - SL = "sl" # Slovenian - SO = "so" # Somali - ST = "st" # Southern Sotho - ES = "es" # Spanish, Castilian - SU = "su" # Sundanese - SW = "sw" # Swahili - SS = "ss" # Swati - SV = "sv" # Swedish - TA = "ta" # Tamil - TE = "te" # Telugu - TG = "tg" # Tajik - TH = "th" # Thai - TI = "ti" # Tigrinya - BO = "bo" # Tibetan - TK = "tk" # Turkmen - TL = "tl" # Tagalog - TN = "tn" # Tswana - TO = "to" # Tonga (Tonga Islands) - TR = "tr" # Turkish - TS = "ts" # Tsonga - TT = "tt" # Tatar - TW = "tw" # Twi - TY = "ty" # Tahitian - UG = "ug" # Uighur, Uyghur - UK = "uk" # Ukrainian - UR = "ur" # Urdu - UZ = "uz" # Uzbek - VE = "ve" # Venda - VI = "vi" # Vietnamese - VO = "vo" # Volapük - WA = "wa" # Walloon - CY = "cy" # Welsh - WO = "wo" # Wolof - FY = "fy" # Western Frisian - XH = "xh" # Xhosa - YI = "yi" # Yiddish - YO = "yo" # Yoruba - ZA = "za" # Zhuang, Chuang - ZU = "zu" # Zulu From 64fc8c09364f099044064eb9c485c49621000e0d Mon Sep 17 00:00:00 2001 From: "tobias.pitters" Date: Fri, 21 Apr 2023 20:52:14 +0200 Subject: [PATCH 6/6] update qa datasets --- model/model_training/custom_datasets/qa_datasets.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/model/model_training/custom_datasets/qa_datasets.py b/model/model_training/custom_datasets/qa_datasets.py index 85095c2a3b..be6864791c 100644 --- a/model/model_training/custom_datasets/qa_datasets.py +++ b/model/model_training/custom_datasets/qa_datasets.py @@ -450,9 +450,7 @@ def process_split( dataset: Subset, reverse_augmentation: bool = False, keep_unreversed: bool = True ) -> list[tuple[str, str]]: data = [] - import pdb - pdb.set_trace() for row in dataset: question = row["instruction"] if len(row["input"]) > 0: