test_text2token.py 3.8 KB
# sherpa-onnx/python/tests/test_text2token.py
#
# Copyright (c)  2023  Xiaomi Corporation
#
# To run this single test, use
#
#  ctest --verbose -R  test_text2token_py

import unittest
from pathlib import Path

import sherpa_onnx

d = "/tmp/sherpa-test-data"
# Please refer to
# https://github.com/pkufool/sherpa-test-data
# to download test data for testing


class TestText2Token(unittest.TestCase):
    def test_bpe(self):
        tokens = f"{d}/text2token/tokens_en.txt"
        bpe_model = f"{d}/text2token/bpe_en.model"

        if not Path(tokens).is_file() or not Path(bpe_model).is_file():
            print(
                f"No test data found, skipping test_bpe().\n"
                f"You can download the test data by: \n"
                f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
            )
            return

        texts = ["HELLO WORLD", "I LOVE YOU"]
        encoded_texts = sherpa_onnx.text2token(
            texts,
            tokens=tokens,
            tokens_type="bpe",
            bpe_model=bpe_model,
        )
        assert encoded_texts == [
            ["▁HE", "LL", "O", "▁WORLD"],
            ["▁I", "▁LOVE", "▁YOU"],
        ], encoded_texts

        encoded_ids = sherpa_onnx.text2token(
            texts,
            tokens=tokens,
            tokens_type="bpe",
            bpe_model=bpe_model,
            output_ids=True,
        )
        assert encoded_ids == [[22, 58, 24, 425], [19, 370, 47]], encoded_ids

    def test_cjkchar(self):
        tokens = f"{d}/text2token/tokens_cn.txt"

        if not Path(tokens).is_file():
            print(
                f"No test data found, skipping test_cjkchar().\n"
                f"You can download the test data by: \n"
                f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
            )
            return

        texts = ["世界人民大团结", "中国 VS 美国"]
        encoded_texts = sherpa_onnx.text2token(
            texts, tokens=tokens, tokens_type="cjkchar"
        )
        assert encoded_texts == [
            ["世", "界", "人", "民", "大", "团", "结"],
            ["中", "国", "V", "S", "美", "国"],
        ], encoded_texts
        encoded_ids = sherpa_onnx.text2token(
            texts,
            tokens=tokens,
            tokens_type="cjkchar",
            output_ids=True,
        )
        assert encoded_ids == [
            [379, 380, 72, 874, 93, 1251, 489],
            [262, 147, 3423, 2476, 21, 147],
        ], encoded_ids

    def test_cjkchar_bpe(self):
        tokens = f"{d}/text2token/tokens_mix.txt"
        bpe_model = f"{d}/text2token/bpe_mix.model"

        if not Path(tokens).is_file() or not Path(bpe_model).is_file():
            print(
                f"No test data found, skipping test_cjkchar_bpe().\n"
                f"You can download the test data by: \n"
                f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
            )
            return

        texts = ["世界人民 GOES TOGETHER", "中国 GOES WITH 美国"]
        encoded_texts = sherpa_onnx.text2token(
            texts,
            tokens=tokens,
            tokens_type="cjkchar+bpe",
            bpe_model=bpe_model,
        )
        assert encoded_texts == [
            ["世", "界", "人", "民", "▁GO", "ES", "▁TOGETHER"],
            ["中", "国", "▁GO", "ES", "▁WITH", "美", "国"],
        ], encoded_texts
        encoded_ids = sherpa_onnx.text2token(
            texts,
            tokens=tokens,
            tokens_type="cjkchar+bpe",
            bpe_model=bpe_model,
            output_ids=True,
        )
        assert encoded_ids == [
            [1368, 1392, 557, 680, 275, 178, 475],
            [685, 736, 275, 178, 179, 921, 736],
        ], encoded_ids


if __name__ == "__main__":
    unittest.main()