test_text2token.py
3.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# sherpa-onnx/python/tests/test_text2token.py
#
# Copyright (c) 2023 Xiaomi Corporation
#
# To run this single test, use
#
# ctest --verbose -R test_text2token_py
import unittest
from pathlib import Path
import sherpa_onnx
d = "/tmp/sherpa-test-data"
# Please refer to
# https://github.com/pkufool/sherpa-test-data
# to download test data for testing
class TestText2Token(unittest.TestCase):
def test_bpe(self):
tokens = f"{d}/text2token/tokens_en.txt"
bpe_model = f"{d}/text2token/bpe_en.model"
if not Path(tokens).is_file() or not Path(bpe_model).is_file():
print(
f"No test data found, skipping test_bpe().\n"
f"You can download the test data by: \n"
f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
)
return
texts = ["HELLO WORLD", "I LOVE YOU"]
encoded_texts = sherpa_onnx.text2token(
texts,
tokens=tokens,
tokens_type="bpe",
bpe_model=bpe_model,
)
assert encoded_texts == [
["▁HE", "LL", "O", "▁WORLD"],
["▁I", "▁LOVE", "▁YOU"],
], encoded_texts
encoded_ids = sherpa_onnx.text2token(
texts,
tokens=tokens,
tokens_type="bpe",
bpe_model=bpe_model,
output_ids=True,
)
assert encoded_ids == [[22, 58, 24, 425], [19, 370, 47]], encoded_ids
def test_cjkchar(self):
tokens = f"{d}/text2token/tokens_cn.txt"
if not Path(tokens).is_file():
print(
f"No test data found, skipping test_cjkchar().\n"
f"You can download the test data by: \n"
f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
)
return
texts = ["世界人民大团结", "中国 VS 美国"]
encoded_texts = sherpa_onnx.text2token(
texts, tokens=tokens, tokens_type="cjkchar"
)
assert encoded_texts == [
["世", "界", "人", "民", "大", "团", "结"],
["中", "国", "V", "S", "美", "国"],
], encoded_texts
encoded_ids = sherpa_onnx.text2token(
texts,
tokens=tokens,
tokens_type="cjkchar",
output_ids=True,
)
assert encoded_ids == [
[379, 380, 72, 874, 93, 1251, 489],
[262, 147, 3423, 2476, 21, 147],
], encoded_ids
def test_cjkchar_bpe(self):
tokens = f"{d}/text2token/tokens_mix.txt"
bpe_model = f"{d}/text2token/bpe_mix.model"
if not Path(tokens).is_file() or not Path(bpe_model).is_file():
print(
f"No test data found, skipping test_cjkchar_bpe().\n"
f"You can download the test data by: \n"
f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
)
return
texts = ["世界人民 GOES TOGETHER", "中国 GOES WITH 美国"]
encoded_texts = sherpa_onnx.text2token(
texts,
tokens=tokens,
tokens_type="cjkchar+bpe",
bpe_model=bpe_model,
)
assert encoded_texts == [
["世", "界", "人", "民", "▁GO", "ES", "▁TOGETHER"],
["中", "国", "▁GO", "ES", "▁WITH", "美", "国"],
], encoded_texts
encoded_ids = sherpa_onnx.text2token(
texts,
tokens=tokens,
tokens_type="cjkchar+bpe",
bpe_model=bpe_model,
output_ids=True,
)
assert encoded_ids == [
[1368, 1392, 557, 680, 275, 178, 475],
[685, 736, 275, 178, 179, 921, 736],
], encoded_ids
if __name__ == "__main__":
unittest.main()