diff --git a/nemo_text_processing/text_normalization/en/taggers/serial.py b/nemo_text_processing/text_normalization/en/taggers/serial.py index f650c8ff3..6e4ec3d5b 100644 --- a/nemo_text_processing/text_normalization/en/taggers/serial.py +++ b/nemo_text_processing/text_normalization/en/taggers/serial.py @@ -28,16 +28,102 @@ from nemo_text_processing.text_normalization.en.utils import get_abs_path, load_labels +def _leading_zero_graph(cardinal: GraphFst) -> "pynini.FstLike": + return pynini.compose(pynini.accep("0") + pynini.closure(NEMO_DIGIT), cardinal.single_digits_graph).optimize() + + +def _num_graph_alphanumeric(cardinal: GraphFst) -> "pynini.FstLike": + """ + For digit runs inside letter-digit tokens: + - 1-2 digits, or single digits followed by zeros -> cardinal + - 3 digits not ending in 00, or 4+ digits -> single-digit reading + """ + + single_digit_followed_by_zeros = pynini.compose(NEMO_DIGIT, pynini.closure("0", 1)) + three_not_00 = pynini.difference(NEMO_DIGIT**3, NEMO_DIGIT + NEMO_DIGIT + "00") + return ( + pynini.compose(NEMO_DIGIT, cardinal.graph) + | pynini.compose(NEMO_DIGIT**2, cardinal.graph) + | pynini.compose(NEMO_DIGIT + pynini.closure("0", 1), cardinal.graph) + | pynini.compose(three_not_00, cardinal.single_digits_graph) + | pynini.compose(NEMO_DIGIT ** (4, ...), cardinal.single_digits_graph) + | _leading_zero_graph(cardinal) + ).optimize() + + +def _num_graph_pure(cardinal: GraphFst) -> "pynini.FstLike": + """Digit-only serial segments between non-slash delimiters.""" + return ( + pynini.compose(NEMO_DIGIT ** (1, 3), cardinal.graph) + | pynini.compose(NEMO_DIGIT ** (4, ...), cardinal.single_digits_graph) + | _leading_zero_graph(cardinal) + ).optimize() + + +def _num_graph_slash(cardinal: GraphFst) -> "pynini.FstLike": + """Slash-separated digit-only tokens use cardinal (5+ digits stay single-digit).""" + return ( + pynini.compose(NEMO_DIGIT ** (1, 4), cardinal.graph) + | pynini.compose(NEMO_DIGIT ** (5, ...), cardinal.single_digits_graph) + | _leading_zero_graph(cardinal) + ).optimize() + + +def _build_serial_graph( + num_graph: "pynini.FstLike", + delimiter: "pynini.FstLike", + alphas: "pynini.FstLike", + ordinal: GraphFst, +) -> "pynini.FstLike": + letter_num = alphas + delimiter + num_graph + num_letter = pynini.closure(num_graph + delimiter, 1) + alphas + next_alpha_or_num = pynini.closure(delimiter + (alphas | num_graph)) + next_alpha_or_num |= pynini.closure( + delimiter + + num_graph + + plurals._priority_union(pynini.accep(" "), pynutil.insert(" "), NEMO_SIGMA).optimize() + + alphas + ) + + serial_graph = letter_num + next_alpha_or_num + serial_graph |= num_letter + next_alpha_or_num + serial_graph |= num_graph + delimiter + num_graph + delimiter + num_graph + pynini.closure(delimiter + num_graph) + + symbols = [x[0] for x in load_labels(get_abs_path("data/whitelist/symbol.tsv"))] + symbols = pynini.union(*symbols) + serial_graph |= pynini.compose(NEMO_SIGMA + symbols + NEMO_SIGMA, num_graph + delimiter + num_graph) + + serial_graph = pynini.compose( + pynini.difference(NEMO_SIGMA, pynini.project(ordinal.graph, "input")), serial_graph + ).optimize() + + serial_graph = pynutil.add_weight(serial_graph, 0.0001) + serial_graph |= ( + pynini.closure(NEMO_NOT_SPACE, 1) + (pynini.cross("^2", " squared") | pynini.cross("^3", " cubed")).optimize() + ) + + serial_graph = ( + pynini.closure((serial_graph | num_graph | alphas) + delimiter) + + serial_graph + + pynini.closure(delimiter + (serial_graph | num_graph | alphas)) + ) + return serial_graph.optimize() + + class SerialFst(GraphFst): """ - This class is a composite class of two other class instances + Finite state transducer for classifying serial numbers without conventional delimiters. + + Digit normalization within letter-digit tokens follows: + 1. 1-2 digits, or single digits followed by zeros -> cardinal + 2. 3 digits not ending in 00, or 4+ digits -> single-digit reading + 3. Digit-only tokens separated by ``/`` -> cardinal per segment (5+ digits stay single-digit) Args: - time: composed tagger and verbalizer - date: composed tagger and verbalizer - cardinal: tagger + cardinal: cardinal tagger + ordinal: ordinal tagger (used to exclude ordinal readings) deterministic: if True will provide a single transduction option, - for False multiple transduction are generated (used for audio-based normalization) + for False multiple transduction are generated (used for audio-based normalization) lm: whether to use for hybrid LM """ @@ -51,28 +137,28 @@ def __init__(self, cardinal: GraphFst, ordinal: GraphFst, deterministic: bool = c325b -> tokens { cardinal { integer: "c three two five b" } } """ if deterministic: - num_graph = pynini.compose(NEMO_DIGIT ** (6, ...), cardinal.single_digits_graph).optimize() - num_graph |= pynini.compose(NEMO_DIGIT ** (1, 5), cardinal.graph).optimize() - # to handle numbers starting with zero - num_graph |= pynini.compose( - pynini.accep("0") + pynini.closure(NEMO_DIGIT), cardinal.single_digits_graph - ).optimize() + num_graph_pure = _num_graph_pure(cardinal) + num_graph_alnum = _num_graph_alphanumeric(cardinal) + num_graph_slash = _num_graph_slash(cardinal) else: - num_graph = cardinal.final_graph + num_graph_pure = cardinal.final_graph + num_graph_alnum = cardinal.final_graph + num_graph_slash = cardinal.final_graph # TODO: "#" doesn't work from the file symbols_graph = pynini.string_file(get_abs_path("data/whitelist/symbol.tsv")).optimize() | pynini.cross( "#", "hash" ) - num_graph |= symbols_graph + num_graph_pure |= symbols_graph + num_graph_alnum |= symbols_graph if not self.deterministic and not lm: - num_graph |= cardinal.single_digits_graph - num_graph |= pynini.compose(num_graph, NEMO_SIGMA + pynutil.delete("hundred ") + NEMO_SIGMA) - # also allow double digits to be pronounced as integer in serial number - num_graph |= pynutil.add_weight( + num_graph_pure |= cardinal.single_digits_graph + num_graph_pure |= pynini.compose(num_graph_pure, NEMO_SIGMA + pynutil.delete("hundred ") + NEMO_SIGMA) + num_graph_pure |= pynutil.add_weight( NEMO_DIGIT**2 @ cardinal.graph_hundred_component_at_least_one_none_zero_digit, weight=0.0001 ) + num_graph_alnum = num_graph_pure # add space between letter and digit/symbol symbols = [x[0] for x in load_labels(get_abs_path("data/whitelist/symbol.tsv"))] @@ -90,44 +176,21 @@ def __init__(self, cardinal: GraphFst, ordinal: GraphFst, deterministic: bool = delimiter |= pynini.cross("-", " dash ") | pynini.cross("/", " slash ") alphas = pynini.closure(NEMO_ALPHA, 1) - letter_num = alphas + delimiter + num_graph - num_letter = pynini.closure(num_graph + delimiter, 1) + alphas - next_alpha_or_num = pynini.closure(delimiter + (alphas | num_graph)) - next_alpha_or_num |= pynini.closure( - delimiter - + num_graph - + plurals._priority_union(pynini.accep(" "), pynutil.insert(" "), NEMO_SIGMA).optimize() - + alphas - ) - - serial_graph = letter_num + next_alpha_or_num - serial_graph |= num_letter + next_alpha_or_num - # numbers only with 2+ delimiters - serial_graph |= ( - num_graph + delimiter + num_graph + delimiter + num_graph + pynini.closure(delimiter + num_graph) - ) - # 2+ symbols - serial_graph |= pynini.compose(NEMO_SIGMA + symbols + NEMO_SIGMA, num_graph + delimiter + num_graph) - - # exclude ordinal numbers from serial options - serial_graph = pynini.compose( - pynini.difference(NEMO_SIGMA, pynini.project(ordinal.graph, "input")), serial_graph - ).optimize() - serial_graph = pynutil.add_weight(serial_graph, 0.0001) - serial_graph |= ( - pynini.closure(NEMO_NOT_SPACE, 1) - + (pynini.cross("^2", " squared") | pynini.cross("^3", " cubed")).optimize() - ) + serial_graph = _build_serial_graph(num_graph_pure, delimiter, alphas, ordinal) + serial_graph_alnum = _build_serial_graph(num_graph_alnum, delimiter, alphas, ordinal) - # at least one serial graph with alpha numeric value and optional additional serial/num/alpha values - serial_graph = ( - pynini.closure((serial_graph | num_graph | alphas) + delimiter) - + serial_graph - + pynini.closure(delimiter + (serial_graph | num_graph | alphas)) + # Rule 3: tokens that contain only digits and slashes (e.g. 31/31/100, 123/261788/2021). + slash_digit_token = ( + pynini.closure(NEMO_DIGIT, 1) + pynini.accep("/") + pynini.closure(NEMO_DIGIT | pynini.accep("/"), 0) ) + slash_serial = pynini.compose( + slash_digit_token, + pynini.closure(num_graph_slash + pynini.accep("/"), 1) + num_graph_slash, + ).optimize() + serial_graph |= pynutil.add_weight(slash_serial, -0.0001) - serial_graph |= pynini.compose(graph_with_space, serial_graph.optimize()).optimize() + serial_graph |= pynini.compose(graph_with_space, serial_graph_alnum.optimize()).optimize() serial_graph = pynini.compose(pynini.closure(NEMO_NOT_SPACE, 2), serial_graph).optimize() # this is not to verbolize "/" as "slash" in cases like "import/export" diff --git a/tests/nemo_text_processing/en/data_text_normalization/test_cases_ordinal.txt b/tests/nemo_text_processing/en/data_text_normalization/test_cases_ordinal.txt index 2e1b5ec7e..d4f073525 100644 --- a/tests/nemo_text_processing/en/data_text_normalization/test_cases_ordinal.txt +++ b/tests/nemo_text_processing/en/data_text_normalization/test_cases_ordinal.txt @@ -24,4 +24,4 @@ 21th~twenty one th 121st~one hundred twenty first 111th~one hundred eleventh -111st~one hundred eleven st \ No newline at end of file +111st~one one one st \ No newline at end of file diff --git a/tests/nemo_text_processing/en/data_text_normalization/test_cases_serial.txt b/tests/nemo_text_processing/en/data_text_normalization/test_cases_serial.txt index f0a6e0a3f..f142ceb7e 100644 --- a/tests/nemo_text_processing/en/data_text_normalization/test_cases_serial.txt +++ b/tests/nemo_text_processing/en/data_text_normalization/test_cases_serial.txt @@ -29,3 +29,5 @@ a 4-kilogram bag~a four-kilogram bag 100-car~one hundred-car 123/261788/2021~one hundred twenty three/two six one seven eight eight/two thousand twenty one 2*8~two asterisk eight +my pnr is t2000~my pnr is t two thousand +your otp is ab9453~your otp is ab nine four five three \ No newline at end of file