diff --git a/.vscode/settings.json b/.vscode/settings.json index 7504701c..119b98d8 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -11,20 +11,13 @@ "editor.formatOnSave": true, "editor.codeActionsOnSave": { "source.organizeImports": "explicit" - }, + } }, "files.associations": { - "*.SFM": "usfm", + "*.SFM": "usfm" }, - "black-formatter.path": [ - "poetry", - "run", - "black" - ], - "isort.args": [ - "--profile", - "black" - ], + "black-formatter.path": ["poetry", "run", "black"], + "isort.args": ["--profile", "black"], "cSpell.words": [ "CLEARML", "DYNACONF", @@ -34,5 +27,6 @@ "Usfm", "venv" ], - "python-envs.defaultEnvManager": "ms-python.python:system", -} \ No newline at end of file + "python-envs.defaultEnvManager": "ms-python.python:poetry", + "python-envs.defaultPackageManager": "ms-python.python:poetry" +} diff --git a/machine/translation/__init__.py b/machine/translation/__init__.py index fccb7fc7..2c48e634 100644 --- a/machine/translation/__init__.py +++ b/machine/translation/__init__.py @@ -23,6 +23,7 @@ from .symmetrized_word_alignment_model import SymmetrizedWordAlignmentModel from .symmetrized_word_alignment_model_trainer import SymmetrizedWordAlignmentModelTrainer from .trainer import Trainer, TrainStats +from .transductive_word_alignment_model import TransductiveWordAlignmentModel from .translation_constants import MAX_SEGMENT_LENGTH from .translation_engine import TranslationEngine from .translation_model import TranslationModel @@ -69,6 +70,7 @@ "SymmetrizedWordAlignmentModelTrainer", "Trainer", "TrainStats", + "TransductiveWordAlignmentModel", "translate_corpus", "TranslationEngine", "TranslationModel", diff --git a/machine/translation/corpus_ops.py b/machine/translation/corpus_ops.py index 37b350e3..f628dc90 100644 --- a/machine/translation/corpus_ops.py +++ b/machine/translation/corpus_ops.py @@ -5,6 +5,7 @@ from ..corpora.parallel_text_row import ParallelTextRow from ..utils.progress_status import ProgressStatus from .symmetrization_heuristic import SymmetrizationHeuristic +from .transductive_word_alignment_model import TransductiveWordAlignmentModel from .translation_engine import TranslationEngine from .word_aligner import WordAligner from .word_alignment_matrix import WordAlignmentMatrix @@ -23,11 +24,16 @@ def word_align_corpus( model = create_thot_symmetrized_word_alignment_model(aligner) model.heuristic = symmetrization_heuristic + # Retain the alignments computed during training so that the corpus can be aligned + # without a separate, potentially expensive, inference pass. + model.emit_training_alignments = True with model.create_trainer(corpus) as trainer: trainer.train(progress) trainer.save() aligner = model + if isinstance(aligner, TransductiveWordAlignmentModel): + return _TransductiveWordAlignParallelTextCorpus(corpus, aligner) return _WordAlignParallelTextCorpus(corpus, aligner, batch_size) @@ -67,6 +73,38 @@ def _get_rows(self, text_ids: Optional[Iterable[str]] = None) -> Generator[Paral yield row +class _TransductiveWordAlignParallelTextCorpus(ParallelTextCorpus): + def __init__(self, corpus: ParallelTextCorpus, model: TransductiveWordAlignmentModel) -> None: + self._corpus = corpus + self._model = model + + def is_source_tokenized(self) -> bool: + return self._corpus.is_source_tokenized + + def is_target_tokenized(self) -> bool: + return self._corpus.is_target_tokenized + + def _get_rows(self, text_ids: Optional[Iterable[str]] = None) -> Generator[ParallelTextRow, None, None]: + # The training alignments are keyed by the order in which the sentence pairs were added + # during training, so the full corpus must be iterated to keep the index in sync; rows that + # are not in the requested texts are skipped rather than filtered out of the enumeration. + text_id_set = None if text_ids is None else set(text_ids) + with self._corpus.get_rows() as rows: + for index, row in enumerate(rows): + if text_id_set is not None and row.text_id not in text_id_set: + continue + alignment = self._model.get_training_alignment(index) + known_alignment = WordAlignmentMatrix.from_parallel_text_row(row) + if known_alignment is not None: + known_alignment.priority_symmetrize_with(alignment) + alignment = known_alignment + word_pairs = alignment.to_aligned_word_pairs() + if isinstance(self._model, WordAlignmentModel): + self._model.compute_aligned_word_pair_scores(row.source_segment, row.target_segment, word_pairs) + row.aligned_word_pairs = word_pairs + yield row + + class _TranslateParallelTextCorpus(ParallelTextCorpus): def __init__(self, corpus: ParallelTextCorpus, translation_engine: TranslationEngine, batch_size: int) -> None: self._corpus = corpus diff --git a/machine/translation/thot/thot_symmetrized_word_alignment_model.py b/machine/translation/thot/thot_symmetrized_word_alignment_model.py index c1cbe284..32e8ffba 100644 --- a/machine/translation/thot/thot_symmetrized_word_alignment_model.py +++ b/machine/translation/thot/thot_symmetrized_word_alignment_model.py @@ -9,6 +9,7 @@ from ..symmetrized_word_alignment_model import SymmetrizedWordAlignmentModel from ..symmetrized_word_alignment_model_trainer import SymmetrizedWordAlignmentModelTrainer from ..trainer import Trainer +from ..transductive_word_alignment_model import TransductiveWordAlignmentModel from ..word_alignment_matrix import WordAlignmentMatrix from .thot_utils import batch from .thot_word_alignment_model import ThotWordAlignmentModel @@ -16,7 +17,7 @@ _MAX_BATCH_SIZE = 10240 -class ThotSymmetrizedWordAlignmentModel(SymmetrizedWordAlignmentModel): +class ThotSymmetrizedWordAlignmentModel(SymmetrizedWordAlignmentModel, TransductiveWordAlignmentModel): def __init__( self, direct_word_alignment_model: ThotWordAlignmentModel, @@ -56,9 +57,26 @@ def align_batch(self, segments: Sequence[Sequence[Sequence[str]]]) -> Sequence[W results.append(WordAlignmentMatrix(matrix.to_numpy())) return results + @property + def emit_training_alignments(self) -> bool: + return self.direct_word_alignment_model.emit_training_alignments + + @emit_training_alignments.setter + def emit_training_alignments(self, value: bool) -> None: + self.direct_word_alignment_model.emit_training_alignments = value + self.inverse_word_alignment_model.emit_training_alignments = value + + @property + def training_alignment_count(self) -> int: + return self._aligner.num_sentence_pairs + + def get_training_alignment(self, n: int) -> WordAlignmentMatrix: + _, matrix = self._aligner.get_training_alignment(n) + return WordAlignmentMatrix(matrix.to_numpy()) + def create_trainer(self, corpus: ParallelTextCorpus) -> Trainer: - direct_trainer = self._direct_word_alignment_model.create_trainer(corpus) - inverse_trainer = self._inverse_word_alignment_model.create_trainer(corpus.invert()) + direct_trainer = self.direct_word_alignment_model.create_trainer(corpus) + inverse_trainer = self.inverse_word_alignment_model.create_trainer(corpus.invert()) return _Trainer(self, direct_trainer, inverse_trainer) @@ -66,7 +84,7 @@ def __enter__(self) -> ThotSymmetrizedWordAlignmentModel: return self def _reset_aligner(self) -> None: - self._aligner = ta.SymmetrizedAligner( + self._aligner = ta.SymmetrizedAlignmentModel( self.direct_word_alignment_model.thot_model, self.inverse_word_alignment_model.thot_model ) self._aligner.heuristic = _convert_heuristic(self._heuristic) diff --git a/machine/translation/thot/thot_word_alignment_model.py b/machine/translation/thot/thot_word_alignment_model.py index 791b35c6..273e2ad7 100644 --- a/machine/translation/thot/thot_word_alignment_model.py +++ b/machine/translation/thot/thot_word_alignment_model.py @@ -10,6 +10,7 @@ from ...corpora.parallel_text_corpus import ParallelTextCorpus from ...utils.typeshed import StrPath from ..ibm1_word_alignment_model import Ibm1WordAlignmentModel +from ..transductive_word_alignment_model import TransductiveWordAlignmentModel from ..word_alignment_matrix import WordAlignmentMatrix from ..word_vocabulary import WordVocabulary from .thot_utils import batch, escape_token, escape_tokens, unescape_token @@ -21,7 +22,7 @@ _MAX_BATCH_SIZE = 10240 -class ThotWordAlignmentModel(Ibm1WordAlignmentModel): +class ThotWordAlignmentModel(Ibm1WordAlignmentModel, TransductiveWordAlignmentModel): def __init__(self, prefix_filename: Optional[StrPath] = None, create_new: bool = False) -> None: self._set_model(self._create_model()) if prefix_filename is not None: @@ -33,6 +34,7 @@ def __init__(self, prefix_filename: Optional[StrPath] = None, create_new: bool = else: self._prefix_filename = None self.parameters = ThotWordAlignmentParameters() + self.emit_training_alignments = False @property def source_words(self) -> WordVocabulary: @@ -94,6 +96,14 @@ def align_batch(self, segments: Sequence[Sequence[Sequence[str]]]) -> Sequence[W results.append(WordAlignmentMatrix(matrix.to_numpy())) return results + @property + def training_alignment_count(self) -> int: + return self._model.num_sentence_pairs + + def get_training_alignment(self, n: int) -> WordAlignmentMatrix: + _, matrix = self._model.get_training_alignment(n) + return WordAlignmentMatrix(matrix.to_numpy()) + def get_translation_score( self, source_word: Optional[Union[str, int]], target_word: Optional[Union[str, int]] ) -> float: @@ -199,7 +209,13 @@ class _Trainer(ThotWordAlignmentModelTrainer): def __init__( self, model: ThotWordAlignmentModel, corpus: ParallelTextCorpus, prefix_filename: Optional[StrPath] ) -> None: - super().__init__(model.type, corpus, prefix_filename, model.parameters) + super().__init__( + model.type, + corpus, + prefix_filename, + model.parameters, + emit_training_alignments=model.emit_training_alignments, + ) self._machine_model = model def save(self) -> None: diff --git a/machine/translation/thot/thot_word_alignment_model_trainer.py b/machine/translation/thot/thot_word_alignment_model_trainer.py index e0ec21f6..e9de1524 100644 --- a/machine/translation/thot/thot_word_alignment_model_trainer.py +++ b/machine/translation/thot/thot_word_alignment_model_trainer.py @@ -29,6 +29,7 @@ def __init__( source_tokenizer: Tokenizer[str, int, str] = WHITESPACE_TOKENIZER, target_tokenizer: Tokenizer[str, int, str] = WHITESPACE_TOKENIZER, max_corpus_count: int = sys.maxsize, + emit_training_alignments: bool = False, ) -> None: ... @overload @@ -40,6 +41,8 @@ def __init__( parameters: ThotWordAlignmentParameters = ThotWordAlignmentParameters(), source_tokenizer: Tokenizer[str, int, str] = WHITESPACE_TOKENIZER, target_tokenizer: Tokenizer[str, int, str] = WHITESPACE_TOKENIZER, + max_corpus_count: int = sys.maxsize, + emit_training_alignments: bool = False, ) -> None: ... def __init__( @@ -51,6 +54,7 @@ def __init__( source_tokenizer: Tokenizer[str, int, str] = WHITESPACE_TOKENIZER, target_tokenizer: Tokenizer[str, int, str] = WHITESPACE_TOKENIZER, max_corpus_count: int = sys.maxsize, + emit_training_alignments: bool = False, ) -> None: if isinstance(corpus, tuple) and max_corpus_count != sys.maxsize: raise ValueError("max_corpus_count cannot be set when corpus filenames are provided.") @@ -60,6 +64,7 @@ def __init__( self._max_corpus_count = max_corpus_count self.source_tokenizer = source_tokenizer self.target_tokenizer = target_tokenizer + self.emit_training_alignments = emit_training_alignments self._stats = TrainStats() if isinstance(model_type, str): @@ -216,6 +221,12 @@ def report() -> None: if check_canceled is not None: check_canceled() + if self.emit_training_alignments: + # Retain the alignments computed during training so that they can be returned without a + # separate inference pass. Only the final (most refined) model's alignments are needed, + # since that is the model used for inference. + self._model.emit_training_alignments = True + trained_segment_count = 0 for model, iteration_count in self._models: if iteration_count == 0 and not self._is_eflomal: diff --git a/machine/translation/transductive_word_alignment_model.py b/machine/translation/transductive_word_alignment_model.py new file mode 100644 index 00000000..76eda158 --- /dev/null +++ b/machine/translation/transductive_word_alignment_model.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod + +from .word_alignment_matrix import WordAlignmentMatrix + + +class TransductiveWordAlignmentModel(ABC): + + @property + @abstractmethod + def training_alignment_count(self) -> int: ... + + @abstractmethod + def get_training_alignment(self, n: int) -> WordAlignmentMatrix: ... diff --git a/poetry.lock b/poetry.lock index 453496c2..5c9ac52f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4301,42 +4301,42 @@ type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14 [[package]] name = "sil-thot" -version = "3.5.0" +version = "3.5.1" description = "A toolkit for statistical word alignment and machine translation" optional = false python-versions = "<4.0,>=3.10" files = [ - {file = "sil_thot-3.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:79384c7d19f73d6eff2fd66977c0951ff94d2d27f2209ca2f9f7b0b525ab55aa"}, - {file = "sil_thot-3.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:98f145b4430a859ab566476750adb29878ec0b1f4b5441afe4305c5c281555a4"}, - {file = "sil_thot-3.5.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:95c2b04022fab86b8d4aac3f62665c97677ef68fc9e0ea689f2313f24c8cf9a7"}, - {file = "sil_thot-3.5.0-cp310-cp310-win32.whl", hash = "sha256:20d912479290df5b9d2c8fc7ac6497bf65fc1a849d99743d8126f3220fdf752d"}, - {file = "sil_thot-3.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:583392d50839bd0da9417e91effc8f53d5359d0cbc5f057eba2fbf58a3bf9b33"}, - {file = "sil_thot-3.5.0-cp310-cp310-win_arm64.whl", hash = "sha256:5f25aa5d4c9ab7ff8d749f325ab8a25273a2ede67e1cd1d6732527a1fa6773a7"}, - {file = "sil_thot-3.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:99a6cb3c6b65aec40970176f5c65840c033086e95db8264f84b2fea546d50094"}, - {file = "sil_thot-3.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e8cd801c71fb919724554759ab4051d5a6ef698deca1a3c12168300cc665d361"}, - {file = "sil_thot-3.5.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a189068d738402f564751390abc428aed99f287afb6267d364994d5aff39df02"}, - {file = "sil_thot-3.5.0-cp311-cp311-win32.whl", hash = "sha256:cf77c198033ba9ac96af657b42f4da615bff652b222d004c40fd03a8f03220cb"}, - {file = "sil_thot-3.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:0e840d02c70b17de70b60bfbd6d50c4258a86ca5de38718e2491e98c926c23df"}, - {file = "sil_thot-3.5.0-cp311-cp311-win_arm64.whl", hash = "sha256:fb9c6916a1c6cd083bd631a53605cc6ac24aa0655d207e592e10a6759a9deaa6"}, - {file = "sil_thot-3.5.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:e572d8dd5f8a15c8051736fb4376c38ef8e8363c2387f42c193f0ed77b00fc04"}, - {file = "sil_thot-3.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ddd77f14a2ffe3d2f7eb08239fe74d006dc661241d99e5ad4b8e844a4baf294"}, - {file = "sil_thot-3.5.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a107755bf22b841db9ffc599c1b9bcf56df2849513c21df7aa8bdf6a694a2c1e"}, - {file = "sil_thot-3.5.0-cp312-cp312-win32.whl", hash = "sha256:5b87815ba98cee0bb55ee95737683511cde009aa26229a0ed69d74b73b01280a"}, - {file = "sil_thot-3.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:071c4d6c8b7aac30355c3343f420173ee8fb2a9aabbe7a14db26dff68df6f227"}, - {file = "sil_thot-3.5.0-cp312-cp312-win_arm64.whl", hash = "sha256:db71f2ba820e161576afe633f22dec1e84b75d05dfcd48c8c621876644ef0669"}, - {file = "sil_thot-3.5.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:18d177227c274fe4359efe531b61e53afd6189f1cce6d127318c54d34ced6c6e"}, - {file = "sil_thot-3.5.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:03d0bfe640214241dba18373e3f24a7684fbfa3f9a0695ed8c6b60604256663f"}, - {file = "sil_thot-3.5.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8ea9a848895608b3287c72872ac30303fec192bea343728831ddaa13326b462a"}, - {file = "sil_thot-3.5.0-cp313-cp313-win32.whl", hash = "sha256:320e29e288faac7bbef3ab30f2dcb2631d90e35edf77229b0573f5851f8fcbb5"}, - {file = "sil_thot-3.5.0-cp313-cp313-win_amd64.whl", hash = "sha256:af54a686c43829159f60099c0e4181fb28dbd8df809a78df77089aebe458c34c"}, - {file = "sil_thot-3.5.0-cp313-cp313-win_arm64.whl", hash = "sha256:90ceb0c024ffcf6e3fccb8ea122f62d28efeef0166e949aec0701f053ca5f776"}, - {file = "sil_thot-3.5.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:f0ad4729aea14426b669ce05086fbf7bee79399f4fe7c4fb51c035db85d29288"}, - {file = "sil_thot-3.5.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:8573ed74fe766f7b85d75a36569053b40de74b13075972f7c499d5da59eef349"}, - {file = "sil_thot-3.5.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1961d018b4f5f0ce7e478f19cd505a7b99732761e32968d5ac5312d1ff25d047"}, - {file = "sil_thot-3.5.0-cp314-cp314-win32.whl", hash = "sha256:fed0aebe76c51cccfe60ba2656f6df6b5aa3d46276fd54435fca4ec20a3c4c32"}, - {file = "sil_thot-3.5.0-cp314-cp314-win_amd64.whl", hash = "sha256:b8d47b4755cc28b745d4aef92fe74a116b452389290df0453a8aff702c4a4841"}, - {file = "sil_thot-3.5.0-cp314-cp314-win_arm64.whl", hash = "sha256:3afdce0a4324643a64d308c7a6e657794f130563781612eb916336bc88d2d5b1"}, - {file = "sil_thot-3.5.0.tar.gz", hash = "sha256:20d0651e27837a8cfd6cb09da9b21cee1a312745ab52e671e7cfbe2d38f73a80"}, + {file = "sil_thot-3.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896c9d285cad8febdb053110e2cb69098cdbf749f58cb8fc401ecaef4ce3b31c"}, + {file = "sil_thot-3.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e16737d6c39d7596864e8b61df688ba69ea3855dfae2df9db269cd459ad10246"}, + {file = "sil_thot-3.5.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b1c701109371a1b0bcc48b86cccd84edadb30ae872496fb559ec93352463f6c"}, + {file = "sil_thot-3.5.1-cp310-cp310-win32.whl", hash = "sha256:d58040685f1cdb6a8500944c49b6f7369382059f62d6a76e818262458076af00"}, + {file = "sil_thot-3.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:b46e971ba975db88f4f186c89f2f1225b4e6555141184eae647613af12a8ec38"}, + {file = "sil_thot-3.5.1-cp310-cp310-win_arm64.whl", hash = "sha256:1acbbb874a284684fc717184af6fa613257e7689cd459e28ceaadc196d264409"}, + {file = "sil_thot-3.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8e81fb6f8344a8346fac772b20b773bdaadaa83be89b8d841b07c9cd7eb99abd"}, + {file = "sil_thot-3.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b7e870101a22e0fee84f323c57f596e8466bbbdf98c062b2ec55854c855a54dd"}, + {file = "sil_thot-3.5.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e72e9c625301f4f7ea87918fd33ca4915873a937f4efe934fb144bdfe05f0281"}, + {file = "sil_thot-3.5.1-cp311-cp311-win32.whl", hash = "sha256:5365d1d5a6bb21d10659ef26dbf7d1d74aae3de12b25e7f1207c8ee1873a95bb"}, + {file = "sil_thot-3.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:5318cbbd08c1f567d6d1d248e24aba00eb6cd2110ca92c3019db3fbe251aec44"}, + {file = "sil_thot-3.5.1-cp311-cp311-win_arm64.whl", hash = "sha256:806936112ae26d907c6391047a5547f7aaf458c7c565a520aafa1823878ae31a"}, + {file = "sil_thot-3.5.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:21f1e7e22b23b1292f6e2347f093967426d9dfa523b4bb09ea49382fc40cee28"}, + {file = "sil_thot-3.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5e8b8200c9d8f958eb1b425101252997c072225bbd21973c51656cd853bef9dc"}, + {file = "sil_thot-3.5.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ab04c62ae204cec86c766c7c0d9a74dc0c96f45774c5b210c3921a8f4d000004"}, + {file = "sil_thot-3.5.1-cp312-cp312-win32.whl", hash = "sha256:ac9cab93314d54b0328ff9a78b84182563bb92d350d5ebecf595ae35a4c23781"}, + {file = "sil_thot-3.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:d9cf20fe13f662070a801e611f9cacbc6c9e547ff42989080dd2a4ab350cbaae"}, + {file = "sil_thot-3.5.1-cp312-cp312-win_arm64.whl", hash = "sha256:9bf45b58d7335f6693fe074f450f7219f3211916da42a7ae4cc5243932e0ea43"}, + {file = "sil_thot-3.5.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:3ec0159e492e6eadedef52f67c6b5f397c0359e42b65309a6607edd6d0e7b69d"}, + {file = "sil_thot-3.5.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:65a699fc9ff6e7917c4b065a99cb3c7f6df5fe41d1e0cf599b5d7e7a33f2b112"}, + {file = "sil_thot-3.5.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:865d65c7354083d50c5e0b29acf0c50fcaa68256e0856a9c6d01a749adc4bc8b"}, + {file = "sil_thot-3.5.1-cp313-cp313-win32.whl", hash = "sha256:311df2a158cf6a0b9febee144d822da964e287b3c3ae996cbb6ec810eca7aba3"}, + {file = "sil_thot-3.5.1-cp313-cp313-win_amd64.whl", hash = "sha256:0fb8f015d5f1486651c2136f64e00fcc6038990a43d3fce9516612a6c7ed4af0"}, + {file = "sil_thot-3.5.1-cp313-cp313-win_arm64.whl", hash = "sha256:2eb9131ec0be2ffb7fb3ddd3f3ac9801c6fddb8cda6fd73ca19e66ed250af836"}, + {file = "sil_thot-3.5.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:34d1663f402382a72f0feb97b4961e35d8b87859c1de00bc890147a05bee1120"}, + {file = "sil_thot-3.5.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:afa926a1c7e261c33008a0c0ee9d0a923b0837c277fc0da07c57d28c8c6cdf9d"}, + {file = "sil_thot-3.5.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:24f24a9a19f95b7d5f3954e2bb4eff7ee9798a44fd7cae1c48fdd6601f01af31"}, + {file = "sil_thot-3.5.1-cp314-cp314-win32.whl", hash = "sha256:19a4d070cb3b86135f2837742daef87eb8d9f5101215699605cc4a4133fe1a41"}, + {file = "sil_thot-3.5.1-cp314-cp314-win_amd64.whl", hash = "sha256:3c40d3fc149a7b8d1eb4aed0f8e7d63e533f76e1e5f61284e756234d830468b9"}, + {file = "sil_thot-3.5.1-cp314-cp314-win_arm64.whl", hash = "sha256:dccb748942ca9e1a3bddf68f20096cbac8217457389f8b674cdad0fb409f8868"}, + {file = "sil_thot-3.5.1.tar.gz", hash = "sha256:29defec1d82b1017e6241b74d9ae8278f2283fefe6ab2581cf4b68365f659fed"}, ] [package.extras] @@ -5270,4 +5270,4 @@ thot = ["sil-thot"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.14" -content-hash = "3015d427d7e11c94f8e11cdf6970930e2db564428736cd38fb3bd7513c78f947" +content-hash = "5052b629991b8e2d0d668333aa3271ebd5bf1597e5301ead6b7871e74d165528" diff --git a/pyproject.toml b/pyproject.toml index 77dc1d9b..f65ed036 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ charset-normalizer = "^2.1.1" urllib3 = "<2" sentencepiece = "^0.2.0" -sil-thot = "^3.5.0" +sil-thot = "^3.5.1" transformers = "4.46.2" datasets = "^2.4.0" diff --git a/samples/word_alignment.ipynb b/samples/word_alignment.ipynb index 2e185d76..67a07ccb 100644 --- a/samples/word_alignment.ipynb +++ b/samples/word_alignment.ipynb @@ -6,7 +6,7 @@ "source": [ "# Word Alignment Tutorial\n", "\n", - "In this notebook, we will demonstrate how to use machine to train statistical word alignment models and then use them to predict alignments between sentences. Machine uses the [Thot](https://github.com/sillsdev/thot) library to implement word alignment models. The classes can be enabled by installing the `sil-machine` package with the `thot` optional dependency. Machine has implementations of all common statistical models, including the famous IBM models (1-4), HMM, and FastAlign." + "In this notebook, we will demonstrate how to use machine to train statistical word alignment models and then use them to predict alignments between sentences. Machine uses the [Thot](https://github.com/sillsdev/thot) library to implement word alignment models. The classes can be enabled by installing the `sil-machine` package with the `thot` optional dependency. Machine has implementations of all common statistical models, including the famous IBM models (1-4), HMM, FastAlign, and Eflomal." ] }, { @@ -498,6 +498,46 @@ " print(\"Target:\", \" \".join(target_segment))\n", " print(\"Alignment:\", alignment)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Transductive alignment models\n", + "\n", + "So far, we have treated training and alignment as two separate steps: first we train a model on a corpus, and then we use the trained model to *infer* alignments for sentences. This is known as _inductive_ alignment. Often, though, the corpus that we want to align is the very same corpus that we train on. In that case, a separate inference pass repeats work that was already done during training.\n", + "\n", + "Some models compute alignments as a natural byproduct of training. This is especially true for sampling-based models such as Eflomal: aligning a sentence pair from a cold start requires running several sampling iterations to \"burn in\" the alignment, so it is much cheaper to keep the alignments that were already produced while training. A model that can hand back the alignments it computed during training is called a _transductive_ model.\n", + "\n", + "In Machine, transductive models implement the `TransductiveWordAlignmentModel` interface. To retain the training alignments, set the `emit_training_alignments` property to `True` before training. Afterwards, the `get_training_alignment` method returns the alignment for the *n*-th sentence pair, in the order that the pairs appear in the training corpus.\n", + "\n", + "This is exactly what the `word_align_corpus` function does under the hood: since it trains a model on the same corpus that it is aligning, it simply returns the alignments that were produced during training instead of making a second pass." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from machine.translation.thot import create_thot_symmetrized_word_alignment_model\n", + "\n", + "transductive_model = create_thot_symmetrized_word_alignment_model(\"ibm1\")\n", + "transductive_model.heuristic = SymmetrizationHeuristic.GROW_DIAG_FINAL_AND\n", + "# Retain the alignments that are computed while the model is trained.\n", + "transductive_model.emit_training_alignments = True\n", + "with transductive_model.create_trainer(parallel_corpus) as trainer:\n", + " trainer.train(lambda status: print(f\"Training symmetrized IBM-1 model: {status.percent_completed:.2%}\"))\n", + " trainer.save()\n", + "\n", + "# The alignments computed during training are available directly, in corpus order, without\n", + "# aligning the sentences again.\n", + "for i, row in enumerate(parallel_corpus.take(5)):\n", + " alignment = transductive_model.get_training_alignment(i)\n", + " print(\"Source:\", row.source_text)\n", + " print(\"Target:\", row.target_text)\n", + " print(\"Alignment:\", str(alignment))" + ] } ], "metadata": { diff --git a/tests/translation/test_corpus_ops.py b/tests/translation/test_corpus_ops.py new file mode 100644 index 00000000..6a3a69d4 --- /dev/null +++ b/tests/translation/test_corpus_ops.py @@ -0,0 +1,75 @@ +from typing import Iterable, Optional + +import pytest +from testutils.thot_test_helpers import create_test_parallel_corpus + +from machine.corpora import ( + AlignedWordPair, + DictionaryTextCorpus, + MemoryText, + ParallelTextCorpus, + StandardParallelTextCorpus, + TextRow, +) +from machine.translation import SymmetrizationHeuristic, word_align_corpus +from machine.translation.thot import create_thot_symmetrized_word_alignment_model + + +def _alignment_strings(corpus: ParallelTextCorpus, text_ids: Optional[Iterable[str]] = None) -> list: + return [ + AlignedWordPair.to_string(row.aligned_word_pairs, include_scores=False) for row in corpus.get_rows(text_ids) + ] + + +@pytest.mark.parametrize("aligner", ["fast_align", "ibm1"]) +def test_word_align_corpus_transductive_matches_inference(aligner: str) -> None: + # For deterministic models, the alignments retained during training match those produced by a + # separate inference pass, so the transductive output must equal aligning each row directly. + transductive = _alignment_strings(word_align_corpus(create_test_parallel_corpus(), aligner=aligner)) + + model = create_thot_symmetrized_word_alignment_model(aligner) + model.heuristic = SymmetrizationHeuristic.GROW_DIAG_FINAL_AND # word_align_corpus's default + with model.create_trainer(create_test_parallel_corpus()) as trainer: + trainer.train() + trainer.save() + inference = [ + AlignedWordPair.to_string( + model.align(row.source_segment, row.target_segment).to_aligned_word_pairs(), include_scores=False + ) + for row in create_test_parallel_corpus().get_rows() + ] + assert transductive == inference + + +def test_word_align_corpus_default_is_transductive() -> None: + rows = list(word_align_corpus(create_test_parallel_corpus()).get_rows()) + assert len(rows) == 8 + assert any(row.aligned_word_pairs for row in rows) + + +def _create_two_text_parallel_corpus() -> StandardParallelTextCorpus: + src = DictionaryTextCorpus( + MemoryText("text1", [TextRow("text1", 1, "el gato".split()), TextRow("text1", 2, "la casa".split())]), + MemoryText("text2", [TextRow("text2", 1, "el perro corre".split()), TextRow("text2", 2, "la mesa".split())]), + ) + trg = DictionaryTextCorpus( + MemoryText("text1", [TextRow("text1", 1, "the cat".split()), TextRow("text1", 2, "the house".split())]), + MemoryText("text2", [TextRow("text2", 1, "the dog runs".split()), TextRow("text2", 2, "the table".split())]), + ) + return StandardParallelTextCorpus(src, trg) + + +def test_word_align_corpus_transductive_text_ids_keep_index_in_sync() -> None: + # Filtering by text must not desync the training-alignment index: the rows for a requested text + # must get exactly the alignments they got in the unfiltered pass, not those of earlier rows. + corpus = word_align_corpus(_create_two_text_parallel_corpus(), aligner="fast_align") + full = list(corpus.get_rows()) + text2_expected = [AlignedWordPair.to_string(r.aligned_word_pairs, include_scores=False) for r in full[2:]] + text2_actual = _alignment_strings(corpus, ["text2"]) + assert text2_actual == text2_expected + + +def test_word_align_corpus_transductive_eflomal() -> None: + rows = list(word_align_corpus(create_test_parallel_corpus(), aligner="eflomal").get_rows()) + assert len(rows) == 8 + assert any(row.aligned_word_pairs for row in rows) diff --git a/tests/translation/thot/test_thot_word_alignment_model_trainer.py b/tests/translation/thot/test_thot_word_alignment_model_trainer.py index d5b69c60..4cfd901d 100644 --- a/tests/translation/thot/test_thot_word_alignment_model_trainer.py +++ b/tests/translation/thot/test_thot_word_alignment_model_trainer.py @@ -1,15 +1,19 @@ from pathlib import Path from tempfile import TemporaryDirectory +from testutils.thot_test_helpers import create_test_parallel_corpus from translation.thot.thot_model_trainer_helper import get_emtpy_parallel_corpus, get_parallel_corpus -from machine.corpora.parallel_text_corpus import ParallelTextCorpus +from machine.corpora import ParallelTextCorpus from machine.tokenization import StringTokenizer, WhitespaceTokenizer -from machine.translation.symmetrized_word_alignment_model_trainer import SymmetrizedWordAlignmentModelTrainer -from machine.translation.thot import ThotWordAlignmentModelTrainer -from machine.translation.thot.thot_symmetrized_word_alignment_model import ThotSymmetrizedWordAlignmentModel -from machine.translation.thot.thot_word_alignment_model_utils import create_thot_word_alignment_model -from machine.translation.word_alignment_matrix import WordAlignmentMatrix +from machine.translation import SymmetrizedWordAlignmentModelTrainer, WordAlignmentMatrix +from machine.translation.thot import ( + ThotFastAlignWordAlignmentModel, + ThotSymmetrizedWordAlignmentModel, + ThotWordAlignmentModelTrainer, + create_thot_symmetrized_word_alignment_model, + create_thot_word_alignment_model, +) def train_model( @@ -78,6 +82,39 @@ def test_train_empty_corpus() -> None: assert matrix == WordAlignmentMatrix.from_word_pairs(5, 6, {(0, 0)}) -if __name__ == "__main__": - test_train_non_empty_corpus() - test_train_empty_corpus() +def test_emit_training_alignments_single_direction() -> None: + corpus = create_test_parallel_corpus() + row = next(iter(corpus.get_rows())) + model = ThotFastAlignWordAlignmentModel() + model.emit_training_alignments = True + with model.create_trainer(corpus) as trainer: + trainer.train() + trainer.save() + assert model.training_alignment_count == 8 + # For a deterministic model, the retained training alignment matches the inference alignment, + # and it survives the trainer being closed. + assert model.get_training_alignment(0) == model.align(row.source_segment, row.target_segment) + + +def test_emit_training_alignments_symmetrized() -> None: + corpus = create_test_parallel_corpus() + row = next(iter(corpus.get_rows())) + model = create_thot_symmetrized_word_alignment_model("fast_align") + model.emit_training_alignments = True + with model.create_trainer(corpus) as trainer: + trainer.train() + trainer.save() + assert model.training_alignment_count == 8 + # The C++ symmetrized transductive alignment matches the C++ symmetrized inference alignment. + assert model.get_training_alignment(0) == model.align(row.source_segment, row.target_segment) + + +def test_emit_training_alignments_disabled() -> None: + corpus = create_test_parallel_corpus() + model = ThotFastAlignWordAlignmentModel() + with model.create_trainer(corpus) as trainer: + trainer.train() + trainer.save() + # When emission is not enabled, retrieval returns a degenerate result rather than raising. + alignment = model.get_training_alignment(0) + assert isinstance(alignment, WordAlignmentMatrix)