diff --git a/silnlp/common/corpus.py b/silnlp/common/corpus.py index 440ba310..6606ee37 100644 --- a/silnlp/common/corpus.py +++ b/silnlp/common/corpus.py @@ -162,6 +162,26 @@ def exclude_books(corpus: pd.DataFrame, books: Set[int]) -> pd.DataFrame: return corpus[corpus.apply(lambda r: r["vref"].book_num not in books, axis=1)].copy() +def include_chapters(corpus: pd.DataFrame, books: dict) -> pd.DataFrame: + return corpus[ + corpus.apply( + lambda r: r["vref"].book_num in books + and (len(books[r["vref"].book_num]) == 0 or r["vref"].chapter_num in books[r["vref"].book_num]), + axis=1, + ) + ].copy() + + +def exclude_chapters(corpus: pd.DataFrame, books: dict) -> pd.DataFrame: + return corpus[ + corpus.apply( + lambda r: r["vref"].book_num not in books + or (len(books[r["vref"].book_num]) > 0 and r["vref"].chapter_num not in books[r["vref"].book_num]), + axis=1, + ) + ].copy() + + def get_terms_metadata_path(list_name: str, mt_terms_dir: Path = SIL_NLP_ENV.mt_terms_dir) -> Path: md_path = SIL_NLP_ENV.assets_dir / f"{list_name}-metadata.txt" if md_path.is_file(): diff --git a/silnlp/common/translator.py b/silnlp/common/translator.py index a3727ffe..44c3e479 100644 --- a/silnlp/common/translator.py +++ b/silnlp/common/translator.py @@ -190,6 +190,50 @@ def remove_inline_elements(doc: List[sfm.Element]) -> None: remove_inline_elements_from_element(root) +def insert_translation_into_trg_sentences( + sentences: List[str], + vrefs: List[VerseRef], + trg_sentences: List[str], + trg_vrefs: List[VerseRef], + chapters: List[int], +) -> List[str]: + ret = [""] * len(trg_sentences) + translation_idx = 0 + for i in range(len(trg_sentences)): + if trg_vrefs[i].chapter_num not in chapters: + ret[i] = trg_sentences[i] + continue + # Skip over rest of verse since the whole verse is put into the first entry + if ( + i > 0 + and trg_vrefs[i].chapter_num == trg_vrefs[i - 1].chapter_num + and trg_vrefs[i].verse_num == trg_vrefs[i - 1].verse_num + ): + continue + # If translation_idx gets behind, catch up + while translation_idx < len(sentences) and ( + trg_vrefs[i].chapter_num > vrefs[translation_idx].chapter_num + or ( + trg_vrefs[i].chapter_num == vrefs[translation_idx].chapter_num + and trg_vrefs[i].verse_num > vrefs[translation_idx].verse_num + ) + ): + translation_idx += 1 + + # Put all parts of the translated verse into the first entry for that verse + while ( + translation_idx < len(sentences) + and vrefs[translation_idx].chapter_num == trg_vrefs[i].chapter_num + and vrefs[translation_idx].verse_num == trg_vrefs[i].verse_num + ): + if ret[i] != "": + ret[i] += " " + ret[i] += sentences[translation_idx] + translation_idx += 1 + + return ret + + class Translator(ABC): @abstractmethod def translate( @@ -201,7 +245,14 @@ def translate_text(self, src_file_path: Path, trg_file_path: Path, src_iso: str, write_corpus(trg_file_path, self.translate(load_corpus(src_file_path), src_iso, trg_iso)) def translate_book( - self, src_project: str, book: str, output_path: Path, trg_iso: str, include_inline_elements: bool = False + self, + src_project: str, + book: str, + output_path: Path, + trg_iso: str, + chapters: List[int] = [], + trg_project: str = "", + include_inline_elements: bool = False, ) -> None: src_project_dir = get_project_dir(src_project) with (src_project_dir / "Settings.xml").open("rb") as settings_file: @@ -209,11 +260,15 @@ def translate_book( src_iso = get_iso(settings_tree) book_path = get_book_path(src_project, book) stylesheet = get_stylesheet(src_project_dir) + if not book_path.is_file(): raise RuntimeError(f"Can't find file {book_path} for book {book}") else: LOGGER.info(f"Found the file {book_path} for book {book}") - self.translate_usfm(book_path, output_path, src_iso, trg_iso, stylesheet, include_inline_elements) + + self.translate_usfm( + book_path, output_path, src_iso, trg_iso, chapters, trg_project, stylesheet, include_inline_elements + ) def translate_usfm( self, @@ -221,6 +276,8 @@ def translate_usfm( trg_file_path: Path, src_iso: str, trg_iso: str, + chapters: List[int] = [], + trg_project_path: str = "", stylesheet: dict = usfm.relaxed_stylesheet, include_inline_elements: bool = False, ) -> None: @@ -240,11 +297,50 @@ def translate_usfm( segments = collect_segments(book, doc) - sentences = (s.text.strip() for s in segments) - vrefs = (s.ref for s in segments) + sentences = [s.text.strip() for s in segments] + vrefs = [s.ref for s in segments] LOGGER.info(f"File {src_file_path} parsed correctly.") - translations = list(self.translate(sentences, src_iso, trg_iso, vrefs)) + # Translate select chapters + if len(chapters) > 0: + idxs_to_translate = [] + sentences_to_translate = [] + vrefs_to_translate = [] + for i in range(len(sentences)): + if vrefs[i].chapter_num in chapters: + idxs_to_translate.append(i) + sentences_to_translate.append(sentences[i]) + vrefs_to_translate.append(vrefs[i]) + + partial_translation = list(self.translate(sentences_to_translate, src_iso, trg_iso, vrefs_to_translate)) + + # Get translation from pre-existing target project to fill in translation + if trg_project_path != "": + trg_project_book_path = get_book_path(trg_project_path, book) + if trg_project_book_path.exists(): + with trg_project_book_path.open(mode="r", encoding="utf-8-sig") as book_file: + trg_doc: List[sfm.Element] = list( + usfm.parser(book_file, stylesheet=stylesheet, canonicalise_footnotes=False) + ) + if not include_inline_elements: + remove_inline_elements(trg_doc) + trg_segments = collect_segments(book, trg_doc) + trg_sentences = [s.text.strip() for s in trg_segments] + trg_vrefs = [s.ref for s in trg_segments] + + translations = insert_translation_into_trg_sentences( + partial_translation, vrefs_to_translate, trg_sentences, trg_vrefs, chapters + ) + update_segments(trg_segments, translations) + with trg_file_path.open(mode="w", encoding="utf-8", newline="\n") as output_file: + output_file.write(sfm.generate(trg_doc)) + return + + translations = [""] * len(sentences) + for i, idx in enumerate(idxs_to_translate): + translations[idx] = partial_translation[i] + else: + translations = list(self.translate(sentences, src_iso, trg_iso, vrefs)) update_segments(segments, translations) diff --git a/silnlp/nmt/config.py b/silnlp/nmt/config.py index cfa8ae45..25653dbd 100644 --- a/silnlp/nmt/config.py +++ b/silnlp/nmt/config.py @@ -10,13 +10,13 @@ from typing import Any, Dict, Iterable, List, Optional, Set, TextIO, Tuple, Union, cast import pandas as pd -from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef, get_books +from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef, get_chapters from tqdm import tqdm from ..alignment.config import get_aligner_name from ..alignment.utils import add_alignment_scores from ..common.corpus import ( - exclude_books, + exclude_chapters, filter_parallel_corpus, get_mt_corpus_path, get_scripture_parallel_corpus, @@ -26,7 +26,7 @@ get_terms_glosses_path, get_terms_list, get_terms_renderings_path, - include_books, + include_chapters, load_corpus, split_corpus, split_parallel_corpus, @@ -98,8 +98,8 @@ class CorpusPair: disjoint_test: bool disjoint_val: bool score_threshold: float - corpus_books: Set[int] - test_books: Set[int] + corpus_books: Dict[int, List[int]] + test_books: Dict[int, List[int]] use_test_set_from: str src_terms_files: List[DataFile] trg_terms_files: List[DataFile] @@ -198,8 +198,10 @@ def parse_corpus_pairs(corpus_pairs: List[dict]) -> List[CorpusPair]: pair["disjoint_val"] = False disjoint_val: bool = pair["disjoint_val"] score_threshold: float = pair.get("score_threshold", 0.0) - corpus_books = get_books(pair.get("corpus_books", [])) - test_books = get_books(pair.get("test_books", [])) + corpus_books_string = pair.get("corpus_books", "") + corpus_books = get_chapters(corpus_books_string) if len(corpus_books_string) > 0 else {} + test_books_string = pair.get("test_books", "") + test_books = get_chapters(test_books_string) if len(test_books_string) > 0 else {} use_test_set_from: str = pair.get("use_test_set_from", "") src_terms_files = get_terms_files(src_files) if is_set(type, DataFileType.TRAIN) else [] @@ -344,6 +346,7 @@ def __init__(self, exp_dir: Path, config: dict) -> None: self.has_scripture_data = False self._iso_pairs: Dict[Tuple[str, str], IsoPairInfo] = {} self.src_projects: Set[str] = set() + self.trg_projects: Set[str] = set() for corpus_pair in self.corpus_pairs: pair_src_isos = {sf.iso for sf in corpus_pair.src_files} pair_trg_isos = {tf.iso for tf in corpus_pair.trg_files} @@ -362,6 +365,7 @@ def __init__(self, exp_dir: Path, config: dict) -> None: self.src_file_paths.update(sf.path for sf in corpus_pair.src_terms_files) self.trg_file_paths.update(tf.path for tf in corpus_pair.trg_terms_files) self.src_projects.update(sf.project for sf in corpus_pair.src_files) + self.trg_projects.update(sf.project for sf in corpus_pair.trg_files) if terms_config["include_glosses"]: if "en" in pair_src_isos: self.src_file_paths.update(get_terms_glosses_file_paths(corpus_pair.src_terms_files)) @@ -595,11 +599,11 @@ def _write_scripture_data_sets( corpus["source"] = [self._noise(pair.src_noise, x) for x in corpus["source"]] if len(pair.corpus_books) > 0: - cur_train = include_books(corpus, pair.corpus_books) - if len(pair.corpus_books.intersection(pair.test_books)) > 0: - cur_train = exclude_books(cur_train, pair.test_books) + cur_train = include_chapters(corpus, pair.corpus_books) + if len(pair.test_books) > 0: + cur_train = exclude_chapters(cur_train, pair.test_books) elif len(pair.test_books) > 0: - cur_train = exclude_books(corpus, pair.test_books) + cur_train = exclude_chapters(corpus, pair.test_books) else: cur_train = corpus @@ -622,7 +626,7 @@ def _write_scripture_data_sets( test_indices = set(random.sample(indices, min(split_size, len(indices)))) if len(pair.test_books) > 0: - cur_test = include_books(corpus, pair.test_books) + cur_test = include_chapters(corpus, pair.test_books) if test_size > 0: _, cur_test = split_parallel_corpus( cur_test, diff --git a/silnlp/nmt/test.py b/silnlp/nmt/test.py index 9cdba62f..7125d8df 100644 --- a/silnlp/nmt/test.py +++ b/silnlp/nmt/test.py @@ -7,7 +7,7 @@ from typing import IO, Dict, List, Optional, Set, TextIO, Tuple import sacrebleu -from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef, book_number_to_id, get_books +from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef, book_number_to_id, get_chapters from sacrebleu.metrics import BLEU, BLEUScore from ..common.environment import SIL_NLP_ENV @@ -151,7 +151,7 @@ def process_individual_books( ref_file_paths: List[Path], vref_file_path: Path, select_rand_ref_line: bool, - books: Set[int], + books: Dict[int, List[int]], ) -> Dict[str, Tuple[List[str], List[List[str]]]]: # Output data structure book_dict: Dict[str, Tuple[List[str], List[List[str]]]] = {} @@ -210,7 +210,7 @@ def load_test_data( output_file_name: str, ref_projects: Set[str], config: Config, - books: Set[int], + books: Dict[int, List[int]], by_book: bool, ) -> Tuple[List[str], List[List[str]], Dict[str, Tuple[List[str], List[List[str]]]]]: sys: List[str] = [] @@ -336,7 +336,7 @@ def test_checkpoint( checkpoint_type: CheckpointType, step: int, scorers: Set[str], - books: Set[int], + books: Dict[int, List[int]], ) -> List[PairScore]: config.set_seed() vref_file_names: List[str] = [] @@ -344,7 +344,7 @@ def test_checkpoint( translation_file_names: List[str] = [] refs_patterns: List[str] = [] translation_detok_file_names: List[str] = [] - suffix_str = "_".join(map(lambda n: book_number_to_id(n), sorted(books))) + suffix_str = "_".join(map(lambda n: book_number_to_id(n), sorted(books.keys()))) if len(suffix_str) > 0: suffix_str += "-" suffix_str += "avg" if step == -1 else str(step) @@ -479,7 +479,7 @@ def test( LOGGER.info("No test dataset.") return - books_nums = get_books(books) + books_nums = get_chapters(";".join(books)) if len(scorers) == 0: scorers.add("bleu") diff --git a/silnlp/nmt/translate.py b/silnlp/nmt/translate.py index e658b166..2c7630e8 100644 --- a/silnlp/nmt/translate.py +++ b/silnlp/nmt/translate.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Iterable, Optional, Tuple, Union -from machine.scripture import VerseRef, book_number_to_id, get_books +from machine.scripture import VerseRef, book_number_to_id, get_chapters from ..common.environment import SIL_NLP_ENV from ..common.paratext import book_file_name_digits, get_project_dir @@ -44,34 +44,79 @@ def translate_books( self, books: str, src_project: Optional[str], + trg_project: Optional[str], trg_iso: Optional[str], include_inline_elements: bool = False, ): - translator, config, step_str = self._init_translation_task(experiment_suffix=f"_{self.checkpoint}_{books}") - book_nums = get_books(books) + book_nums = get_chapters(books) + translator, config, step_str = self._init_translation_task( + experiment_suffix=f"_{self.checkpoint}_{[book_number_to_id(book) for book in book_nums.keys()]}" + ) if src_project is None: if len(config.src_projects) != 1: raise RuntimeError("A source project must be specified.") src_project = next(iter(config.src_projects)) + SIL_NLP_ENV.copy_pt_project_from_bucket(src_project) + src_project_dir = get_project_dir(src_project) if not src_project_dir.is_dir(): LOGGER.error(f"Source project {src_project} not found in projects folder {src_project_dir}") return + if any(len(book_nums[book]) > 0 for book in book_nums): + use_trg_project = True + if trg_project is None: + if len(config.trg_projects) != 1: + use_trg_project = False + else: + trg_project = next(iter(config.trg_projects)) + + if use_trg_project: + SIL_NLP_ENV.copy_pt_project_from_bucket(trg_project) + + trg_project_dir = get_project_dir(trg_project) + if not trg_project_dir.is_dir(): + LOGGER.error(f"Target project {trg_project} not found in projects folder {trg_project_dir}") + return + if trg_iso is None: trg_iso = config.default_test_trg_iso output_dir = config.exp_dir / "infer" / step_str output_dir.mkdir(exist_ok=True, parents=True) + if trg_project is not None: + output_dir_trg_project = output_dir / trg_project + output_dir_trg_project.mkdir(exist_ok=True) for book_num in book_nums: book = book_number_to_id(book_num) - output_path = output_dir / f"{book_file_name_digits(book_num)}{book}.SFM" try: LOGGER.info(f"Translating {book} ...") - translator.translate_book(src_project, book, output_path, trg_iso, include_inline_elements) + if ( + trg_project is not None and len(book_nums[book_num]) > 0 + ): # Pass target project to fill in missing chapters if only some are being translated + output_path = output_dir_trg_project / f"{book_file_name_digits(book_num)}{book}.SFM" + translator.translate_book( + src_project, + book, + output_path, + trg_iso, + book_nums[book_num], + trg_project, + include_inline_elements, + ) + else: + output_path = output_dir / f"{book_file_name_digits(book_num)}{book}.SFM" + translator.translate_book( + src_project, + book, + output_path, + trg_iso, + book_nums[book_num], + include_inline_elements=include_inline_elements, + ) except Exception as e: error_str = " ".join([str(s) for s in e.args]) LOGGER.error(f"Was not able to translate {book}. Error: {error_str}") @@ -216,7 +261,17 @@ def main() -> None: parser.add_argument("--end-seq", default=None, type=int, help="Ending file sequence #") parser.add_argument("--src-project", default=None, type=str, help="The source project to translate") parser.add_argument( - "--books", metavar="books", nargs="+", default=[], help="The books to translate; e.g., 'NT', 'OT', 'GEN,EXO'" + "--trg-project", + default=None, + type=str, + help="The target project to use as the output for chapters that aren't translated", + ) + parser.add_argument( + "--books", + metavar="books", + nargs="+", + default=[], + help="The books to translate; e.g., 'NT', 'OT', 'GEN,EXO', can also select chapters; e.g., 'MAT-REV;-LUK10-30', 'MAT1,2,3,5-11'", ) parser.add_argument("--src-iso", default=None, type=str, help="The source language (iso code) to translate from") parser.add_argument("--trg-iso", default=None, type=str, help="The target language (iso code) to translate to") @@ -255,8 +310,6 @@ def main() -> None: if args.memory_growth: enable_memory_growth() - SIL_NLP_ENV.copy_pt_project_from_bucket(args.src_project) - translator = TranslationTask( name=args.experiment, checkpoint=args.checkpoint, @@ -265,9 +318,13 @@ def main() -> None: if len(args.books) > 0: if args.debug: - show_attrs(cli_args=args, actions=[f"Will attempt to translate books {args.books} into {args.trg_iso}"]) + show_attrs( + cli_args=args, actions=[f"Will attempt to translate books {';'.join(args.books)} into {args.trg_iso}"] + ) exit() - translator.translate_books(args.books, args.src_project, args.trg_iso, args.include_inline_elements) + translator.translate_books( + ";".join(args.books), args.src_project, args.trg_project, args.trg_iso, args.include_inline_elements + ) elif args.src_prefix is not None: if args.debug: show_attrs(