diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml index b5b053fc1549..35942ef531c7 100644 --- a/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml @@ -38,6 +38,55 @@ web_port: 9889 # the port number of the web server chat: False # use the chat interface chatbot_config: value: False # whether to inject the value attributes + attributes: + - name: Quality + min: 0 + max: 4 + key: quality + type: int + default: 4 + - name: Toxicity + min: 0 + max: 4 + key: toxcity + type: int + default: 0 + - name: Humor + min: 0 + max: 4 + key: humor + type: int + default: 0 + - name: Creativity + min: 0 + max: 4 + key: creativity + type: int + default: 0 + - name: Violence + min: 0 + max: 4 + key: violence + type: int + default: 0 + - name: Helpfulness + min: 0 + max: 4 + key: helpfulness + type: int + default: 4 + - name: Not_Appropriate + min: 0 + max: 4 + key: not_appropriate + type: int + default: 0 + - name: Language + choices: ['ar', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en', 'eo', 'es', 'eu', 'fa', 'fi', 'fr', 'gl', 'he', 'hu', 'id', 'it', 'ja', 'ko', 'nb', 'nl', 'pl', 'pt', 'ro', 'ru', 'sk', 'sv', 'th', 'tr', 'uk', 'vi', 'zh'] + key: lang + type: list + default: en + user: User assistant: Assistant system: "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n" diff --git a/examples/nlp/language_modeling/megatron_gpt_eval.py b/examples/nlp/language_modeling/megatron_gpt_eval.py index 76e68d24bae8..2d34bf78ed01 100644 --- a/examples/nlp/language_modeling/megatron_gpt_eval.py +++ b/examples/nlp/language_modeling/megatron_gpt_eval.py @@ -314,7 +314,12 @@ def main(cfg) -> None: 'assistant': cfg.chatbot_config.assistant, 'system': cfg.chatbot_config.system, } - web_ui = partial(get_chatbot_demo, defaults=defaults, value=cfg.chatbot_config.value) + web_ui = partial( + get_chatbot_demo, + defaults=defaults, + value=cfg.chatbot_config.value, + attributes=cfg.chatbot_config.attributes, + ) else: web_ui = get_demo loop = asyncio.new_event_loop() diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml index f15138c99264..3639e2386288 100644 --- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml @@ -102,6 +102,7 @@ model: truncation_field: "context" # Options: ['context', 'answer'] index_mapping_dir: null # Path to a directory to write index mapping files. prompt_template: null # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + hf_dataset: False # Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset. validation_ds: file_names: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. @@ -126,6 +127,7 @@ model: truncation_field: "context" # Options: ['context', 'answer'] index_mapping_dir: null # Path to a directory to write index mapping files. prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + hf_dataset: False # Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset. metric: name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] @@ -155,6 +157,7 @@ model: truncation_field: "context" # Options: ['context', 'answer'] index_mapping_dir: null # Path to a directory to write index mapping files. prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + hf_dataset: False # Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset. metric: name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py index d6c2257ebabb..733ff0f829cd 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py @@ -90,12 +90,18 @@ def _mask_targets( # target[cur_idx + 1:cur_idx + tokenized_len] skip the turn token if not torch.equal(target[cur_idx + 1 : cur_idx + tokenized_len], s_id[1:]): logging.warning("a sentence mismatches the corresponding piece " "in the conversation") - if i == 0: + if i == 0 and (gtype == 'VALUE_TO_TEXT' or gtype is None): # mask the first turn completely to provide at least one turn as context target[cur_idx : cur_idx + tokenized_len] = IGNORE_INDEX - elif speaker == mask_role: + elif speaker == mask_role and i == 1 and gtype == 'TEXT_TO_VALUE': # leave the first human tag unmasked target[cur_idx + 1 : cur_idx + tokenized_len] = IGNORE_INDEX + elif speaker == mask_role and (i > 1): + # leave the first human tag unmasked + target[cur_idx + 1 : cur_idx + tokenized_len] = IGNORE_INDEX + elif speaker == mask_role and (i <= 1): + # mask out everything in the second turn + target[cur_idx : cur_idx + tokenized_len] = IGNORE_INDEX else: # mask up to the name end, need to remove one as skip name has an extra artifact empty token target[cur_idx : cur_idx + skip_name_len] = IGNORE_INDEX @@ -109,6 +115,8 @@ def cannonical_form_formater(cannoical_form): def response_value_formater(label): if isinstance(label, str): return '' + label + '\n' + elif label is None: + return '' else: raise ValueError(f'Unknown label type {type(label)}, only str type is supported') diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index da3d03199c2e..cb62b3ac9f7a 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -16,11 +16,13 @@ import numpy as np import torch +from datasets import load_dataset from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import get_samples_mapping from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import JSONLMemMapDataset from nemo.core.classes import Dataset +from nemo.utils import logging __all__ = ['GPTSFTDataset'] @@ -49,6 +51,7 @@ def __init__( virtual_tokens: int = 0, tokens_to_generate: int = 0, memmap_workers: Optional[int] = None, + hf_dataset: bool = False, ): """ file_path: Path to a JSONL GPT supervised fine-tuning dataset. Data is formatted as multiple JSON lines with each line formatted as follows. {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'} @@ -70,6 +73,7 @@ def __init__( pad_to_max_length: Whether to pad the input to the max sequence length. If False, will pad to the max length of the current batch. index_mapping_dir: Directory to save the index mapping to. If None, will write to the same folder as the dataset. prompt_template: Prompt template to inject via an fstring. Formatted like Q: {input}\n\nA: {output} + hf_dataset: Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset. """ self.tokenizer = tokenizer self.file_path = file_path @@ -96,13 +100,18 @@ def __init__( self.prompt_template = self.prompt_template.encode('utf-8').decode('unicode_escape') assert self.truncation_field in ["answer", "context"] - self.indexed_dataset = JSONLMemMapDataset( - dataset_paths=[file_path], - tokenizer=None, - header_lines=0, - index_mapping_dir=index_mapping_dir, - workers=memmap_workers, - ) + if hf_dataset: + self.indexed_dataset = load_dataset( + 'json', data_files=file_path, cache_dir=index_mapping_dir, num_proc=memmap_workers, split='train' + ) + else: + self.indexed_dataset = JSONLMemMapDataset( + dataset_paths=[file_path], + tokenizer=None, + header_lines=0, + index_mapping_dir=index_mapping_dir, + workers=memmap_workers, + ) # Will be None after this call if `max_num_samples` is None self._build_samples_mapping() @@ -141,7 +150,11 @@ def __getitem__(self, idx): idx = idx.item() assert idx < len(self.indexed_dataset) - example = self.indexed_dataset[idx] + try: + example = self.indexed_dataset[idx] + except Exception as e: + logging.error(f"Error while loading example {idx} from dataset {self.file_path}") + raise e return self._process_example(example) def _process_example(self, example): diff --git a/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py b/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py index 05d10b42e115..706fdf1d2393 100644 --- a/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py @@ -371,7 +371,13 @@ def __init__( def _build_data_from_text(self, text): """Return a dictionary of data based on a single JSON line.""" - return json.loads(text) + try: + record = json.loads(text) + except Exception as e: + logging.error(f"Exception: {e}") + logging.error(f"datapoint: {text}") + raise e + return record def _index_file_exists(idx_fn): diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 95108e90f087..fdec8a31c02d 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -269,6 +269,9 @@ def _build_dataset(self, data_cfg, is_train=True): memmap_workers=data_cfg.get( 'memmap_workers', None ), # used to set num. of workers to create the memmap index files + hf_dataset=data_cfg.get( + 'hf_dataset', False + ), # Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset. ) datasets.append(dataset) diff --git a/nemo/collections/nlp/modules/common/megatron_web_server.py b/nemo/collections/nlp/modules/common/megatron_web_server.py index 7c04ef201927..91fb12af102c 100644 --- a/nemo/collections/nlp/modules/common/megatron_web_server.py +++ b/nemo/collections/nlp/modules/common/megatron_web_server.py @@ -190,7 +190,7 @@ def clear_fun(): def get_chatbot_demo( - share, username, password, server_port=5555, web_port=9889, loop=None, value=False, defaults=None + share, username, password, server_port=5555, web_port=9889, loop=None, value=False, defaults=None, attributes=None, ): check_gradio_import() from nemo.collections.nlp.modules.common.chatbot_component import Chatbot @@ -222,28 +222,20 @@ def get_chatbot_demo( ) with gr.Accordion("Value Parameters", open=True, visible=value): - keys = ['quality', 'toxicity', 'humor', 'creativity', 'violence', 'helpfulness', 'not_appropriate'] - quality_value = gr.Slider( - minimum=0, maximum=9, step=1, value=9, label='Quality', interactive=True, visible=True - ) - toxicity_value = gr.Slider( - minimum=0, maximum=9, step=1, value=0, label='Toxicity', interactive=True, visible=True - ) - humor_value = gr.Slider( - minimum=0, maximum=9, step=1, value=0, label='Humor', interactive=True, visible=True - ) - creativity_value = gr.Slider( - minimum=0, maximum=9, step=1, value=0, label='Creativity', interactive=True, visible=True - ) - violence_value = gr.Slider( - minimum=0, maximum=9, step=1, value=0, label='Violence', interactive=True, visible=True - ) - helpfulness_value = gr.Slider( - minimum=0, maximum=9, step=1, value=9, label='Helpfulness', interactive=True, visible=True - ) - not_appropriate_value = gr.Slider( - minimum=0, maximum=9, step=1, value=0, label='Not Appropriate', interactive=True, visible=True - ) + keys = [k.key for k in attributes] + # keys = ['quality', 'toxicity', 'humor', 'creativity', 'violence', 'helpfulness', 'not_appropriate'] + widgets = [] + for item in attributes: + if item.type == 'int': + slider = gr.Slider( + minimum=item.min, maximum=item.max, step=1, value=item.default, label=item.name + ) + widgets.append(slider) + elif item.type == 'list': + dropdown = gr.Dropdown( + item.choices, label=item.name, default=item.default, value=item.default + ) + widgets.append(dropdown) used_value = gr.CheckboxGroup(keys, value=keys) def change_visibility(x): @@ -256,17 +248,7 @@ def change_visibility(x): return values used_value.change( - change_visibility, - inputs=[used_value], - outputs=[ - quality_value, - toxicity_value, - humor_value, - creativity_value, - violence_value, - helpfulness_value, - not_appropriate_value, - ], + change_visibility, inputs=[used_value], outputs=widgets, ) def set_sampling(x): @@ -328,25 +310,11 @@ def bot( assistant_name, session_state, prompts_presets, - quality_value, - toxicity_value, - humor_value, - creativity_value, - violence_value, - helpfulness_value, - not_appropriate_value, used_value, + *values, ): - values_array = [ - quality_value, - toxicity_value, - humor_value, - creativity_value, - violence_value, - helpfulness_value, - not_appropriate_value, - ] + values_array = values if value: value_str = get_value_str(values_array, used_value) else: @@ -400,14 +368,8 @@ def bot( assistant_name, session_state, prompt_presets, - quality_value, - toxicity_value, - humor_value, - creativity_value, - violence_value, - helpfulness_value, - not_appropriate_value, used_value, + *widgets, ], [chatbot], ) diff --git a/requirements/requirements_nlp.txt b/requirements/requirements_nlp.txt index 68d8b8985748..edc643cbbd3a 100644 --- a/requirements/requirements_nlp.txt +++ b/requirements/requirements_nlp.txt @@ -1,4 +1,5 @@ boto3 +datasets einops faiss-cpu fasttext diff --git a/scripts/nlp_language_modeling/sft/attribute_annotate.py b/scripts/nlp_language_modeling/sft/attribute_annotate.py new file mode 100644 index 000000000000..4e447a63b72b --- /dev/null +++ b/scripts/nlp_language_modeling/sft/attribute_annotate.py @@ -0,0 +1,366 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""script to annotate the the datasets with using trained attribute prediciton model. +First, we need to launch the NeMo Megatron inference server +Example: +```bash + python examples/nlp/language_modeling/megatron_gpt_eval.py \ + gpt_model_file=/models/TRAINED_ATTR_PREDICTION_MODEL.nemo \ + pipeline_model_parallel_split_rank=0 \ + server=True \ + tensor_model_parallel_size=TP_SIZE \ + pipeline_model_parallel_size=PP_SIZE \ + trainer.precision=bf16 \ + trainer.devices=TP_SIZE*PP_SIZE \ + trainer.num_nodes=1 \ + web_server=False \ + port=1424 +``` + +Then, we can run this script to annotate the dataset. +Example usage: + +python scripts/nlp_language_modeling/sft/attribute_annotate.py --batch_size=1 --host=localhost --input_file_name=input.jsonl --output_file_name=output.jsonl --port_num=1424 +""" + +import json +import os + +import fire +import tqdm +from langchain.prompts.few_shot import PromptTemplate + +from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import text_generation + +langs = [ + 'ar', + 'bg', + 'bn', + 'ca', + 'cs', + 'da', + 'de', + 'el', + 'en', + 'eo', + 'es', + 'eu', + 'fa', + 'fi', + 'fr', + 'gl', + 'he', + 'hu', + 'id', + 'it', + 'ja', + 'ko', + 'nb', + 'nl', + 'pl', + 'pt', + 'ro', + 'ru', + 'sk', + 'sv', + 'th', + 'tr', + 'uk', + 'vi', + 'zh', +] + +SFT_PREFIX = """System +{system_message}""" + +ONE_TRUN_WITH_VAL = """{user_name} +{user_message} +{label} +""" + +ONE_TRUN_WITHOUT_VAL = """{user_name} +{user_message} +""" +SYSTEM = PromptTemplate(input_variables=["system_message"], template=SFT_PREFIX) +EXAMPLE_PROMPT_WITH_VAL = PromptTemplate( + input_variables=["user_name", "user_message", "label"], template=ONE_TRUN_WITH_VAL +) +EXAMPLE_PROMPT_WITHOUT_VAL = PromptTemplate( + input_variables=["user_name", "user_message"], template=ONE_TRUN_WITHOUT_VAL +) + +selected_keys = [ + 'quality', + 'toxicity', + 'humor', + 'creativity', + 'violence', + 'helpfulness', + 'not_appropriate', + 'hate_speech', + 'sexual_content', + 'fails_task', + 'political_content', + 'moral_judgement', + 'lang', +] + + +def calculate_key(obj): + return ":".join([item['value'] for item in obj['conversations']]) + + +def load_data(path): + with open(path, 'r', encoding='utf-8') as fin: + for line in fin: + yield json.loads(line) + + +def get_prompt(data_obj, turn, current_label="", label_id=0): + if len(data_obj['conversations']) < turn + 1: + return None + + examples = [] + for i in range(0, turn): + d = data_obj['conversations'][i] + if 'label' in d: + examples.append( + EXAMPLE_PROMPT_WITH_VAL.format( + **{'user_name': d['from'], 'user_message': d['value'], 'label': d['label']} + ) + ) + else: + examples.append(EXAMPLE_PROMPT_WITHOUT_VAL.format(**{'user_name': d['from'], 'user_message': d['value']})) + + example_text = "".join(examples) + d = data_obj['conversations'][turn] + predict_message = EXAMPLE_PROMPT_WITHOUT_VAL.format(**{'user_name': d['from'], 'user_message': d['value']}) + + if label_id != 0: + current_label = current_label + ',' + selected_keys[label_id] + ':' + else: + current_label = '' + selected_keys[label_id] + ':' + return SYSTEM.format(**{'system_message': data_obj['system']}) + example_text + predict_message + current_label + + +def create_gen_function(host='localhost', port=5555): + def request(prompts, greedy, add_BOS, token_to_gen, min_tokens, temp, top_p, top_k, repetition, end_strings): + data = { + "sentences": prompts, + "tokens_to_generate": int(token_to_gen), + "temperature": temp, + "add_BOS": add_BOS, + "top_k": top_k, + "top_p": top_p, + "greedy": greedy, + "all_probs": False, + "repetition_penalty": repetition, + "min_tokens_to_generate": int(min_tokens), + "end_strings": end_strings, + } + response = text_generation(data, ip=host, port=port) + sentences = response['sentences'] + return sentences + + return request + + +class Worker(object): + def __init__(self, host='localhost', port=5555, progress_bar=None, output_file=None, process_lang=False): + self.req = create_gen_function(host=host, port=port) + self.fout = open(output_file, "a", encoding='utf-8') + self.progress_bar = progress_bar + self.process_lang = process_lang + + def process_result(self, batch): + while True: + try: + items = [i['item'] for i in batch] + turns = [i['turn'] for i in batch] + prompts = [i['prompt'] for i in batch] + + for label_id in range(1, len(selected_keys)): + results = self.req( + prompts, + greedy=True, + add_BOS=False, + token_to_gen=1, + min_tokens=1, + temp=0.1, + top_p=1.0, + top_k=1, + repetition=1.0, + end_strings=["", "<|endoftext|>"], + ) + # get current value from result + current_values = [] + nums = [] + for result in results: + # promblem result[-1] is '\n' + current_val = result.split('quality')[-1] + current_val = 'quality' + current_val + # remove whatever after new line + current_val = current_val.split('\n')[0].strip() + # remove everything that is >= selected_keys[label_id] + splits = current_val.split(',') + filtered = [] + for item in splits: + filtered.append(item) + if item.split(':')[0] == selected_keys[label_id - 1]: + nums.append(item.split(':')[1]) + break + current_val = '' + ','.join(filtered) + current_values.append(current_val) + + filtered_items = [] + filtered_turns = [] + filtered_prompts = [] + filtered_current_values = [] + + for result, item, turn, num, current_value in zip(results, items, turns, nums, current_values): + try: + value = int(num) + except Exception as e: + print(f'error {e} when convert {num} to int') + continue + filtered_current_values.append(current_value) + filtered_items.append(item) + filtered_turns.append(turn) + if label_id < len(selected_keys): + prompt = get_prompt(item, turn, current_label=current_value, label_id=label_id) + filtered_prompts.append(prompt) + items = filtered_items + turns = filtered_turns + prompts = filtered_prompts + current_values = filtered_current_values + + if self.process_lang: + results = self.req( + prompts, + greedy=True, + add_BOS=False, + token_to_gen=1, + min_tokens=1, + temp=0.1, + top_p=1.0, + top_k=1, + repetition=1.0, + end_strings=["", "<|endoftext|>"], + ) + # get current value from result + current_values = [] + for result in results: + # promblem result[-1] is '\n' + if result.endswith('\n'): + result = result[:-1] + '@' + current_values.append(result.split('\n')[-1]) + + nums = [] + for result in results: + # promblem result[-1] is '\n' + current_val = result.split('quality')[-1] + current_val = 'quality' + current_val + # remove whatever after new line + current_val = current_val.split('\n')[0].strip() + # remove everything that is >= selected_keys[label_id] + splits = current_val.split(',') + filtered = [] + for item in splits: + filtered.append(item) + if item.split(':')[0] == selected_keys[label_id]: + nums.append(item.split(':')[1]) + break + current_val = '' + ','.join(filtered) + current_values.append(current_val) + + filtered_items = [] + filtered_turns = [] + filtered_prompts = [] + filtered_current_values = [] + + for result, item, turn, num, current_value in zip(results, items, turns, nums, current_values): + if num not in langs: + print(f'error {num} not in langs') + continue + filtered_current_values.append(current_value) + filtered_items.append(item) + filtered_turns.append(turn) + items = filtered_items + turns = filtered_turns + current_values = filtered_current_values + + batch = [] + for item, turn, current_value in zip(items, turns, current_values): + response_text = current_value[12:] + if 'label' in item['conversations'][turn]: + item['conversations'][turn]['gt_label'] = item['conversations'][turn]['label'] + item['conversations'][turn]['label'] = response_text + prompt = get_prompt(item, turn + 1, current_label='', label_id=0) + if prompt is not None: + batch.append({'prompt': prompt, 'item': item, 'turn': turn + 1}) + else: + self.progress_bar.update(1) + self.fout.write(json.dumps(item, ensure_ascii=False) + "\n") + self.fout.flush() + if self.progress_bar.n >= self.progress_bar.total: + break + if len(batch) == 0: + break + except Exception as e: + print(f'error {e} when processing {batch}') + # ignore the error and continue + self.progress_bar.update(1) + if self.progress_bar.n >= self.progress_bar.total: + break + + +def main( + batch_size=1, + host='localhost', + input_file_name='input.jsonl', + output_file_name='output.jsonl', + port_num=1424, + process_lang=True, +): + input_data = load_data(f'{input_file_name}') + output_path = f'{output_file_name}' + existing_requests = set() + if os.path.exists(output_path): + with open(output_path, 'r', encoding='utf-8') as fin: + for line in fin: + line = json.loads(line) + existing_requests.add(calculate_key(line)) + print(f"Loaded {len(existing_requests)} existing requests") + + filter_data = [d for d in input_data if calculate_key(d) not in existing_requests] + + progress_bar = tqdm.tqdm(total=len(filter_data)) + + worker = Worker( + host=host, port=port_num, progress_bar=progress_bar, output_file=output_path, process_lang=process_lang + ) + for batch_idx in range(0, len(filter_data), batch_size): + batch = [line for line in filter_data[batch_idx : batch_idx + batch_size]] + turns = [ + 0 if 'mask' not in d['conversations'][0]['from'] or d['conversations'][0]['from'] == d['mask'] else 1 + for d in batch + ] + task = [{'prompt': get_prompt(d, turn, "", 0), 'item': d, 'turn': turn} for d, turn in zip(batch, turns)] + worker.process_result(task) + worker.fout.close() + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/scripts/nlp_language_modeling/sft/data_clean.py b/scripts/nlp_language_modeling/sft/data_clean.py new file mode 100644 index 000000000000..8c67aa2e3bcd --- /dev/null +++ b/scripts/nlp_language_modeling/sft/data_clean.py @@ -0,0 +1,97 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""script to clean the data for sft chat dataset. It will remove the records if all the tokens are masked due to truncation by sequence length +Example usage: + +MPT-7B: + python data_clean.py --dataset_file /dataset/INPUT.jsonl --output_file /dataset/OUTPUT.jsonl --library huggingface --model_name EleutherAI/gpt-neox-20b --seq_len 4096 +NeMo GPT: + python data_clean.py --dataset_file /dataset/INPUT.jsonl --output_file /dataset/OUTPUT.jsonl --library sentencepiece --model_file sentencepiece.model --seq_len 4096 +""" + + +import argparse +import json +from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import GPTSFTChatDataset +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer + + +def data_clean( + dataset_file, + output_file, + seq_len=4096, + library='huggingface', + model_name='EleutherAI/gpt-neox-20b', + tokenizer_model=None, +): + tokenizer = get_nmt_tokenizer( + library=library, model_name=model_name, tokenizer_model=tokenizer_model, use_fast=True + ) + if library == 'huggingface': + tokenizer.add_special_tokens({'additional_special_tokens': ['', '', '']}) + d = GPTSFTChatDataset(dataset_file, tokenizer, seq_len, 1) + total_records = len(d) + removed_ids = set() + for i in range(total_records): + if i % 1000 == 0: + print(i) + try: + if d[i]['mask'][: seq_len + 1].sum().item() == 0: + removed_ids.add(i) + print(f'removed {i}') + continue + except: + removed_ids.add(i) + print(f'Exception removed {i}') + with open(dataset_file, 'r', encoding='utf-8') as f: + with open(output_file, 'w', encoding='utf-8') as o: + for i, line in enumerate(f): + if i in removed_ids: + continue + obj = json.loads(line) + o.write(json.dumps(obj, ensure_ascii=False) + '\n') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--dataset_file", type=str, required=True, default='/dataset/input.jsonl') + parser.add_argument( + "--model_file", type=str, required=False, default=None, help="Path to the sentence piece model file." + ) + parser.add_argument( + "--library", + type=str, + required=False, + default='huggingface', + help="tokenizer library, huggingface or sentencepiece", + ) + parser.add_argument( + "--model_name", + type=str, + required=False, + default='EleutherAI/gpt-neox-20b', + help="huggingface tokenizer model name", + ) + parser.add_argument("--output_file", type=str, required=True) + parser.add_argument("--seq_len", type=int, required=False, default=4096) + args = parser.parse_args() + data_clean( + dataset_file=args.dataset_file, + output_file=args.output_file, + seq_len=args.seq_len, + library=args.library, + model_name=args.model_name, + tokenizer_model=args.model_file, + ) diff --git a/scripts/nlp_language_modeling/sft/preprocessing.py b/scripts/nlp_language_modeling/sft/preprocessing.py new file mode 100644 index 000000000000..7a08e055543d --- /dev/null +++ b/scripts/nlp_language_modeling/sft/preprocessing.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Processing data for megatron pretraining. + +Example to create dataset used for training attribute prediction model: + python preprocessing.py --input_file dataset/2023-04-12_oasst_all.trees.jsonl output_file_prefix=oasst_output mask_role=User type=TEXT_TO_VALUE split_ratio=0.95, seed=10 + +Example to create dataset used for attribute conditioned SFT model: + python preprocessing.py --input_file dataset/2023-04-12_oasst_all.trees.jsonl output_file_prefix=oasst_output mask_role=User type=VALUE_TO_TEXT split_ratio=0.95, seed=10 + +""" + +import json +import random + +import fire + +# All the keys ['spam', 'lang_mismatch', 'pii', 'not_appropriate', 'hate_speech', 'sexual_content', 'quality', 'toxicity', 'humor', 'creativity', 'violence', 'fails_task', 'helpfulness', 'political_content', 'moral_judgement'] +selected_keys = [ + 'quality', + 'toxicity', + 'humor', + 'creativity', + 'violence', + 'helpfulness', + 'not_appropriate', + 'hate_speech', + 'sexual_content', + 'fails_task', + 'political_content', + 'moral_judgement', +] +label_values = {} +likert_scale = 5 +system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n" + + +def encode_labels(labels): + items = [] + for key in selected_keys: + if key in labels: + value = labels[key]['value'] + items.append(f'{key}:{round(value*(likert_scale-1))}') + return ','.join(items) + + +def parse_conversations(tree_obj): + """ recusive function that returns all the sub converstaions in a list starting from node tree_obj + + Args: + tree_obj (obj): current conversation node + + Returns: + a list of sub conversation threads including the current conversation node + """ + if 'prompt' in tree_obj: + prompt_obj = tree_obj['prompt'] + elif 'text' in tree_obj and 'role' in tree_obj: + prompt_obj = tree_obj + else: + return [[]] + if prompt_obj['role'] == 'prompter': + role = 'User' + elif prompt_obj['role'] == 'assistant': + role = 'Assistant' + else: + raise ValueError(f'unknown role {prompt_obj["role"]}') + turn = {'value': prompt_obj['text'], 'from': role} + if 'labels' in prompt_obj: + turn['human_labels'] = prompt_obj['labels'] + for key in turn['human_labels']: + value_set = label_values.get(key, set()) + value_set.add(turn['human_labels'][key]['value']) + label_values[key] = value_set + turn['label'] = encode_labels(prompt_obj['labels']) + if 'lang' in prompt_obj: + turn['lang'] = prompt_obj['lang'].split('-')[0] + if turn['label'] == '': + turn['label'] = f'lang:{turn["lang"]}' + else: + turn['label'] = turn['label'] + f',lang:{turn["lang"]}' + value_set = label_values.get('lang', set()) + value_set.add(turn['lang']) + label_values['lang'] = value_set + all_conversations = [] + multiple_sub_threads = [] + for next_obj in prompt_obj['replies']: + multiple_threads = parse_conversations(next_obj) + multiple_sub_threads.extend(multiple_threads) + if len(multiple_sub_threads) != 0: + for sub_thread in multiple_sub_threads: + all_conversations.append([turn] + sub_thread) + else: + all_conversations.append([turn]) + return all_conversations + + +def get_data_records(objs, mask_role, type): + output = [] + for obj in objs: + multi_conversations = parse_conversations(obj) + for conversations in multi_conversations: + if len(conversations) <= 1: + # remove single turn conversations + continue + conversation_obj = {} + conversation_obj['conversations'] = [] + conversation_obj['tree_id'] = obj['message_tree_id'] + conversation_obj['conversations'] = conversations + conversation_obj['system'] = system_prompt + conversation_obj['mask'] = mask_role + conversation_obj['type'] = type + output.append(conversation_obj) + return output + + +def main( + input_file='2023-04-12_oasst_all.trees.jsonl', + output_file_prefix='oasst_output', + mask_role='User', + type='TEXT_TO_VALUE', + split_ratio=0.95, + seed=10, +): + all_objs = [] + with open(input_file, 'r', encoding='utf-8') as f: + for line in f: + obj = json.loads(line) + all_objs.append(obj) + random.seed(seed) + random.shuffle(all_objs) + train_num = int(len(all_objs) * split_ratio) + train_objs = all_objs[:train_num] + val_objs = all_objs[train_num:] + train_records = get_data_records(train_objs, mask_role, type) + val_records = get_data_records(val_objs, mask_role, type) + + with open(f'{output_file_prefix}_train.jsonl', 'w', encoding='utf-8') as f: + for record in train_records: + f.write(json.dumps(record, ensure_ascii=False) + '\n') + + with open(f'{output_file_prefix}_val.jsonl', 'w', encoding='utf-8') as f: + for record in val_records: + f.write(json.dumps(record, ensure_ascii=False) + '\n') + + for label in label_values: + values = sorted(list(label_values[label])) + print(f'{label} values: {values}') + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/tests/collections/nlp/test_chat_sft_dataset.py b/tests/collections/nlp/test_chat_sft_dataset.py new file mode 100644 index 000000000000..36d00e3108d7 --- /dev/null +++ b/tests/collections/nlp/test_chat_sft_dataset.py @@ -0,0 +1,317 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os +import random + +import pytest + +from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import GPTSFTChatDataset +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer + +TOKENIZER_FILE_43B = '/home/TestData/nlp/megatron_sft/tokenizer.model' +MERGE_FILE = '/home/TestData/nlp/megatron_sft/merges.txt' +VOCAB_FILE = '/home/TestData/nlp/megatron_sft/vocab.json' + + +def ids_to_text(tokenizer, ids): + tokens = tokenizer.ids_to_tokens(ids) + text = tokenizer.tokens_to_text(tokens) + return text + + +def get_random_sentence(): + nouns = ("puppy", "car", "rabbit", "girl", "monkey") + verbs = ("runs", "hits", "jumps", "drives", "barfs") + adv = ("crazily.", "dutifully.", "foolishly.", "merrily.", "occasionally.") + num1 = random.randrange(0, 5) + num2 = random.randrange(0, 5) + num3 = random.randrange(0, 5) + return nouns[num1] + ' ' + verbs[num2] + ' ' + adv[num3] + + +def get_random_label(): + keys = ["quality", "toxicity", "humor", "creativity", "violence", "helpfulness", "not_appropriate"] + values = [random.randrange(0, 5) for i in range(len(keys))] + return ",".join([k + ":" + str(v) for k, v in zip(keys, values)]) + + +def create_data_points(mask_user, turn_num, records, temp_file, t2v, label=True): + data_points = [] + with open(temp_file, 'w', encoding='utf-8') as f: + for r in range(records): + record = {} + record['system'] = 'a chat\n\n' + record['type'] = 'TEXT_TO_VALUE' if t2v else 'VALUE_TO_TEXT' + record['mask'] = 'User' if mask_user else 'Assistant' + turns = [] + record['conversations'] = turns + for i in range(turn_num): + turn = {} + turn['from'] = 'User' if i % 2 == 0 else 'Assistant' + turn['value'] = get_random_sentence() + if label: + turn['label'] = get_random_label() + turns.append(turn) + f.write(json.dumps(record, ensure_ascii=False) + '\n') + data_points.append(record) + return data_points + + +class TestGPTSFTChatDataset: + @classmethod + def setup_class(cls): + pass + + @pytest.mark.unit + def test_43B_tokenizer_mask_user(self): + random.seed(5) + temp_file = '/tmp/test_file.jsonl' + turn_num = 5 + records = 5 + try: + data_points = create_data_points(True, turn_num, records, temp_file, t2v=False) + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) + d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) + for i in range(len(d)): + result = d[i] + input_ids = result['input_ids'] + mask = result['mask'] + text = tokenizer.ids_to_text(input_ids[mask].tolist()) + expected_text = '' + for j in range(1, turn_num, 2): + expected_text += data_points[i]['conversations'][j]['value'] + '\n' + '' + assert text == expected_text + finally: + os.remove(temp_file) + + @pytest.mark.unit + def test_43B_tokenizer_mask_assistant(self): + random.seed(3) + temp_file = '/tmp/test_file.jsonl' + turn_num = 5 + records = 5 + try: + data_points = create_data_points(False, turn_num, records, temp_file, t2v=False) + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) + d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) + for i in range(len(d)): + result = d[i] + input_ids = result['input_ids'] + mask = result['mask'] + text = tokenizer.ids_to_text(input_ids[mask].tolist()) + expected_text = '' + for j in range(2, turn_num, 2): + expected_text += data_points[i]['conversations'][j]['value'] + '\n' + '' + assert text == expected_text + finally: + os.remove(temp_file) + + @pytest.mark.unit + def test_43B_tokenizer_mask_user_t2v(self): + random.seed(5) + temp_file = '/tmp/test_file.jsonl' + turn_num = 5 + records = 5 + try: + data_points = create_data_points(True, turn_num, records, temp_file, t2v=True) + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) + d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) + for i in range(len(d)): + result = d[i] + input_ids = result['input_ids'] + mask = result['mask'] + text = tokenizer.ids_to_text(input_ids[mask].tolist()) + expected_text = '' + for j in range(1, turn_num, 2): + expected_text += data_points[i]['conversations'][j]['label'] + '\n' + '' + assert text == expected_text + finally: + os.remove(temp_file) + + @pytest.mark.unit + def test_43B_tokenizer_mask_assistant_t2v(self): + random.seed(5) + temp_file = '/tmp/test_file.jsonl' + turn_num = 5 + records = 5 + try: + data_points = create_data_points(False, turn_num, records, temp_file, t2v=True) + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) + d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) + for i in range(len(d)): + result = d[i] + input_ids = result['input_ids'] + mask = result['mask'] + text = tokenizer.ids_to_text(input_ids[mask].tolist()) + expected_text = '' + for j in range(0, turn_num, 2): + expected_text += data_points[i]['conversations'][j]['label'] + '\n' + '' + assert text == expected_text + finally: + os.remove(temp_file) + + @pytest.mark.unit + def test_mpt_tokenizer_mask_user(self): + random.seed(5) + temp_file = '/tmp/test_file.jsonl' + turn_num = 5 + records = 5 + try: + data_points = create_data_points(True, turn_num, records, temp_file, t2v=False) + tokenizer = get_nmt_tokenizer( + library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True + ) + tokenizer.add_special_tokens( + {'additional_special_tokens': ['', '', '']} + ) + d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) + for i in range(len(d)): + result = d[i] + input_ids = result['input_ids'] + mask = result['mask'] + text = ids_to_text(tokenizer, input_ids[mask].tolist()) + expected_text = '' + for j in range(1, turn_num, 2): + expected_text += data_points[i]['conversations'][j]['value'] + '\n' + '' + assert text == expected_text + finally: + os.remove(temp_file) + + @pytest.mark.unit + def test_mpt_tokenizer_mask_assistant(self): + random.seed(3) + temp_file = '/tmp/test_file.jsonl' + turn_num = 5 + records = 5 + try: + data_points = create_data_points(False, turn_num, records, temp_file, t2v=False) + tokenizer = get_nmt_tokenizer( + library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True + ) + tokenizer.add_special_tokens( + {'additional_special_tokens': ['', '', '']} + ) + d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) + for i in range(len(d)): + result = d[i] + input_ids = result['input_ids'] + mask = result['mask'] + text = ids_to_text(tokenizer, input_ids[mask].tolist()) + expected_text = '' + for j in range(2, turn_num, 2): + expected_text += data_points[i]['conversations'][j]['value'] + '\n' + '' + assert text == expected_text + finally: + os.remove(temp_file) + + @pytest.mark.unit + def test_mpt_tokenizer_mask_user_t2v(self): + random.seed(5) + temp_file = '/tmp/test_file.jsonl' + turn_num = 5 + records = 5 + try: + data_points = create_data_points(True, turn_num, records, temp_file, t2v=True) + tokenizer = get_nmt_tokenizer( + library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True + ) + tokenizer.add_special_tokens( + {'additional_special_tokens': ['', '', '']} + ) + d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) + for i in range(len(d)): + result = d[i] + input_ids = result['input_ids'] + mask = result['mask'] + text = ids_to_text(tokenizer, input_ids[mask].tolist()) + expected_text = '' + for j in range(1, turn_num, 2): + expected_text += data_points[i]['conversations'][j]['label'] + '\n' + '' + assert text == expected_text + finally: + os.remove(temp_file) + + @pytest.mark.unit + def test_mpt_tokenizer_mask_assistant_t2v(self): + random.seed(5) + temp_file = '/tmp/test_file.jsonl' + turn_num = 5 + records = 5 + try: + data_points = create_data_points(False, turn_num, records, temp_file, t2v=True) + tokenizer = get_nmt_tokenizer( + library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True + ) + tokenizer.add_special_tokens( + {'additional_special_tokens': ['', '', '']} + ) + d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) + for i in range(len(d)): + result = d[i] + input_ids = result['input_ids'] + mask = result['mask'] + text = ids_to_text(tokenizer, input_ids[mask].tolist()) + expected_text = '' + for j in range(0, turn_num, 2): + expected_text += data_points[i]['conversations'][j]['label'] + '\n' + '' + assert text == expected_text + finally: + os.remove(temp_file) + + @pytest.mark.unit + def test_43B_tokenizer_mask_user_nolabel(self): + random.seed(5) + temp_file = '/tmp/test_file.jsonl' + turn_num = 5 + records = 5 + try: + data_points = create_data_points(True, turn_num, records, temp_file, t2v=False, label=False) + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) + d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) + for i in range(len(d)): + result = d[i] + input_ids = result['input_ids'] + mask = result['mask'] + text = tokenizer.ids_to_text(input_ids[mask].tolist()) + expected_text = '' + for j in range(1, turn_num, 2): + expected_text += data_points[i]['conversations'][j]['value'] + '\n' + '' + assert text == expected_text + finally: + os.remove(temp_file) + + @pytest.mark.unit + def test_43B_tokenizer_mask_assistant_nolabel(self): + random.seed(3) + temp_file = '/tmp/test_file.jsonl' + turn_num = 5 + records = 5 + try: + data_points = create_data_points(False, turn_num, records, temp_file, t2v=False, label=False) + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) + d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) + for i in range(len(d)): + result = d[i] + input_ids = result['input_ids'] + mask = result['mask'] + text = tokenizer.ids_to_text(input_ids[mask].tolist()) + expected_text = '' + for j in range(2, turn_num, 2): + expected_text += data_points[i]['conversations'][j]['value'] + '\n' + '' + assert text == expected_text + finally: + os.remove(temp_file)