Skip to content

Commit

Permalink
Merge branch 'master' into #238_upgrade_poetry
Browse files Browse the repository at this point in the history
  • Loading branch information
mshannon-sil authored Dec 1, 2023
2 parents d99a436 + bcba32d commit caa5c03
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 33 deletions.
20 changes: 20 additions & 0 deletions silnlp/common/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,26 @@ def exclude_books(corpus: pd.DataFrame, books: Set[int]) -> pd.DataFrame:
return corpus[corpus.apply(lambda r: r["vref"].book_num not in books, axis=1)].copy()


def include_chapters(corpus: pd.DataFrame, books: dict) -> pd.DataFrame:
return corpus[
corpus.apply(
lambda r: r["vref"].book_num in books
and (len(books[r["vref"].book_num]) == 0 or r["vref"].chapter_num in books[r["vref"].book_num]),
axis=1,
)
].copy()


def exclude_chapters(corpus: pd.DataFrame, books: dict) -> pd.DataFrame:
return corpus[
corpus.apply(
lambda r: r["vref"].book_num not in books
or (len(books[r["vref"].book_num]) > 0 and r["vref"].chapter_num not in books[r["vref"].book_num]),
axis=1,
)
].copy()


def get_terms_metadata_path(list_name: str, mt_terms_dir: Path = SIL_NLP_ENV.mt_terms_dir) -> Path:
md_path = SIL_NLP_ENV.assets_dir / f"{list_name}-metadata.txt"
if md_path.is_file():
Expand Down
106 changes: 101 additions & 5 deletions silnlp/common/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,50 @@ def remove_inline_elements(doc: List[sfm.Element]) -> None:
remove_inline_elements_from_element(root)


def insert_translation_into_trg_sentences(
sentences: List[str],
vrefs: List[VerseRef],
trg_sentences: List[str],
trg_vrefs: List[VerseRef],
chapters: List[int],
) -> List[str]:
ret = [""] * len(trg_sentences)
translation_idx = 0
for i in range(len(trg_sentences)):
if trg_vrefs[i].chapter_num not in chapters:
ret[i] = trg_sentences[i]
continue
# Skip over rest of verse since the whole verse is put into the first entry
if (
i > 0
and trg_vrefs[i].chapter_num == trg_vrefs[i - 1].chapter_num
and trg_vrefs[i].verse_num == trg_vrefs[i - 1].verse_num
):
continue
# If translation_idx gets behind, catch up
while translation_idx < len(sentences) and (
trg_vrefs[i].chapter_num > vrefs[translation_idx].chapter_num
or (
trg_vrefs[i].chapter_num == vrefs[translation_idx].chapter_num
and trg_vrefs[i].verse_num > vrefs[translation_idx].verse_num
)
):
translation_idx += 1

# Put all parts of the translated verse into the first entry for that verse
while (
translation_idx < len(sentences)
and vrefs[translation_idx].chapter_num == trg_vrefs[i].chapter_num
and vrefs[translation_idx].verse_num == trg_vrefs[i].verse_num
):
if ret[i] != "":
ret[i] += " "
ret[i] += sentences[translation_idx]
translation_idx += 1

return ret


class Translator(ABC):
@abstractmethod
def translate(
Expand All @@ -201,26 +245,39 @@ def translate_text(self, src_file_path: Path, trg_file_path: Path, src_iso: str,
write_corpus(trg_file_path, self.translate(load_corpus(src_file_path), src_iso, trg_iso))

def translate_book(
self, src_project: str, book: str, output_path: Path, trg_iso: str, include_inline_elements: bool = False
self,
src_project: str,
book: str,
output_path: Path,
trg_iso: str,
chapters: List[int] = [],
trg_project: str = "",
include_inline_elements: bool = False,
) -> None:
src_project_dir = get_project_dir(src_project)
with (src_project_dir / "Settings.xml").open("rb") as settings_file:
settings_tree = etree.parse(settings_file)
src_iso = get_iso(settings_tree)
book_path = get_book_path(src_project, book)
stylesheet = get_stylesheet(src_project_dir)

if not book_path.is_file():
raise RuntimeError(f"Can't find file {book_path} for book {book}")
else:
LOGGER.info(f"Found the file {book_path} for book {book}")
self.translate_usfm(book_path, output_path, src_iso, trg_iso, stylesheet, include_inline_elements)

self.translate_usfm(
book_path, output_path, src_iso, trg_iso, chapters, trg_project, stylesheet, include_inline_elements
)

def translate_usfm(
self,
src_file_path: Path,
trg_file_path: Path,
src_iso: str,
trg_iso: str,
chapters: List[int] = [],
trg_project_path: str = "",
stylesheet: dict = usfm.relaxed_stylesheet,
include_inline_elements: bool = False,
) -> None:
Expand All @@ -240,11 +297,50 @@ def translate_usfm(

segments = collect_segments(book, doc)

sentences = (s.text.strip() for s in segments)
vrefs = (s.ref for s in segments)
sentences = [s.text.strip() for s in segments]
vrefs = [s.ref for s in segments]
LOGGER.info(f"File {src_file_path} parsed correctly.")

translations = list(self.translate(sentences, src_iso, trg_iso, vrefs))
# Translate select chapters
if len(chapters) > 0:
idxs_to_translate = []
sentences_to_translate = []
vrefs_to_translate = []
for i in range(len(sentences)):
if vrefs[i].chapter_num in chapters:
idxs_to_translate.append(i)
sentences_to_translate.append(sentences[i])
vrefs_to_translate.append(vrefs[i])

partial_translation = list(self.translate(sentences_to_translate, src_iso, trg_iso, vrefs_to_translate))

# Get translation from pre-existing target project to fill in translation
if trg_project_path != "":
trg_project_book_path = get_book_path(trg_project_path, book)
if trg_project_book_path.exists():
with trg_project_book_path.open(mode="r", encoding="utf-8-sig") as book_file:
trg_doc: List[sfm.Element] = list(
usfm.parser(book_file, stylesheet=stylesheet, canonicalise_footnotes=False)
)
if not include_inline_elements:
remove_inline_elements(trg_doc)
trg_segments = collect_segments(book, trg_doc)
trg_sentences = [s.text.strip() for s in trg_segments]
trg_vrefs = [s.ref for s in trg_segments]

translations = insert_translation_into_trg_sentences(
partial_translation, vrefs_to_translate, trg_sentences, trg_vrefs, chapters
)
update_segments(trg_segments, translations)
with trg_file_path.open(mode="w", encoding="utf-8", newline="\n") as output_file:
output_file.write(sfm.generate(trg_doc))
return

translations = [""] * len(sentences)
for i, idx in enumerate(idxs_to_translate):
translations[idx] = partial_translation[i]
else:
translations = list(self.translate(sentences, src_iso, trg_iso, vrefs))

update_segments(segments, translations)

Expand Down
28 changes: 16 additions & 12 deletions silnlp/nmt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from typing import Any, Dict, Iterable, List, Optional, Set, TextIO, Tuple, Union, cast

import pandas as pd
from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef, get_books
from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef, get_chapters
from tqdm import tqdm

from ..alignment.config import get_aligner_name
from ..alignment.utils import add_alignment_scores
from ..common.corpus import (
exclude_books,
exclude_chapters,
filter_parallel_corpus,
get_mt_corpus_path,
get_scripture_parallel_corpus,
Expand All @@ -26,7 +26,7 @@
get_terms_glosses_path,
get_terms_list,
get_terms_renderings_path,
include_books,
include_chapters,
load_corpus,
split_corpus,
split_parallel_corpus,
Expand Down Expand Up @@ -98,8 +98,8 @@ class CorpusPair:
disjoint_test: bool
disjoint_val: bool
score_threshold: float
corpus_books: Set[int]
test_books: Set[int]
corpus_books: Dict[int, List[int]]
test_books: Dict[int, List[int]]
use_test_set_from: str
src_terms_files: List[DataFile]
trg_terms_files: List[DataFile]
Expand Down Expand Up @@ -198,8 +198,10 @@ def parse_corpus_pairs(corpus_pairs: List[dict]) -> List[CorpusPair]:
pair["disjoint_val"] = False
disjoint_val: bool = pair["disjoint_val"]
score_threshold: float = pair.get("score_threshold", 0.0)
corpus_books = get_books(pair.get("corpus_books", []))
test_books = get_books(pair.get("test_books", []))
corpus_books_string = pair.get("corpus_books", "")
corpus_books = get_chapters(corpus_books_string) if len(corpus_books_string) > 0 else {}
test_books_string = pair.get("test_books", "")
test_books = get_chapters(test_books_string) if len(test_books_string) > 0 else {}
use_test_set_from: str = pair.get("use_test_set_from", "")

src_terms_files = get_terms_files(src_files) if is_set(type, DataFileType.TRAIN) else []
Expand Down Expand Up @@ -344,6 +346,7 @@ def __init__(self, exp_dir: Path, config: dict) -> None:
self.has_scripture_data = False
self._iso_pairs: Dict[Tuple[str, str], IsoPairInfo] = {}
self.src_projects: Set[str] = set()
self.trg_projects: Set[str] = set()
for corpus_pair in self.corpus_pairs:
pair_src_isos = {sf.iso for sf in corpus_pair.src_files}
pair_trg_isos = {tf.iso for tf in corpus_pair.trg_files}
Expand All @@ -362,6 +365,7 @@ def __init__(self, exp_dir: Path, config: dict) -> None:
self.src_file_paths.update(sf.path for sf in corpus_pair.src_terms_files)
self.trg_file_paths.update(tf.path for tf in corpus_pair.trg_terms_files)
self.src_projects.update(sf.project for sf in corpus_pair.src_files)
self.trg_projects.update(sf.project for sf in corpus_pair.trg_files)
if terms_config["include_glosses"]:
if "en" in pair_src_isos:
self.src_file_paths.update(get_terms_glosses_file_paths(corpus_pair.src_terms_files))
Expand Down Expand Up @@ -595,11 +599,11 @@ def _write_scripture_data_sets(
corpus["source"] = [self._noise(pair.src_noise, x) for x in corpus["source"]]

if len(pair.corpus_books) > 0:
cur_train = include_books(corpus, pair.corpus_books)
if len(pair.corpus_books.intersection(pair.test_books)) > 0:
cur_train = exclude_books(cur_train, pair.test_books)
cur_train = include_chapters(corpus, pair.corpus_books)
if len(pair.test_books) > 0:
cur_train = exclude_chapters(cur_train, pair.test_books)
elif len(pair.test_books) > 0:
cur_train = exclude_books(corpus, pair.test_books)
cur_train = exclude_chapters(corpus, pair.test_books)
else:
cur_train = corpus

Expand All @@ -622,7 +626,7 @@ def _write_scripture_data_sets(
test_indices = set(random.sample(indices, min(split_size, len(indices))))

if len(pair.test_books) > 0:
cur_test = include_books(corpus, pair.test_books)
cur_test = include_chapters(corpus, pair.test_books)
if test_size > 0:
_, cur_test = split_parallel_corpus(
cur_test,
Expand Down
12 changes: 6 additions & 6 deletions silnlp/nmt/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import IO, Dict, List, Optional, Set, TextIO, Tuple

import sacrebleu
from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef, book_number_to_id, get_books
from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef, book_number_to_id, get_chapters
from sacrebleu.metrics import BLEU, BLEUScore

from ..common.environment import SIL_NLP_ENV
Expand Down Expand Up @@ -151,7 +151,7 @@ def process_individual_books(
ref_file_paths: List[Path],
vref_file_path: Path,
select_rand_ref_line: bool,
books: Set[int],
books: Dict[int, List[int]],
) -> Dict[str, Tuple[List[str], List[List[str]]]]:
# Output data structure
book_dict: Dict[str, Tuple[List[str], List[List[str]]]] = {}
Expand Down Expand Up @@ -210,7 +210,7 @@ def load_test_data(
output_file_name: str,
ref_projects: Set[str],
config: Config,
books: Set[int],
books: Dict[int, List[int]],
by_book: bool,
) -> Tuple[List[str], List[List[str]], Dict[str, Tuple[List[str], List[List[str]]]]]:
sys: List[str] = []
Expand Down Expand Up @@ -336,15 +336,15 @@ def test_checkpoint(
checkpoint_type: CheckpointType,
step: int,
scorers: Set[str],
books: Set[int],
books: Dict[int, List[int]],
) -> List[PairScore]:
config.set_seed()
vref_file_names: List[str] = []
source_file_names: List[str] = []
translation_file_names: List[str] = []
refs_patterns: List[str] = []
translation_detok_file_names: List[str] = []
suffix_str = "_".join(map(lambda n: book_number_to_id(n), sorted(books)))
suffix_str = "_".join(map(lambda n: book_number_to_id(n), sorted(books.keys())))
if len(suffix_str) > 0:
suffix_str += "-"
suffix_str += "avg" if step == -1 else str(step)
Expand Down Expand Up @@ -479,7 +479,7 @@ def test(
LOGGER.info("No test dataset.")
return

books_nums = get_books(books)
books_nums = get_chapters(";".join(books))

if len(scorers) == 0:
scorers.add("bleu")
Expand Down
Loading

0 comments on commit caa5c03

Please sign in to comment.