diff --git a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py old mode 100644 new mode 100755 index 34663a6e34..2ff4d28d2e --- a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py +++ b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py @@ -1,103 +1,35 @@ import copy -import json import numpy as np -import os -import pyonmttok import time -from onmt.constants import CorpusTask, DefaultTokens from onmt.inference_engine import InferenceEnginePY -from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter import onmt.opts as opts from onmt.utils.logging import init_logger from onmt.utils.parse import ArgumentParser from onmt.utils.misc import use_gpu, set_random_seed -from onmt.transforms import get_transforms_cls - - -def compute_file_ppl(output_filename): - with open(output_filename, "r") as f: - run_results = json.load(f) - nlls = [] - lengths = [] - for i, _res in enumerate(run_results["scored_results"]): - print(_res) - nlls.append(_res[0]) - lengths.append(_res[1]) - file_ppl = np.exp(-np.sum(nlls) / np.sum(lengths)) - print("wikitext-2 ppl: %.4f" % file_ppl) def tokenize_dataset(opt, context_length): print("Tokenization...") - - # Prepare the dataset + # Clean and Concat the dataset x = open(opt.src, "r").readlines() - x = [_x.rstrip("\n") for _x in x] - y = DefaultTokens.SEP.join(x) - - with open(opt.src + ".temp", "w") as writer: - writer.write(y) - - # ########################## # - # Build the dataset iterator # - # ########################## # - - # Build the vocab - vocab_path_in = "/nas-labs/LM/big_llms/llama/7B/llama.vocab" - voc = [] - with open(vocab_path_in, "r", encoding="utf-8") as reader: - for line in reader: - line = line.strip("\n") - voc.append(line) - vocabs = {} - src_vocab = pyonmttok.build_vocab_from_tokens(voc) - vocabs["src"] = src_vocab - vocabs["tgt"] = src_vocab - vocabs["data_task"] = "lm" - vocabs["decoder_start_token"] = "" - - transforms_cls = get_transforms_cls(opt._all_transform) - - new_opt = opt - new_opt.gpu = -1 - new_opt.parallel_mode = "data_parallel" - new_opt.src = opt.src + ".temp" - - dataset_iter = build_dynamic_dataset_iter( - new_opt, transforms_cls, vocabs, task=CorpusTask.INFER, device_id=-1 - ) - - input_tokens = [] - for batch, i in dataset_iter: - for i in range(batch["src"].size()[0]): - start_ids = batch["src"][i, :, 0].cpu().numpy().tolist() - input_tokens += [ - vocabs["src"].lookup_index(id) - for id in start_ids - if id != vocabs["src"].lookup_token(DefaultTokens.PAD) - ] - - def make_chunks(lst, n): - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i : i + n] - - # #################### # - # Tokenize the dataset # - # ################### # - with open(opt.src + f".tokenized.context_{context_length}", "w") as writer: - for _chunk in make_chunks(input_tokens, context_length - 1): - writer.write(" ".join(_chunk) + "\n") - print(len(_chunk)) + xx = [_x for _x in x if _x != " \n"] + from onmt.transforms.tokenize import SentencePieceTransform + tokenizer = SentencePieceTransform(opt) + tokenizer.warm_up() + tokens = tokenizer._tokenize(xx) print("Done !") - - z = open(opt.src + f".tokenized.context_{context_length}", "r").readlines() - print(len(z[0].split(" "))) + return tokens def evaluate(opt): - """Score the wikitext2 testset""" + """Score the wikitext2 testset + + The perplexity of the file is calculated with a window size of max_seq_length = 4096 tokens. + At each step, the window shifts by 512 tokens, and its first max_seq_length - stride + tokens are considered as context tokens. This means that their logits are not + taken into account, allowing this rolling perplexity to be calculated without overlap.""" + ArgumentParser.validate_translate_opts(opt) ArgumentParser._get_all_transform_translate(opt) ArgumentParser._validate_transforms_opts(opt) @@ -105,37 +37,47 @@ def evaluate(opt): logger = init_logger(opt.log_file) set_random_seed(opt.seed, use_gpu(opt)) - run_results = {} - dir_name = os.path.dirname(opt.models[0]) - base_name = os.path.basename(opt.models[0]) - - output_filename = os.path.join( - dir_name, "wikitext-2_benchmark_%s.json" % base_name[:-3] - ) + # Tokenize the dataset. + opt.src = "wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw" + tokens = tokenize_dataset(opt, context_length=512) # Build the translator (along with the model. engine_opt = copy.copy(opt) engine_opt._all_transform = [] engine = InferenceEnginePY(engine_opt) - # Tokenize the dataset. - opt.src = "eval_llm/WIKITEXT2/wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw" - tokenize_dataset(opt, context_length=512) - - # Score the tokeznized dataset - engine.opt.src = opt.src + f".tokenized.context_{512}" - start_time = time.time() - scored_results = engine.score_file() - engine.terminate() - run_results["scored_results"] = scored_results + # Score the dataset. + stride = 512 + max_seq_length = 4096 - with open(output_filename, "w") as f: - json.dump(run_results, f, ensure_ascii=False, indent=2) + seq_len = len(tokens) + src = [] + for begin_loc in range(0, seq_len, stride): + end_loc = min(begin_loc + max_seq_length, seq_len) + src.append(" ".join(tokens[begin_loc:end_loc])) - compute_file_ppl(output_filename) + start_time = time.time() + engine.translator.return_gold_log_probs = True + score_results = engine.score_list(src=src) + nlls = [] + lengths = [] + for _, log_probs, _ in score_results: + lengths.append(stride) + # zero out the context tokens + nlls += [ + log_probs[i][0] + for i, _ in enumerate(log_probs) + if i > (max_seq_length - stride) + ] + ppl = np.exp(-np.sum(nlls) / np.sum(lengths)) + engine.terminate() end_time = time.time() logger.info("total run time %.2f" % (end_time - start_time)) + logger.info( + "wikitext-2 perplexity with rolling likelihood and sliding window size 1000 and stride 512 %.2f" # noqa: E501 + % (ppl) + ) def _get_parser(): diff --git a/onmt/inference_engine.py b/onmt/inference_engine.py index a87b4ae76a..b088f497ef 100755 --- a/onmt/inference_engine.py +++ b/onmt/inference_engine.py @@ -163,6 +163,7 @@ def _translate(self, infer_iter): def _score(self, infer_iter): self.translator.with_scores = True + self.return_gold_log_probs = True return self.translator._score(infer_iter) def score_list_parallel(self, src): diff --git a/onmt/tests/pull_request_chk.sh b/onmt/tests/pull_request_chk.sh index 99bfc81680..29293b925c 100755 --- a/onmt/tests/pull_request_chk.sh +++ b/onmt/tests/pull_request_chk.sh @@ -558,7 +558,7 @@ ${PYTHON} translate.py -model ${TEST_DIR}/test_model_lm.pt \ -ban_unk_token \ -length_penalty none \ -out $TMP_OUT_DIR/gen_sampling >> ${LOG_FILE} 2>&1 -diff ${DATA_DIR}/data_lm/gen-nucleus-sampling-sol$(python -c "import torch; print(torch.__version__[0])").txt $TMP_OUT_DIR/gen_sampling +diff ${DATA_DIR}/data_lm/gen-nucleus-sampling-sol$(${PYTHON} -c "import torch; print(torch.__version__[0])").txt $TMP_OUT_DIR/gen_sampling [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} rm $TMP_OUT_DIR/gen_sampling diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index b8c9f57203..85a8dc1ad9 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -133,6 +133,7 @@ def __init__( logger=None, seed=-1, with_score=False, + return_gold_log_probs=False, ): self.model = model self.vocabs = vocabs @@ -205,6 +206,8 @@ def __init__( set_random_seed(seed, self._use_cuda) self.with_score = with_score + self.return_gold_log_probs = return_gold_log_probs + @classmethod def from_opt( cls, @@ -280,26 +283,17 @@ def _log(self, msg): print(msg) def _gold_score( - self, - batch, - enc_out, - src_len, - use_src_map, - enc_final_hs, - batch_size, - src, + self, batch, enc_out, src_len, use_src_map, enc_final_hs, batch_size, src ): if "tgt" in batch.keys() and not self.tgt_file_prefix: - gs = self._score_target( - batch, - enc_out, - src_len, - batch["src_map"] if use_src_map else None, + gs, glp = self._score_target( + batch, enc_out, src_len, batch["src_map"] if use_src_map else None ) self.model.decoder.init_state(src, enc_out, enc_final_hs) else: gs = [0] * batch_size - return gs + glp = None + return gs, glp def _translate( self, @@ -584,12 +578,25 @@ def _score(self, infer_iter): self.with_scores = True scored_bucket = {} for batch, bucket_idx in infer_iter: - batch_data = self.translate_batch(batch, attn_debug=False) + batch_data = self.translate_batch(batch, attn_debug=False, scoring=True) batch_gold_scores = batch_data["gold_score"].cpu().numpy().tolist() + if self.return_gold_log_probs: + batch_gold_log_probs = ( + batch_data["gold_log_probs"].cpu().numpy().tolist() + ) + else: + batch_gold_log_probs = None batch_tgt_lengths = batch["tgtlen"].cpu().numpy().tolist() batch_inds_in_bucket = batch["ind_in_bucket"] for i, _score in enumerate(batch_gold_scores): - scored_bucket[batch_inds_in_bucket[i]] = (_score, batch_tgt_lengths[i]) + log_probs = ( + batch_gold_log_probs[i] if self.return_gold_log_probs else None + ) + scored_bucket[batch_inds_in_bucket[i]] = ( + _score, + log_probs, + batch_tgt_lengths[i], + ) score_results = [scored_bucket[i] for i in range(len(scored_bucket))] return score_results @@ -720,6 +727,7 @@ def _score_target(self, batch, enc_out, src_len, src_map): def report_results( self, gold_score, + gold_log_probs, batch, batch_size, decode_strategy, @@ -730,6 +738,7 @@ def report_results( "attention": None, "batch": batch, "gold_score": gold_score, + "gold_log_probs": gold_log_probs, } results["scores"] = decode_strategy.scores @@ -900,7 +909,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): self.model.decoder.init_state(src, enc_out, enc_final_hs) - gold_score = self._gold_score( + gold_score, gold_log_probs = self._gold_score( batch, enc_out, src_len, @@ -961,6 +970,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): return self.report_results( gold_score, + gold_log_probs, batch, batch_size, decode_strategy, @@ -982,7 +992,7 @@ def _score_target(self, batch, enc_out, src_len, src_map): gold = tgt[:, 1:, :] gold_scores = log_probs.gather(2, gold) gold_scores = gold_scores.sum(dim=1).view(-1) - return gold_scores + return gold_scores, None class GeneratorLM(Inference): @@ -1001,8 +1011,9 @@ def _align_forward(self, batch, predictions): """ raise NotImplementedError - def translate_batch(self, batch, attn_debug): + def translate_batch(self, batch, attn_debug, scoring=False): """Translate a batch of sentences.""" + max_length = 0 if scoring else self.max_length with torch.no_grad(): if self.sample_from_topk != 0 or self.sample_from_topp != 0: decode_strategy = GreedySearchLM( @@ -1015,7 +1026,7 @@ def translate_batch(self, batch, attn_debug): batch_size=len(batch["srclen"]), global_scorer=self.global_scorer, min_length=self.min_length, - max_length=self.max_length, + max_length=max_length, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, return_attention=attn_debug or self.replace_unk, @@ -1039,7 +1050,7 @@ def translate_batch(self, batch, attn_debug): n_best=self.n_best, global_scorer=self.global_scorer, min_length=self.min_length, - max_length=self.max_length, + max_length=max_length, return_attention=attn_debug or self.replace_unk, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, @@ -1095,14 +1106,8 @@ def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True): # (2) init decoder self.model.decoder.init_state(src, None, None) - gold_score = self._gold_score( - batch, - None, - src_len, - use_src_map, - None, - batch_size, - src, + gold_score, gold_log_probs = self._gold_score( + batch, None, src_len, use_src_map, None, batch_size, src ) # (3) prep decode_strategy. Possibly repeat src objects. @@ -1158,6 +1163,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True): return self.report_results( gold_score, + gold_log_probs, batch, batch_size, decode_strategy, @@ -1177,7 +1183,10 @@ def _score_target(self, batch, enc_out, src_len, src_map): ) log_probs[:, :, self._tgt_pad_idx] = 0 - gold_scores = log_probs.gather(2, tgt) - gold_scores = gold_scores.sum(dim=1).view(-1) + gold_log_probs = log_probs.gather(2, tgt) + gold_scores = gold_log_probs.sum(dim=1).view(-1) + + if self.return_gold_log_probs: + return gold_scores, gold_log_probs - return gold_scores + return gold_scores, None