From b6a2a0184198e4df86679b8fce2041ce20c1470c Mon Sep 17 00:00:00 2001 From: XapaJIaMnu Date: Thu, 29 Jun 2023 17:12:57 +0100 Subject: [PATCH 01/12] Add initial logger support --- src/opustrainer/logger.py | 31 +++++++++++++++++++++++++++++++ src/opustrainer/trainer.py | 26 +++++++++++++++----------- 2 files changed, 46 insertions(+), 11 deletions(-) create mode 100644 src/opustrainer/logger.py diff --git a/src/opustrainer/logger.py b/src/opustrainer/logger.py new file mode 100644 index 0000000..a7b61e9 --- /dev/null +++ b/src/opustrainer/logger.py @@ -0,0 +1,31 @@ +import logging +from sys import stderr +from functools import lru_cache + +def getLogLevel(name: str) -> int: + """Incredibly, i can't find a function that will do this conversion, other + than setLevel, but setLevel doesn't work for calling the different log level logs.""" + if name.upper() in logging.getLevelNamesMapping(): + return logging.getLevelNamesMapping()[name.upper()] + else: + logging.log(logging.WARNING, "unknown log level level used: " + name + " assuming warning...") + return logging.WARNING + +def log(msg: str, loglevel: str = "INFO") -> None: + level = getLogLevel(loglevel) + logging.log(level, msg) + + +@lru_cache(None) +def log_once(msg: str, loglevel: str = "INFO") -> None: + """A wrapper to log, to make sure that we only print things once""" + log(msg, loglevel) + + +def setup_logger(outputfilename: str | None = None, loglevel: str = "INFO") -> None: + """Sets up the logger with the necessary settings""" + loggingformat = '[%(asctime)s] [Trainer] [%(levelname)s] %(message)s' + if outputfilename is None: + logging.basicConfig(stream=stderr, encoding='utf-8', level=getLogLevel(loglevel), format=loggingformat, datefmt='%Y-%m-%d %H:%M:%S') + else: + logging.basicConfig(filename=outputfilename, encoding='utf-8', level=getLogLevel(loglevel), format=loggingformat, datefmt='%Y-%m-%d %H:%M:%S') diff --git a/src/opustrainer/trainer.py b/src/opustrainer/trainer.py index f3ea12c..0ab53aa 100755 --- a/src/opustrainer/trainer.py +++ b/src/opustrainer/trainer.py @@ -26,6 +26,7 @@ from opustrainer.modifiers.surface import UpperCaseModifier, TitleCaseModifier from opustrainer.modifiers.placeholders import PlaceholderTagModifier from opustrainer.modifiers.typos import TypoModifier +from opustrainer import logger def ignore_sigint(): """Used as pre-exec hook for the trainer program as to ignore ctrl-c. We'll @@ -161,7 +162,7 @@ def close(self): self._fh.close() def _open(self): - print(f"[Trainer] Reading {self.dataset.name} for epoch {self.epoch}") + logger.log(f"Reading {self.dataset.name} for epoch {self.epoch}") # Open temporary file which will contain shuffled version of `cat self.files` fh = TemporaryFile(mode='w+', encoding='utf-8', dir=self.tmpdir) @@ -257,7 +258,7 @@ def _kill_async(self): self._pending = None def _open(self): - print(f"[Trainer] Reading {self.dataset.name} for epoch {self.epoch}") + logger.log(f"Reading {self.dataset.name} for epoch {self.epoch}") # First time self._pending is None, but all subsequent calls to _open # should have self._pending be set. @@ -606,7 +607,7 @@ def next_stage(self) -> Optional[Stage]: def run(self, *, batch_size:int=100) -> Iterable[List[str]]: """Yield batches, moving through the stages of training as datasets are consumed.""" while self.stage is not None: - print(f"[Trainer] Starting stage {self.stage.name}") + logger.log(f"Starting stage {self.stage.name}") while self.stage.until_epoch is None or self.epoch_tracker.epoch < self.stage.until_epoch: batch: List[str] = [] @@ -692,10 +693,10 @@ def run(self, trainer:Trainer, *args, **kwargs): self._dump(trainer) -def print_state(state:TrainerState, file:TextIO=sys.stdout) -> None: - print(f"[Trainer] At stage {state.stage}", file=file) +def print_state(state:TrainerState) -> None: + logger.log(f"At stage {state.stage}") for name, reader in state.datasets.items(): - print(f"[Trainer] Dataset {name}: overall epochs {reader.epoch: 3d}.{reader.line:010d}", file=file) + logger.log(f"Dataset {name}: overall epochs {reader.epoch: 3d}.{reader.line:010d}") def main() -> None: @@ -706,9 +707,12 @@ def main() -> None: parser.add_argument("--temporary-directory", '-T', default=None, type=str, help='Temporary dir, used for shuffling and tracking state') parser.add_argument("--do-not-resume", '-d', action="store_true", help='Do not resume from the previous training state') parser.add_argument("--no-shuffle", '-n', action="store_false", help='Do not shuffle, for debugging', dest="shuffle") + parser.add_argument("--log-level", type=str, default="INFO", help="Set log level. Available levels: DEBUG, INFO, WARNING, ERROR, CRITICAL") + parser.add_argument("--log-file", '-l', type=str, default=None, help="Target location for logging. Default is stderr.") parser.add_argument("trainer", type=str, nargs=argparse.REMAINDER, help="Trainer program that gets fed the input. If empty it is read from config.") args = parser.parse_args() + logger.setup_logger(args.log_file, args.log_level) with open(args.config, 'r', encoding='utf-8') as fh: config = yaml.safe_load(fh) @@ -726,7 +730,7 @@ def main() -> None: state_tracker = StateTracker(args.state or f'{args.config}.state', restore=not args.do_not_resume) # Make trainer listen to `kill -SIGUSR1 $PID` to print dataset progress - signal.signal(signal.SIGUSR1, lambda signum, handler: print_state(trainer.state(), sys.stderr)) + signal.signal(signal.SIGUSR1, lambda signum, handler: print_state(trainer.state())) model_trainer = subprocess.Popen( args.trainer or config['trainer'], @@ -749,7 +753,7 @@ def main() -> None: for batch in state_tracker.run(trainer): model_trainer.stdin.writelines(batch) except KeyboardInterrupt: - print("[Trainer] Ctrl-c pressed, stopping training") + logger.log("Ctrl-c pressed, stopping training") # Levels of waiting for the trainer. This is reached either because we ran out of batches # or because ctrl-c was pressed. Pressing ctrl-c more advances to next level of aggressiveness. @@ -761,8 +765,8 @@ def main() -> None: model_trainer.terminate() else: model_trainer.kill() - - print(f"[Trainer] waiting for trainer to {stage}. Press ctrl-c to be more aggressive") + + logger.log(f"waiting for trainer to {stage}. Press ctrl-c to be more aggressive") sys.exit(model_trainer.wait()) # blocking except KeyboardInterrupt: continue @@ -770,7 +774,7 @@ def main() -> None: # BrokenPipeError is thrown by writelines() or close() and indicates that the child trainer # process is no more. We can safely retrieve its return code and exit with that, it should # not block at this point. - print("[Trainer] trainer stopped reading input") + logger.log("trainer stopped reading input") sys.exit(model_trainer.wait()) From 4da471253ddabb51efe6a96df982fa42c3c13120 Mon Sep 17 00:00:00 2001 From: Nikolay Bogoychev Date: Fri, 30 Jun 2023 16:06:14 +0100 Subject: [PATCH 02/12] Always log to stderr --- src/opustrainer/logger.py | 12 +++++++----- src/opustrainer/trainer.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/opustrainer/logger.py b/src/opustrainer/logger.py index a7b61e9..c566724 100644 --- a/src/opustrainer/logger.py +++ b/src/opustrainer/logger.py @@ -1,5 +1,6 @@ import logging from sys import stderr +from typing import List from functools import lru_cache def getLogLevel(name: str) -> int: @@ -23,9 +24,10 @@ def log_once(msg: str, loglevel: str = "INFO") -> None: def setup_logger(outputfilename: str | None = None, loglevel: str = "INFO") -> None: - """Sets up the logger with the necessary settings""" + """Sets up the logger with the necessary settings. Outputs to both file and stderr""" loggingformat = '[%(asctime)s] [Trainer] [%(levelname)s] %(message)s' - if outputfilename is None: - logging.basicConfig(stream=stderr, encoding='utf-8', level=getLogLevel(loglevel), format=loggingformat, datefmt='%Y-%m-%d %H:%M:%S') - else: - logging.basicConfig(filename=outputfilename, encoding='utf-8', level=getLogLevel(loglevel), format=loggingformat, datefmt='%Y-%m-%d %H:%M:%S') + handlers = [logging.StreamHandler(stream=stderr)] + if outputfilename is not None: + handlers.append(logging.FileHandler(filename=outputfilename)) + logging.basicConfig(handlers=handlers, encoding='utf-8', level=getLogLevel(loglevel), format=loggingformat, datefmt='%Y-%m-%d %H:%M:%S') + diff --git a/src/opustrainer/trainer.py b/src/opustrainer/trainer.py index 0ab53aa..ccf81d6 100755 --- a/src/opustrainer/trainer.py +++ b/src/opustrainer/trainer.py @@ -708,7 +708,7 @@ def main() -> None: parser.add_argument("--do-not-resume", '-d', action="store_true", help='Do not resume from the previous training state') parser.add_argument("--no-shuffle", '-n', action="store_false", help='Do not shuffle, for debugging', dest="shuffle") parser.add_argument("--log-level", type=str, default="INFO", help="Set log level. Available levels: DEBUG, INFO, WARNING, ERROR, CRITICAL") - parser.add_argument("--log-file", '-l', type=str, default=None, help="Target location for logging. Default is stderr.") + parser.add_argument("--log-file", '-l', type=str, default=None, help="Target location for logging. Always logs to stderr and optionally to a file.") parser.add_argument("trainer", type=str, nargs=argparse.REMAINDER, help="Trainer program that gets fed the input. If empty it is read from config.") args = parser.parse_args() From 5bf35a7b0f18cfa6cf823ae2854a0e3b8bbf359c Mon Sep 17 00:00:00 2001 From: XapaJIaMnu Date: Tue, 4 Jul 2023 02:05:27 +0100 Subject: [PATCH 03/12] Fix existing tests --- .../test-data/test_enzh_config.expected.out | 12 ---------- ...est_enzh_tags_advanced_config.expected.out | 23 ------------------- .../test_enzh_tags_stage_config.expected.out | 23 ------------------- .../test-data/test_zhen_config.expected.out | 12 ---------- .../test_zhen_config_prefix.expected.out | 12 ---------- tests/test_endtoend.py | 18 ++++++++++----- 6 files changed, 12 insertions(+), 88 deletions(-) diff --git a/contrib/test-data/test_enzh_config.expected.out b/contrib/test-data/test_enzh_config.expected.out index 788b13d..5099636 100644 --- a/contrib/test-data/test_enzh_config.expected.out +++ b/contrib/test-data/test_enzh_config.expected.out @@ -98,15 +98,3 @@ On 1 March, Australia Reported The First Death From COVID-19: A 78-year-old Pert Food Poisoning and Food Hygiene. 微生物检验与食品安全控制. On 20 September 201, The Australian Olympic Committee Announced The First Set Of Sailors Selected For Tokyo 2020, Namely Rio 2016 Silver Medalists And Deending World 470 Champions Mathew Belcher And William Ryan And World's Current Top-ranked Laser Sailor Matthew Wearn. 2019年 9月 20日, 澳大利亚奥林匹克委员会公布了第一批入选奥运阵容的帆船选手名单, 名单中包括 2016年里约奥运银牌得主 Mathew Belcher 和 William Ryan 。 2020年 2月 27日, 第二批入选奥运阵容的帆船选手名单正式公布。 2020年 3月 19日, Mara Stransky 确认获得代表澳大利亚参加女子辐射型的资格, 成为第三批入选的帆船选手。 429 __source__ TRANSPORT __target__ 运输 __done__ SQUADRON (429 BIISON SQUADRON) - FYLING THE CC-17 429 运输中队 (429 野牛), 使用 CC - 177 -[Trainer] Starting stage start -[Trainer] Reading clean for epoch 0 -[Trainer] Reading clean for epoch 1 -[Trainer] Reading clean for epoch 2 -[Trainer] Reading clean for epoch 3 -[Trainer] Reading clean for epoch 4 -[Trainer] Reading clean for epoch 5 -[Trainer] Reading clean for epoch 6 -[Trainer] Reading clean for epoch 7 -[Trainer] Reading clean for epoch 8 -[Trainer] Reading clean for epoch 9 -[Trainer] waiting for trainer to exit. Press ctrl-c to be more aggressive diff --git a/contrib/test-data/test_enzh_tags_advanced_config.expected.out b/contrib/test-data/test_enzh_tags_advanced_config.expected.out index ea0f031..d32a68c 100644 --- a/contrib/test-data/test_enzh_tags_advanced_config.expected.out +++ b/contrib/test-data/test_enzh_tags_advanced_config.expected.out @@ -198,26 +198,3 @@ FOOD POISONING AND FOOD HYGIENE. 微生物检验与食品安全控制. TOGETHER __TARGET__ 共同 __DONE__ AT HOME (ALSO KNOWN AS ONE WORLD: TOGETHER AT HOME) WAS A VIRTUAL __TARGET__ 虚拟 __DONE__ CONCERT __TARGET__ 演唱会 __DONE__ SERIES ORGANISED BY GLOBAL CITIZEN __TARGET__ 公民 __DONE__ AND CURATED BY SINGER LADY GAGA, IN SUPPORT __TARGET__ 支持 __DONE__ OF THE __TARGET__ , __DONE__ WORLD __TARGET__ 世界 __DONE__ HEALTH __TARGET__ 卫生 __DONE__ ORGANIZATION. 同一个世界: 共同在家 (英语: ONE WORLD: TOGETHER AT HOME), 是于 2020年 4月 18日举行的虚拟系列演唱会, 推广在 2019 冠状病毒病疫情期间保持社交距离等防疫理念, 由全球公民和歌手嘎嘎小姐共同组织发起, 以支持世界卫生组织。 On 20 September 2019, __target__ 2019年 __done__ the __target__ 了 __done__ Australian __target__ 澳大利亚 __done__ Olympic Committee __target__ 委员会 __done__ announced __target__ 公布 __done__ the first set of sailors selected __target__ 入选 __done__ for Tokyo 2020, namely Rio __target__ 里约 __done__ 2016 silver medalists and __target__ 和 __done__ defending world 470 champions Mathew Belcher and William Ryan and world's current top-ranked Laser sailor Matthew Wearn. 2019年 9月 20日, 澳大利亚奥林匹克委员会公布了第一批入选奥运阵容的帆船选手名单, 名单中包括 2016年里约奥运银牌得主 Mathew Belcher 和 William Ryan 。 2020年 2月 27日, 第二批入选奥运阵容的帆船选手名单正式公布。 2020年 3月 19日, Mara Stransky 确认获得代表澳大利亚参加女子辐射型的资格, 成为第三批入选的帆船选手。 TRANSFUSION-RELATED ACUTE LUNG __TARGET__ 肺 __DONE__ INJURY (TRALI) IS A SERIOUS BLOOD TRANSFUSION COMPLICATION CHARACTERIZED BY THE ACUTE __TARGET__ 急性 __DONE__ ONSET __TARGET__ 引发 __DONE__ OF NON-CARDIOGENIC PULMONARY EDEMA FOLLOWING TRANSFUSION __TARGET__ 输血併 __DONE__ OF BLOOD PRODUCTS. __TARGET__ 。 __DONE__ 输血相关急性肺损伤 (TRANSFUSION RELATED ACUTE LUNG INJURY; TRALI) 是一种会引发急性肺水肿的严重输血併发症。 -[Trainer] Starting stage start -[Trainer] Reading clean for epoch 0 -[Trainer] Reading clean for epoch 1 -[Trainer] Reading clean for epoch 2 -[Trainer] Reading clean for epoch 3 -[Trainer] Reading clean for epoch 4 -[Trainer] Reading clean for epoch 5 -[Trainer] Reading clean for epoch 6 -[Trainer] Reading clean for epoch 7 -[Trainer] Reading clean for epoch 8 -[Trainer] Reading clean for epoch 9 -[Trainer] Starting stage end -[Trainer] Reading clean for epoch 10 -[Trainer] Reading clean for epoch 11 -[Trainer] Reading clean for epoch 12 -[Trainer] Reading clean for epoch 13 -[Trainer] Reading clean for epoch 14 -[Trainer] Reading clean for epoch 15 -[Trainer] Reading clean for epoch 16 -[Trainer] Reading clean for epoch 17 -[Trainer] Reading clean for epoch 18 -[Trainer] Reading clean for epoch 19 -[Trainer] waiting for trainer to exit. Press ctrl-c to be more aggressive diff --git a/contrib/test-data/test_enzh_tags_stage_config.expected.out b/contrib/test-data/test_enzh_tags_stage_config.expected.out index 7cd6a06..39f9b57 100644 --- a/contrib/test-data/test_enzh_tags_stage_config.expected.out +++ b/contrib/test-data/test_enzh_tags_stage_config.expected.out @@ -198,26 +198,3 @@ Food Poisoning and Food Hygiene. 微生物检验与食品安全控制. Together at Home __target__ 在家 __done__ (also __target__ ( __done__ known as One World: Together at Home) was a virtual concert __target__ 演唱会 __done__ series organised __target__ 组织 __done__ by Global __target__ 全球 __done__ Citizen __target__ 公民 __done__ and curated by singer Lady Gaga, in __target__ 以 __done__ support of the __target__ , __done__ World __target__ 世界 __done__ Health Organization. 同一个世界: 共同在家 (英语: One World: Together at Home), 是于 2020年 4月 18日举行的虚拟系列演唱会, 推广在 2019 冠状病毒病疫情期间保持社交距离等防疫理念, 由全球公民和歌手嘎嘎小姐共同组织发起, 以支持世界卫生组织。 On 20 September __target__ 9月 __done__ 2019, the __target__ 了 __done__ Australian __target__ 澳大利亚 __done__ Olympic Committee __target__ 委员会 __done__ announced the first set of sailors selected __target__ 入选 __done__ for Tokyo 2020, namely Rio 2016 silver __target__ 银牌 __done__ medalists __target__ 阵容 __done__ and __target__ 和 __done__ defending world 470 champions Mathew Belcher __target__ Mara __done__ and William Ryan and world's __target__ 成为 __done__ current top-ranked Laser sailor Matthew Wearn. 2019年 9月 20日, 澳大利亚奥林匹克委员会公布了第一批入选奥运阵容的帆船选手名单, 名单中包括 2016年里约奥运银牌得主 Mathew Belcher 和 William Ryan 。 2020年 2月 27日, 第二批入选奥运阵容的帆船选手名单正式公布。 2020年 3月 19日, Mara Stransky 确认获得代表澳大利亚参加女子辐射型的资格, 成为第三批入选的帆船选手。 Transfusion-related __target__ 相关 __done__ acute __target__ 急性 __done__ lung injury (TRALI) is a serious __target__ 严重 __done__ blood transfusion complication __target__ TRALI) __done__ characterized by the acute __target__ 急性 __done__ onset __target__ 引发 __done__ of non-cardiogenic pulmonary edema following transfusion of blood products. 输血相关急性肺损伤 (Transfusion related acute lung injury; TRALI) 是一种会引发急性肺水肿的严重输血併发症。 -[Trainer] Starting stage start -[Trainer] Reading clean for epoch 0 -[Trainer] Reading clean for epoch 1 -[Trainer] Reading clean for epoch 2 -[Trainer] Reading clean for epoch 3 -[Trainer] Reading clean for epoch 4 -[Trainer] Reading clean for epoch 5 -[Trainer] Reading clean for epoch 6 -[Trainer] Reading clean for epoch 7 -[Trainer] Reading clean for epoch 8 -[Trainer] Reading clean for epoch 9 -[Trainer] Starting stage end -[Trainer] Reading clean for epoch 10 -[Trainer] Reading clean for epoch 11 -[Trainer] Reading clean for epoch 12 -[Trainer] Reading clean for epoch 13 -[Trainer] Reading clean for epoch 14 -[Trainer] Reading clean for epoch 15 -[Trainer] Reading clean for epoch 16 -[Trainer] Reading clean for epoch 17 -[Trainer] Reading clean for epoch 18 -[Trainer] Reading clean for epoch 19 -[Trainer] waiting for trainer to exit. Press ctrl-c to be more aggressive diff --git a/contrib/test-data/test_zhen_config.expected.out b/contrib/test-data/test_zhen_config.expected.out index 3c5554a..239f528 100644 --- a/contrib/test-data/test_zhen_config.expected.out +++ b/contrib/test-data/test_zhen_config.expected.out @@ -98,15 +98,3 @@ SIR S 标准成立于 1992年, 是美国胸科医师学会 / 重症监护医学 2020年 1月 20日, 华盛顿州确诊首例 COVID - 19 患者。 1月 29日, 成立白宫冠状病毒工作组。 1月 31日, 特朗普政府 __source__ 宣布 __target__ declared __done__ 进入公共卫生紧急状态。 On 30 January, the WHO declared a Public Health Emergency of International Concern and on January 31, the Trump administration declared a public health emergency, and placed travel restrictions on entry for travellers from China. 同日, 一个曾为钻石公主号邮轮乘客的 78 岁男性老人宣布死亡, 为澳大利亚首例因感染 COVID - 19 死亡的 __source__ 病例 __target__ reported __done__ 。他曾在西澳大利亚州 __source__ Sir __target__ the __done__ Charles Gairdner Hospital 治疗。 On 1 March, Australia reported the first death from COVID-19: a 78-year-old Perth man, who was one of the passengers from the Diamond Princess, and who had been evacuated and was being treated in Western Australia. 429 运输中队 (429 野牛), 使用 CC - 177 429 Transport Squadron (429 Bison Squadron) - Flying the CC-177 -[Trainer] Starting stage start -[Trainer] Reading clean for epoch 0 -[Trainer] Reading clean for epoch 1 -[Trainer] Reading clean for epoch 2 -[Trainer] Reading clean for epoch 3 -[Trainer] Reading clean for epoch 4 -[Trainer] Reading clean for epoch 5 -[Trainer] Reading clean for epoch 6 -[Trainer] Reading clean for epoch 7 -[Trainer] Reading clean for epoch 8 -[Trainer] Reading clean for epoch 9 -[Trainer] waiting for trainer to exit. Press ctrl-c to be more aggressive diff --git a/contrib/test-data/test_zhen_config_prefix.expected.out b/contrib/test-data/test_zhen_config_prefix.expected.out index e47d34f..0a8cb00 100644 --- a/contrib/test-data/test_zhen_config_prefix.expected.out +++ b/contrib/test-data/test_zhen_config_prefix.expected.out @@ -98,15 +98,3 @@ Together At Home (also Known As One World: Together At Home) Was A Virtual Conce Together at Home (also known as One World: Together at Home) was a virtual concert series organised by Global Citizen and curated by singer Lady Gaga, in support of the World Health Organization. 同一个 世界 : 共同 在家 ( 英语 : One World : Together at Home) , 是 于 2020年 4月 18日 举行 的 虚拟 系列 演唱会 , 推广 在 2019 冠状 病毒 病 疫情 期间 保持 社交 距离 等 防疫 理念 , 由 全球 公民 和 歌手 嘎嘎 小姐 共同 组织 发起 , 以 支持 世界 卫生 组织 。 0-3 2-4 3-5 4-31 6-8 7-7 7-10 8-11 9-12 11-15 11-16 12-23 13-22 14-24 15-23 16-49 17-41 18-42 19-43 20-44 21-41 22-41 23-45 24-46 25-46 25-47 26-52 27-53 29-51 30-54 31-55 32-56 32-57 __start__ 虚拟 系列 __end__ Together at Home (also known as One World: Together at Home) was a virtual concert series organised by Global Citizen and curated by singer Lady Gaga, in support of the World Health Organization. 同一个 世界 : 共同 在家 ( 英语 : One World : Together at Home) , 是 于 2020年 4月 18日 举行 的 虚拟 系列 演唱会 , 推广 在 2019 冠状 病毒 病 疫情 期间 保持 社交 距离 等 防疫 理念 , 由 全球 公民 和 歌手 嘎嘎 小姐 共同 组织 发起 , 以 支持 世界 卫生 组织 。 0-3 2-4 3-5 4-31 6-8 7-7 7-10 8-11 9-12 11-15 11-16 12-23 13-22 14-24 15-23 16-49 17-41 18-42 19-43 20-44 21-41 22-41 23-45 24-46 25-46 25-47 26-52 27-53 29-51 30-54 31-55 32-56 32-57 On 20 September 2019, the Australian Olympic Committee announced the first set of sailors selected for Tokyo 2020, namely Rio 2016 silver medalists and defending world 470 champions Mathew Belcher and William Ryan and world's current top-ranked Laser sailor Matthew Wearn. 2019年 9月 20日 , 澳大利亚 奥林匹克 委员会 公布 了 第一 批 入选 奥运 阵容 的 帆船 选手 名单 , 名单 中 包括 2016年 里约 奥运 银牌 得主 Mathew Belcher 和 William Ryan 。 2020年 2月 27日 , 第二 批 入选 奥运 阵容 的 帆船 选手 名单 正式 公布 。 2020年 3月 19日 , Mara Stransky 确认 获得 代表 澳大利亚 参加 女子 辐射型 的 资格 , 成为 第三 批 入选 的 帆船 选手 。 0-2 0-3 1-2 2-1 3-0 4-8 5-4 6-5 7-6 8-7 10-9 10-10 11-10 13-16 14-39 17-33 19-23 20-22 21-25 22-13 23-29 28-27 29-53 31-30 32-31 34-65 36-61 36-62 37-69 40-70 40-71 -[Trainer] Starting stage start -[Trainer] Reading clean for epoch 0 -[Trainer] Reading clean for epoch 1 -[Trainer] Reading clean for epoch 2 -[Trainer] Reading clean for epoch 3 -[Trainer] Reading clean for epoch 4 -[Trainer] Reading clean for epoch 5 -[Trainer] Reading clean for epoch 6 -[Trainer] Reading clean for epoch 7 -[Trainer] Reading clean for epoch 8 -[Trainer] Reading clean for epoch 9 -[Trainer] waiting for trainer to exit. Press ctrl-c to be more aggressive diff --git a/tests/test_endtoend.py b/tests/test_endtoend.py index 580a023..fee8a5a 100644 --- a/tests/test_endtoend.py +++ b/tests/test_endtoend.py @@ -6,28 +6,32 @@ class TestEndToEnd(unittest.TestCase): '''Tests the pipeline end-to-end. Aimed to to test the parser.''' def test_full_enzh(self): - output: str = subprocess.check_output([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_config.yml', '-d', '--sync'], encoding="utf-8") + process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_config.yml', '-d', '--sync'], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") + output = process.stdout reference: str = "" with open('contrib/test-data/test_enzh_config.expected.out', 'r', encoding='utf-8') as reffile: reference: str = "".join(reffile.readlines()) self.assertEqual(output, reference) def test_full_zhen(self): - output: str = subprocess.check_output([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_zhen_config.yml', '-d', '--sync'], encoding="utf-8") + process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_zhen_config.yml', '-d', '--sync'], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") + output = process.stdout reference: str = "" with open('contrib/test-data/test_zhen_config.expected.out', 'r', encoding='utf-8') as reffile: reference: str = "".join(reffile.readlines()) self.assertEqual(output, reference) def test_prefix_augment(self): - output: str = subprocess.check_output([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_zhen_prefix_config.yml', '-d', '--sync'], encoding="utf-8") + process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_zhen_prefix_config.yml', '-d', '--sync'], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") + output = process.stdout reference: str = "" with open('contrib/test-data/test_zhen_config_prefix.expected.out', 'r', encoding='utf-8') as reffile: reference: str = "".join(reffile.readlines()) self.assertEqual(output, reference) def test_no_shuffle(self): - output: str = subprocess.check_output([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_config_plain.yml', '-d', '-n'], encoding="utf-8") + process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_config_plain.yml', '-d', '-n'], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") + output = process.stdout reference: str = "" with open('contrib/test-data/clean.enzh.10', 'r', encoding='utf-8') as reffile: reference: str = "".join(reffile.readlines()) @@ -41,14 +45,16 @@ def test_no_shuffle(self): self.assertEqual(output_arr[i], reference_arr[i]) def test_advanced_config(self): - output: str = subprocess.check_output([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_tags_advanced_config.yml', '-d', '-n'], encoding="utf-8") + process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_tags_advanced_config.yml', '-d', '-n'], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") + output = process.stdout reference: str = "" with open('contrib/test-data/test_enzh_tags_advanced_config.expected.out', 'r', encoding='utf-8') as reffile: reference: str = "".join(reffile.readlines()) self.assertEqual(output, reference) def test_stage_config(self): - output: str = subprocess.check_output([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_tags_stage_config.yml', '-d', '-n'], encoding="utf-8") + process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_tags_stage_config.yml', '-d', '-n'], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") + output = process.stdout reference: str = "" with open('contrib/test-data/test_enzh_tags_stage_config.expected.out', 'r', encoding='utf-8') as reffile: reference: str = "".join(reffile.readlines()) From f994a74934cf6d99ef7ef527ee38120c3810aed8 Mon Sep 17 00:00:00 2001 From: XapaJIaMnu Date: Tue, 4 Jul 2023 02:18:13 +0100 Subject: [PATCH 04/12] Make typing happy --- src/opustrainer/logger.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/opustrainer/logger.py b/src/opustrainer/logger.py index c566724..136710a 100644 --- a/src/opustrainer/logger.py +++ b/src/opustrainer/logger.py @@ -1,6 +1,7 @@ +from io import TextIOWrapper import logging from sys import stderr -from typing import List +from typing import List, TextIO from functools import lru_cache def getLogLevel(name: str) -> int: @@ -26,7 +27,8 @@ def log_once(msg: str, loglevel: str = "INFO") -> None: def setup_logger(outputfilename: str | None = None, loglevel: str = "INFO") -> None: """Sets up the logger with the necessary settings. Outputs to both file and stderr""" loggingformat = '[%(asctime)s] [Trainer] [%(levelname)s] %(message)s' - handlers = [logging.StreamHandler(stream=stderr)] + handlers: List[logging.StreamHandler[TextIO] | logging.StreamHandler[TextIOWrapper]] = [] + handlers.append(logging.StreamHandler(stream=stderr)) if outputfilename is not None: handlers.append(logging.FileHandler(filename=outputfilename)) logging.basicConfig(handlers=handlers, encoding='utf-8', level=getLogLevel(loglevel), format=loggingformat, datefmt='%Y-%m-%d %H:%M:%S') From e08d60598fc16d082b5d1743daef3e0fbda19509 Mon Sep 17 00:00:00 2001 From: XapaJIaMnu Date: Tue, 4 Jul 2023 02:50:51 +0100 Subject: [PATCH 05/12] Log tests --- .../test_enzh_config_plain_expected.log | 12 ++++++++ tests/test_logger.py | 28 +++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 contrib/test-data/test_enzh_config_plain_expected.log create mode 100644 tests/test_logger.py diff --git a/contrib/test-data/test_enzh_config_plain_expected.log b/contrib/test-data/test_enzh_config_plain_expected.log new file mode 100644 index 0000000..ea33b7b --- /dev/null +++ b/contrib/test-data/test_enzh_config_plain_expected.log @@ -0,0 +1,12 @@ +[2023-07-04 02:48:11] [Trainer] [INFO] Starting stage start +[2023-07-04 02:48:11] [Trainer] [INFO] Reading clean for epoch 0 +[2023-07-04 02:48:11] [Trainer] [INFO] Reading clean for epoch 1 +[2023-07-04 02:48:11] [Trainer] [INFO] Reading clean for epoch 2 +[2023-07-04 02:48:11] [Trainer] [INFO] Reading clean for epoch 3 +[2023-07-04 02:48:12] [Trainer] [INFO] Reading clean for epoch 4 +[2023-07-04 02:48:12] [Trainer] [INFO] Reading clean for epoch 5 +[2023-07-04 02:48:12] [Trainer] [INFO] Reading clean for epoch 6 +[2023-07-04 02:48:12] [Trainer] [INFO] Reading clean for epoch 7 +[2023-07-04 02:48:12] [Trainer] [INFO] Reading clean for epoch 8 +[2023-07-04 02:48:12] [Trainer] [INFO] Reading clean for epoch 9 +[2023-07-04 02:48:12] [Trainer] [INFO] waiting for trainer to exit. Press ctrl-c to be more aggressive diff --git a/tests/test_logger.py b/tests/test_logger.py new file mode 100644 index 0000000..2fa2ced --- /dev/null +++ b/tests/test_logger.py @@ -0,0 +1,28 @@ +import sys +import unittest +import subprocess +import tempfile + +from typing import List + +class TestLogger(unittest.TestCase): + '''Tests the logger using.''' + def test_file_and_stderr(self): + with tempfile.NamedTemporaryFile(suffix='.log', prefix="opustrainer") as tmpfile: + process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_config_plain.yml', '-d', '-l', tmpfile.name], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") + log = process.stderr + # Check that the stderr output is the same the output written to the file + with open(tmpfile.name, 'r', encoding='utf-8') as logfile: + file_output: str = "".join(logfile.readlines()) + self.assertEqual(log, file_output) + # Check if the log is the same as the reference log. This is more complicated as we have a time field + with open('contrib/test-data/test_enzh_config_plain_expected.log', 'r', encoding='utf-8') as reffile: + reference: List[str] = reffile.readlines() + loglist: List[str] = log.split('\n') + # Strip the time field and test + for i in range(len(reference)): + ref = reference[i].split('[Trainer]')[1] + ref = ref.strip('\n') + logout = loglist[i].split('[Trainer]')[1] + self.assertEqual(logout, ref) + From c80581eeef946d0fbbb638d8da9bd9a0d824b4b0 Mon Sep 17 00:00:00 2001 From: XapaJIaMnu Date: Tue, 4 Jul 2023 02:54:26 +0100 Subject: [PATCH 06/12] Extra comment in test --- tests/test_logger.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_logger.py b/tests/test_logger.py index 2fa2ced..db30ddf 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -18,6 +18,7 @@ def test_file_and_stderr(self): # Check if the log is the same as the reference log. This is more complicated as we have a time field with open('contrib/test-data/test_enzh_config_plain_expected.log', 'r', encoding='utf-8') as reffile: reference: List[str] = reffile.readlines() + # Loglist has one extra `\n` compared to reference list, due to stderr flushing an extra empty line? loglist: List[str] = log.split('\n') # Strip the time field and test for i in range(len(reference)): From 451a22e41716b947471a3b4e15f80d0a60ea416e Mon Sep 17 00:00:00 2001 From: XapaJIaMnu Date: Tue, 4 Jul 2023 03:38:01 +0100 Subject: [PATCH 07/12] Move things around and test log_once functionality --- README.md | 21 ++++--- src/opustrainer/logger.py | 9 ++- src/opustrainer/trainer.py | 2 +- tests/test_endtoend.py | 123 ++++++++++++++++++++++--------------- tests/test_logger.py | 50 +++++++-------- 5 files changed, 120 insertions(+), 85 deletions(-) diff --git a/README.md b/README.md index 7aef3aa..11ca5a8 100644 --- a/README.md +++ b/README.md @@ -20,26 +20,33 @@ pip install -e . ## Usage ```bash -% ./trainer.py --help -usage: trainer.py [-h] --config CONFIG [--temporary-directory TEMPORARY_DIR] [--state STATE_FILE] [--do-not-resume] [--sync] [trainer-command [arguments]] +% opustrainer-train --help +usage: opustrainer-train [-h] --config CONFIG [--state STATE] [--sync] [--temporary-directory TEMPORARY_DIRECTORY] [--do-not-resume] [--no-shuffle] [--log-level LOG_LEVEL] [--log-file LOG_FILE] ... Feeds marian tsv data for training. +positional arguments: + trainer Trainer program that gets fed the input. If empty it is read from config. + options: -h, --help show this help message and exit --config CONFIG, -c CONFIG YML configuration input. - --temporary-directory TEMPORARY_DIR, -t TEMPORARY_DIR + --state STATE, -s STATE + YML state file, defaults to ${CONFIG}.state. + --sync Do not shuffle async + --temporary-directory TEMPORARY_DIRECTORY, -T TEMPORARY_DIRECTORY Temporary dir, used for shuffling and tracking state - --state STATE_FILE Path to trainer state file which stores how much of - each dataset has been read. Defaults to ${CONFIG}.state - --sync Do not shuffle in the background --do-not-resume, -d Do not resume from the previous training state --no-shuffle, -n Do not shuffle, for debugging + --log-level LOG_LEVEL + Set log level. Available levels: DEBUG, INFO, WARNING, ERROR, CRITICAL. Default is INFO + --log-file LOG_FILE, -l LOG_FILE + Target location for logging. Always logs to stderr and optionally to a file. ``` Once you fix the paths in the configuration file, `train_config.yml` you can run a test case by doing: ```bash -./trainer.py -c train_config.yml /path/to/marian -c marian_config --any --other --flags +opustrainer-train -c train_config.yml /path/to/marian -c marian_config --any --other --flags ``` You can check resulting mixed file in `/tmp/test`. If your neural network trainer doesn't support training from `stdin`, you can use this tool to generate a training dataset and then disable data reordering or shuffling at your trainer implementation, as your training input should be balanced. diff --git a/src/opustrainer/logger.py b/src/opustrainer/logger.py index 136710a..5b3dac6 100644 --- a/src/opustrainer/logger.py +++ b/src/opustrainer/logger.py @@ -24,11 +24,16 @@ def log_once(msg: str, loglevel: str = "INFO") -> None: log(msg, loglevel) -def setup_logger(outputfilename: str | None = None, loglevel: str = "INFO") -> None: +def setup_logger(outputfilename: str | None = None, loglevel: str = "INFO", disable_stderr: bool=False) -> None: """Sets up the logger with the necessary settings. Outputs to both file and stderr""" loggingformat = '[%(asctime)s] [Trainer] [%(levelname)s] %(message)s' handlers: List[logging.StreamHandler[TextIO] | logging.StreamHandler[TextIOWrapper]] = [] - handlers.append(logging.StreamHandler(stream=stderr)) + # disable_stderr is to be used only when testing the logger + # When testing the logger directly, we don't want to write to stderr, because in order to read + # our stderr output, we have to use redirect_stderr, which however makes all other tests spit + # as it interferes with unittest' own redirect_stderr. How nice. + if not disable_stderr: + handlers.append(logging.StreamHandler(stream=stderr)) if outputfilename is not None: handlers.append(logging.FileHandler(filename=outputfilename)) logging.basicConfig(handlers=handlers, encoding='utf-8', level=getLogLevel(loglevel), format=loggingformat, datefmt='%Y-%m-%d %H:%M:%S') diff --git a/src/opustrainer/trainer.py b/src/opustrainer/trainer.py index f1081db..ac28f42 100755 --- a/src/opustrainer/trainer.py +++ b/src/opustrainer/trainer.py @@ -728,7 +728,7 @@ def main() -> None: parser.add_argument("--temporary-directory", '-T', default=None, type=str, help='Temporary dir, used for shuffling and tracking state') parser.add_argument("--do-not-resume", '-d', action="store_true", help='Do not resume from the previous training state') parser.add_argument("--no-shuffle", '-n', action="store_false", help='Do not shuffle, for debugging', dest="shuffle") - parser.add_argument("--log-level", type=str, default="INFO", help="Set log level. Available levels: DEBUG, INFO, WARNING, ERROR, CRITICAL") + parser.add_argument("--log-level", type=str, default="INFO", help="Set log level. Available levels: DEBUG, INFO, WARNING, ERROR, CRITICAL. Default is INFO") parser.add_argument("--log-file", '-l', type=str, default=None, help="Target location for logging. Always logs to stderr and optionally to a file.") parser.add_argument("trainer", type=str, nargs=argparse.REMAINDER, help="Trainer program that gets fed the input. If empty it is read from config.") diff --git a/tests/test_endtoend.py b/tests/test_endtoend.py index fee8a5a..edba55d 100644 --- a/tests/test_endtoend.py +++ b/tests/test_endtoend.py @@ -1,61 +1,82 @@ import sys import unittest import subprocess - +import tempfile +from typing import List class TestEndToEnd(unittest.TestCase): - '''Tests the pipeline end-to-end. Aimed to to test the parser.''' - def test_full_enzh(self): - process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_config.yml', '-d', '--sync'], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") - output = process.stdout - reference: str = "" - with open('contrib/test-data/test_enzh_config.expected.out', 'r', encoding='utf-8') as reffile: - reference: str = "".join(reffile.readlines()) - self.assertEqual(output, reference) + '''Tests the pipeline end-to-end. Aimed to to test the parser.''' + def test_full_enzh(self): + process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_config.yml', '-d', '--sync'], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") + output = process.stdout + reference: str = "" + with open('contrib/test-data/test_enzh_config.expected.out', 'r', encoding='utf-8') as reffile: + reference: str = "".join(reffile.readlines()) + self.assertEqual(output, reference) + + def test_full_zhen(self): + process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_zhen_config.yml', '-d', '--sync'], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") + output = process.stdout + reference: str = "" + with open('contrib/test-data/test_zhen_config.expected.out', 'r', encoding='utf-8') as reffile: + reference: str = "".join(reffile.readlines()) + self.assertEqual(output, reference) - def test_full_zhen(self): - process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_zhen_config.yml', '-d', '--sync'], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") - output = process.stdout - reference: str = "" - with open('contrib/test-data/test_zhen_config.expected.out', 'r', encoding='utf-8') as reffile: - reference: str = "".join(reffile.readlines()) - self.assertEqual(output, reference) + def test_prefix_augment(self): + process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_zhen_prefix_config.yml', '-d', '--sync'], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") + output = process.stdout + reference: str = "" + with open('contrib/test-data/test_zhen_config_prefix.expected.out', 'r', encoding='utf-8') as reffile: + reference: str = "".join(reffile.readlines()) + self.assertEqual(output, reference) - def test_prefix_augment(self): - process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_zhen_prefix_config.yml', '-d', '--sync'], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") - output = process.stdout - reference: str = "" - with open('contrib/test-data/test_zhen_config_prefix.expected.out', 'r', encoding='utf-8') as reffile: - reference: str = "".join(reffile.readlines()) - self.assertEqual(output, reference) + def test_no_shuffle(self): + process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_config_plain.yml', '-d', '-n'], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") + output = process.stdout + reference: str = "" + with open('contrib/test-data/clean.enzh.10', 'r', encoding='utf-8') as reffile: + reference: str = "".join(reffile.readlines()) + # Since we read 100 lines at a time, we wrap. Often. + # Hence, for the test to pass we need to read the number of lines in the test file + reference_arr = reference.split('\n') + output_arr = output.split('\n') + for i in range(len(reference_arr)): + # Skip final empty newline + if reference_arr[i] != '': + self.assertEqual(output_arr[i], reference_arr[i]) - def test_no_shuffle(self): - process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_config_plain.yml', '-d', '-n'], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") - output = process.stdout - reference: str = "" - with open('contrib/test-data/clean.enzh.10', 'r', encoding='utf-8') as reffile: - reference: str = "".join(reffile.readlines()) - # Since we read 100 lines at a time, we wrap. Often. - # Hence, for the test to pass we need to read the number of lines in the test file - reference_arr = reference.split('\n') - output_arr = output.split('\n') - for i in range(len(reference_arr)): - # Skip final empty newline - if reference_arr[i] != '': - self.assertEqual(output_arr[i], reference_arr[i]) + def test_advanced_config(self): + process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_tags_advanced_config.yml', '-d', '-n'], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") + output = process.stdout + reference: str = "" + with open('contrib/test-data/test_enzh_tags_advanced_config.expected.out', 'r', encoding='utf-8') as reffile: + reference: str = "".join(reffile.readlines()) + self.assertEqual(output, reference) - def test_advanced_config(self): - process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_tags_advanced_config.yml', '-d', '-n'], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") - output = process.stdout - reference: str = "" - with open('contrib/test-data/test_enzh_tags_advanced_config.expected.out', 'r', encoding='utf-8') as reffile: - reference: str = "".join(reffile.readlines()) - self.assertEqual(output, reference) + def test_stage_config(self): + process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_tags_stage_config.yml', '-d', '-n'], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") + output = process.stdout + reference: str = "" + with open('contrib/test-data/test_enzh_tags_stage_config.expected.out', 'r', encoding='utf-8') as reffile: + reference: str = "".join(reffile.readlines()) + self.assertEqual(output, reference) - def test_stage_config(self): - process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_tags_stage_config.yml', '-d', '-n'], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") - output = process.stdout - reference: str = "" - with open('contrib/test-data/test_enzh_tags_stage_config.expected.out', 'r', encoding='utf-8') as reffile: - reference: str = "".join(reffile.readlines()) - self.assertEqual(output, reference) + def test_log_file_and_stderr(self): + with tempfile.NamedTemporaryFile(suffix='.log', prefix="opustrainer") as tmpfile: + process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_config_plain.yml', '-d', '-l', tmpfile.name], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") + log = process.stderr + # Check that the stderr output is the same the output written to the file + with open(tmpfile.name, 'r', encoding='utf-8') as logfile: + file_output: str = "".join(logfile.readlines()) + self.assertEqual(log, file_output) + # Check if the log is the same as the reference log. This is more complicated as we have a time field + with open('contrib/test-data/test_enzh_config_plain_expected.log', 'r', encoding='utf-8') as reffile: + reference: List[str] = reffile.readlines() + # Loglist has one extra `\n` compared to reference list, due to stderr flushing an extra empty line? + loglist: List[str] = log.split('\n') + # Strip the time field and test + for i in range(len(reference)): + ref = reference[i].split('[Trainer]')[1] + ref = ref.strip('\n') + logout = loglist[i].split('[Trainer]')[1] + self.assertEqual(logout, ref) diff --git a/tests/test_logger.py b/tests/test_logger.py index db30ddf..56b8293 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,29 +1,31 @@ -import sys import unittest -import subprocess import tempfile -from typing import List +from opustrainer import logger class TestLogger(unittest.TestCase): - '''Tests the logger using.''' - def test_file_and_stderr(self): - with tempfile.NamedTemporaryFile(suffix='.log', prefix="opustrainer") as tmpfile: - process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_config_plain.yml', '-d', '-l', tmpfile.name], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") - log = process.stderr - # Check that the stderr output is the same the output written to the file - with open(tmpfile.name, 'r', encoding='utf-8') as logfile: - file_output: str = "".join(logfile.readlines()) - self.assertEqual(log, file_output) - # Check if the log is the same as the reference log. This is more complicated as we have a time field - with open('contrib/test-data/test_enzh_config_plain_expected.log', 'r', encoding='utf-8') as reffile: - reference: List[str] = reffile.readlines() - # Loglist has one extra `\n` compared to reference list, due to stderr flushing an extra empty line? - loglist: List[str] = log.split('\n') - # Strip the time field and test - for i in range(len(reference)): - ref = reference[i].split('[Trainer]')[1] - ref = ref.strip('\n') - logout = loglist[i].split('[Trainer]')[1] - self.assertEqual(logout, ref) - + '''Tests the logger using end-to-end config.''' + def test_log_once(self): + '''Tests the log_once functionality''' + with tempfile.NamedTemporaryFile(suffix='.log', prefix="logger") as tmpfile: + logger.setup_logger(outputfilename=tmpfile.name, disable_stderr=True) + logger.log("Test message") + logger.log_once("Once message") + logger.log_once("Once message") + logger.log("Test message") + logger.log_once("Once message") + logger.log_once("Once message2") + logger.log_once("Once message2") + logger.log("Final message") + logger.logging.shutdown() + with open(tmpfile.name, 'r', encoding='utf-8') as reffile: + line1 = reffile.readline().strip().split(' [Trainer] ')[1] + self.assertEqual(line1,"[INFO] Test message") + line2 = reffile.readline().strip().split(' [Trainer] ')[1] + self.assertEqual(line2,"[INFO] Once message") + line3 = reffile.readline().strip().split(' [Trainer] ')[1] + self.assertEqual(line3,"[INFO] Test message") + line4 = reffile.readline().strip().split(' [Trainer] ')[1] + self.assertEqual(line4,"[INFO] Once message2") + line5 = reffile.readline().strip().split(' [Trainer] ')[1] + self.assertEqual(line5,"[INFO] Final message") From 04937ee77502f1f5db4dbabc7a70136992189e6a Mon Sep 17 00:00:00 2001 From: XapaJIaMnu Date: Tue, 4 Jul 2023 04:47:18 +0100 Subject: [PATCH 08/12] Make tests more proper and remove warnings package usage --- src/opustrainer/logger.py | 8 +++++-- src/opustrainer/modifiers/placeholders.py | 5 ++--- tests/test_endtoend.py | 4 ++-- tests/test_logger.py | 23 ++++++++++--------- tests/test_placeholders.py | 27 ++++++++++++++--------- 5 files changed, 39 insertions(+), 28 deletions(-) diff --git a/src/opustrainer/logger.py b/src/opustrainer/logger.py index 5b3dac6..c365485 100644 --- a/src/opustrainer/logger.py +++ b/src/opustrainer/logger.py @@ -32,9 +32,13 @@ def setup_logger(outputfilename: str | None = None, loglevel: str = "INFO", disa # When testing the logger directly, we don't want to write to stderr, because in order to read # our stderr output, we have to use redirect_stderr, which however makes all other tests spit # as it interferes with unittest' own redirect_stderr. How nice. + # This happens even when assertLogs context capture is used. if not disable_stderr: handlers.append(logging.StreamHandler(stream=stderr)) if outputfilename is not None: handlers.append(logging.FileHandler(filename=outputfilename)) - logging.basicConfig(handlers=handlers, encoding='utf-8', level=getLogLevel(loglevel), format=loggingformat, datefmt='%Y-%m-%d %H:%M:%S') - + # This is the only logger we'd ever use. However during testing, due to the context, logger can't be recreated, + # even if it has already been shutdown. This is why we use force=True to force recreation of logger so we can + # properly run our tests. Not the best solution, not sure if it's not prone to race conditions, but it is + # at the very least safe to use for the actual software running + logging.basicConfig(handlers=handlers, encoding='utf-8', level=getLogLevel(loglevel), format=loggingformat, datefmt='%Y-%m-%d %H:%M:%S', force=True) diff --git a/src/opustrainer/modifiers/placeholders.py b/src/opustrainer/modifiers/placeholders.py index a4340df..45448cb 100644 --- a/src/opustrainer/modifiers/placeholders.py +++ b/src/opustrainer/modifiers/placeholders.py @@ -1,11 +1,10 @@ import random from operator import itemgetter from typing import Set, List, Tuple, Optional, Protocol, TypeVar, Iterable -from warnings import warn - from sacremoses import MosesDetokenizer from opustrainer.modifiers import Modifier +from opustrainer import logger T = TypeVar('T') @@ -342,4 +341,4 @@ def validate(self, context:List[Modifier]) -> None: inserted tags, which we don't want. So warn users about that if we notice it. """ if context[-1] != self: - warn('Tags modifier should to be the last modifier to be applied, as otherwise other modifiers might alter the inserted tags themselves.') + logger.log('Tags modifier should to be the last modifier to be applied, as otherwise other modifiers might alter the inserted tags themselves.', loglevel="WARNING") diff --git a/tests/test_endtoend.py b/tests/test_endtoend.py index edba55d..4167c84 100644 --- a/tests/test_endtoend.py +++ b/tests/test_endtoend.py @@ -66,8 +66,8 @@ def test_log_file_and_stderr(self): process = subprocess.run([sys.executable, '-m', 'opustrainer', '-c', 'contrib/test_enzh_config_plain.yml', '-d', '-l', tmpfile.name], stdout = subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") log = process.stderr # Check that the stderr output is the same the output written to the file - with open(tmpfile.name, 'r', encoding='utf-8') as logfile: - file_output: str = "".join(logfile.readlines()) + tmpfile.seek(0) + file_output: str = "".join([line.decode('utf-8') for line in tmpfile.readlines()]) self.assertEqual(log, file_output) # Check if the log is the same as the reference log. This is more complicated as we have a time field with open('contrib/test-data/test_enzh_config_plain_expected.log', 'r', encoding='utf-8') as reffile: diff --git a/tests/test_logger.py b/tests/test_logger.py index 56b8293..c67f410 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -18,14 +18,15 @@ def test_log_once(self): logger.log_once("Once message2") logger.log("Final message") logger.logging.shutdown() - with open(tmpfile.name, 'r', encoding='utf-8') as reffile: - line1 = reffile.readline().strip().split(' [Trainer] ')[1] - self.assertEqual(line1,"[INFO] Test message") - line2 = reffile.readline().strip().split(' [Trainer] ')[1] - self.assertEqual(line2,"[INFO] Once message") - line3 = reffile.readline().strip().split(' [Trainer] ')[1] - self.assertEqual(line3,"[INFO] Test message") - line4 = reffile.readline().strip().split(' [Trainer] ')[1] - self.assertEqual(line4,"[INFO] Once message2") - line5 = reffile.readline().strip().split(' [Trainer] ')[1] - self.assertEqual(line5,"[INFO] Final message") + # Now read + tmpfile.seek(0) + line1 = tmpfile.readline().decode('utf-8').strip().split(' [Trainer] ')[1] + self.assertEqual(line1,"[INFO] Test message") + line2 = tmpfile.readline().decode('utf-8').strip().split(' [Trainer] ')[1] + self.assertEqual(line2,"[INFO] Once message") + line3 = tmpfile.readline().decode('utf-8').strip().split(' [Trainer] ')[1] + self.assertEqual(line3,"[INFO] Test message") + line4 = tmpfile.readline().decode('utf-8').strip().split(' [Trainer] ')[1] + self.assertEqual(line4,"[INFO] Once message2") + line5 = tmpfile.readline().decode('utf-8').strip().split(' [Trainer] ')[1] + self.assertEqual(line5,"[INFO] Final message") diff --git a/tests/test_placeholders.py b/tests/test_placeholders.py index e38246d..5fdda3c 100644 --- a/tests/test_placeholders.py +++ b/tests/test_placeholders.py @@ -1,10 +1,12 @@ import random import unittest +import tempfile from textwrap import dedent from opustrainer.modifiers.placeholders import PlaceholderTagModifier from opustrainer.trainer import CurriculumLoader +from opustrainer import logger class TestTagger(unittest.TestCase): @@ -77,13 +79,18 @@ def test_tagger_zh_src_augment_replace(self): self.assertEqual(test, ref) def test_warn_if_tag_modifier_is_not_last(self): - with self.assertWarnsRegex(UserWarning, r'Tags modifier should to be the last modifier to be applied'): - loader = CurriculumLoader() - loader.load(dedent(""" - datasets: {} - stages: [] - seed: 1 - modifiers: - - Tags: 1.0 - - UpperCase: 1.0 - """)) + with tempfile.NamedTemporaryFile(suffix='.log', prefix="placeholder") as tmpfile: + logger.setup_logger(outputfilename=tmpfile.name, disable_stderr=True) + loader = CurriculumLoader() + loader.load(dedent(""" + datasets: {} + stages: [] + seed: 1 + modifiers: + - Tags: 1.0 + - UpperCase: 1.0 + """)) + logger.logging.shutdown() + tmpfile.seek(0) + warning = tmpfile.readline().decode('utf-8') + self.assertRegex(warning, r"Tags modifier should to be the last modifier to be applied") From c45fdc83ea550638d8e398eb04829e89362bcf18 Mon Sep 17 00:00:00 2001 From: XapaJIaMnu Date: Tue, 4 Jul 2023 04:52:51 +0100 Subject: [PATCH 09/12] Test the log level as well --- tests/test_placeholders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_placeholders.py b/tests/test_placeholders.py index 5fdda3c..ecd0b5f 100644 --- a/tests/test_placeholders.py +++ b/tests/test_placeholders.py @@ -93,4 +93,5 @@ def test_warn_if_tag_modifier_is_not_last(self): logger.logging.shutdown() tmpfile.seek(0) warning = tmpfile.readline().decode('utf-8') + self.assertRegex(warning, r"WARNING") self.assertRegex(warning, r"Tags modifier should to be the last modifier to be applied") From 80b3252ea8752df95c43a2d2ac8ac9531e64a6f5 Mon Sep 17 00:00:00 2001 From: XapaJIaMnu Date: Tue, 4 Jul 2023 12:28:56 +0100 Subject: [PATCH 10/12] Make it workable on older python --- src/opustrainer/logger.py | 31 +++++++++++++++++++++++-------- src/opustrainer/trainer.py | 2 +- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/opustrainer/logger.py b/src/opustrainer/logger.py index c365485..4bb0676 100644 --- a/src/opustrainer/logger.py +++ b/src/opustrainer/logger.py @@ -1,20 +1,35 @@ from io import TextIOWrapper import logging from sys import stderr -from typing import List, TextIO +from typing import List, Dict, TextIO, Union from functools import lru_cache -def getLogLevel(name: str) -> int: +def _getLevelNamesMapping() -> Dict[str,int]: + '''getLevelNamesMapping only available in python 3.11+''' + if 'getLevelNamesMapping' in logging.__dict__: + return logging.getLevelNamesMapping() + else: + return {'CRITICAL': 50, + 'FATAL': 50, + 'ERROR': 40, + 'WARN': 30, + 'WARNING': 30, + 'INFO': 20, + 'DEBUG': 10, + 'NOTSET': 0} + +@lru_cache(None) +def get_log_level(name: str) -> int: """Incredibly, i can't find a function that will do this conversion, other than setLevel, but setLevel doesn't work for calling the different log level logs.""" - if name.upper() in logging.getLevelNamesMapping(): - return logging.getLevelNamesMapping()[name.upper()] + if name.upper() in _getLevelNamesMapping(): + return _getLevelNamesMapping()[name.upper()] else: logging.log(logging.WARNING, "unknown log level level used: " + name + " assuming warning...") return logging.WARNING def log(msg: str, loglevel: str = "INFO") -> None: - level = getLogLevel(loglevel) + level = get_log_level(loglevel) logging.log(level, msg) @@ -24,10 +39,10 @@ def log_once(msg: str, loglevel: str = "INFO") -> None: log(msg, loglevel) -def setup_logger(outputfilename: str | None = None, loglevel: str = "INFO", disable_stderr: bool=False) -> None: +def setup_logger(outputfilename: Union[str, None] = None, loglevel: str = "INFO", disable_stderr: bool=False) -> None: """Sets up the logger with the necessary settings. Outputs to both file and stderr""" loggingformat = '[%(asctime)s] [Trainer] [%(levelname)s] %(message)s' - handlers: List[logging.StreamHandler[TextIO] | logging.StreamHandler[TextIOWrapper]] = [] + handlers: List[Union[logging.StreamHandler[TextIO], logging.StreamHandler[TextIOWrapper]]] = [] # disable_stderr is to be used only when testing the logger # When testing the logger directly, we don't want to write to stderr, because in order to read # our stderr output, we have to use redirect_stderr, which however makes all other tests spit @@ -41,4 +56,4 @@ def setup_logger(outputfilename: str | None = None, loglevel: str = "INFO", disa # even if it has already been shutdown. This is why we use force=True to force recreation of logger so we can # properly run our tests. Not the best solution, not sure if it's not prone to race conditions, but it is # at the very least safe to use for the actual software running - logging.basicConfig(handlers=handlers, encoding='utf-8', level=getLogLevel(loglevel), format=loggingformat, datefmt='%Y-%m-%d %H:%M:%S', force=True) + logging.basicConfig(handlers=handlers, encoding='utf-8', level=get_log_level(loglevel), format=loggingformat, datefmt='%Y-%m-%d %H:%M:%S', force=True) diff --git a/src/opustrainer/trainer.py b/src/opustrainer/trainer.py index ac28f42..0698e1f 100755 --- a/src/opustrainer/trainer.py +++ b/src/opustrainer/trainer.py @@ -209,7 +209,7 @@ def __next__(self): # assert that the line is well formed, meaning non of the fields is the empty string # If not, try to get a new line from the corpus if any(field == '' for field in line.rstrip('\r\n').split('\t')): - print(f"[Trainer] Empty field in {self.dataset.name} line:\"{line}\", skipping...") + logger.log_once(f"[Trainer] Empty field in {self.dataset.name} line:\"{line}\", skipping...", loglevel="WARNING") continue return line From b9c854590d97fa1f809eb0b285687b382e0a9a5c Mon Sep 17 00:00:00 2001 From: Jelmer van der Linde Date: Tue, 4 Jul 2023 12:53:36 +0100 Subject: [PATCH 11/12] Backport to 3.8 --- src/opustrainer/logger.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/opustrainer/logger.py b/src/opustrainer/logger.py index 4bb0676..5fda68c 100644 --- a/src/opustrainer/logger.py +++ b/src/opustrainer/logger.py @@ -1,7 +1,7 @@ from io import TextIOWrapper import logging -from sys import stderr -from typing import List, Dict, TextIO, Union +from sys import stderr, version_info +from typing import List, Dict, TextIO, Union, Optional from functools import lru_cache def _getLevelNamesMapping() -> Dict[str,int]: @@ -25,7 +25,7 @@ def get_log_level(name: str) -> int: if name.upper() in _getLevelNamesMapping(): return _getLevelNamesMapping()[name.upper()] else: - logging.log(logging.WARNING, "unknown log level level used: " + name + " assuming warning...") + logging.log(logging.WARNING, f"unknown log level level used: {name} assuming warning...") return logging.WARNING def log(msg: str, loglevel: str = "INFO") -> None: @@ -39,7 +39,7 @@ def log_once(msg: str, loglevel: str = "INFO") -> None: log(msg, loglevel) -def setup_logger(outputfilename: Union[str, None] = None, loglevel: str = "INFO", disable_stderr: bool=False) -> None: +def setup_logger(outputfilename: Optional[str] = None, loglevel: str = "INFO", disable_stderr: bool=False) -> None: """Sets up the logger with the necessary settings. Outputs to both file and stderr""" loggingformat = '[%(asctime)s] [Trainer] [%(levelname)s] %(message)s' handlers: List[Union[logging.StreamHandler[TextIO], logging.StreamHandler[TextIOWrapper]]] = [] @@ -52,8 +52,14 @@ def setup_logger(outputfilename: Union[str, None] = None, loglevel: str = "INFO" handlers.append(logging.StreamHandler(stream=stderr)) if outputfilename is not None: handlers.append(logging.FileHandler(filename=outputfilename)) + + # Python 3.9 introduced an encoding argument + if version_info[:2] >= (3,9): + kwargs = {'encoding': 'utf-8'} + else: + kwargs = {} # This is the only logger we'd ever use. However during testing, due to the context, logger can't be recreated, # even if it has already been shutdown. This is why we use force=True to force recreation of logger so we can # properly run our tests. Not the best solution, not sure if it's not prone to race conditions, but it is # at the very least safe to use for the actual software running - logging.basicConfig(handlers=handlers, encoding='utf-8', level=get_log_level(loglevel), format=loggingformat, datefmt='%Y-%m-%d %H:%M:%S', force=True) + logging.basicConfig(handlers=handlers, level=get_log_level(loglevel), format=loggingformat, datefmt='%Y-%m-%d %H:%M:%S', force=True, **kwargs) From 6d0ff0445c8d69dca045331a9ddd4a0215a9bb19 Mon Sep 17 00:00:00 2001 From: XapaJIaMnu Date: Tue, 4 Jul 2023 13:36:13 +0100 Subject: [PATCH 12/12] Small formatting fix --- src/opustrainer/logger.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/opustrainer/logger.py b/src/opustrainer/logger.py index 5fda68c..328bc14 100644 --- a/src/opustrainer/logger.py +++ b/src/opustrainer/logger.py @@ -10,13 +10,13 @@ def _getLevelNamesMapping() -> Dict[str,int]: return logging.getLevelNamesMapping() else: return {'CRITICAL': 50, - 'FATAL': 50, - 'ERROR': 40, - 'WARN': 30, - 'WARNING': 30, - 'INFO': 20, - 'DEBUG': 10, - 'NOTSET': 0} + 'FATAL': 50, + 'ERROR': 40, + 'WARN': 30, + 'WARNING': 30, + 'INFO': 20, + 'DEBUG': 10, + 'NOTSET': 0} @lru_cache(None) def get_log_level(name: str) -> int: