utils.py
4.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright (c) 2023 Xiaomi Corporation
import re
from pathlib import Path
from typing import List, Optional, Union
def text2token(
texts: List[str],
tokens: str,
tokens_type: str = "cjkchar",
bpe_model: Optional[str] = None,
output_ids: bool = False,
) -> List[List[Union[str, int]]]:
"""
Encode the given texts (a list of string) to a list of a list of tokens.
Args:
texts:
The given contexts list (a list of string).
tokens:
The path of the tokens.txt.
tokens_type:
The valid values are cjkchar, bpe, cjkchar+bpe, fpinyin, ppinyin.
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
ppinyin means partial pinyin, it splits pinyin into initial and final,
bpe_model:
The path of the bpe model. Only required when tokens_type is bpe or
cjkchar+bpe.
output_ids:
True to output token ids otherwise tokens.
Returns:
Return the encoded texts, it is a list of a list of token ids if output_ids
is True, or it is a list of list of tokens.
"""
try:
import sentencepiece as spm
except ImportError:
print("Please run")
print(" pip install sentencepiece")
print("before you continue")
raise
try:
from pypinyin import pinyin
from pypinyin.contrib.tone_convert import to_initials, to_finals_tone
except ImportError:
print("Please run")
print(" pip install pypinyin")
print("before you continue")
raise
assert Path(tokens).is_file(), f"File not exists, {tokens}"
tokens_table = {}
with open(tokens, "r", encoding="utf-8") as f:
for line in f:
toks = line.strip().split()
assert len(toks) == 2, len(toks)
assert toks[0] not in tokens_table, f"Duplicate token: {toks} "
tokens_table[toks[0]] = int(toks[1])
if "bpe" in tokens_type:
assert Path(bpe_model).is_file(), f"File not exists, {bpe_model}"
sp = spm.SentencePieceProcessor()
sp.load(bpe_model)
texts_list: List[List[str]] = []
if tokens_type == "cjkchar":
texts_list = [list("".join(text.split())) for text in texts]
elif tokens_type == "bpe":
texts_list = sp.encode(texts, out_type=str)
elif "pinyin" in tokens_type:
for txt in texts:
py = [x[0] for x in pinyin(txt)]
if "ppinyin" == tokens_type:
res = []
for x in py:
initial = to_initials(x, strict=False)
final = to_finals_tone(x, strict=False)
if initial == "" and final == "":
res.append(x)
else:
if initial != "":
res.append(initial)
if final != "":
res.append(final)
texts_list.append(res)
else:
texts_list.append(py)
else:
assert (
tokens_type == "cjkchar+bpe"
), f"Supported tokens_type are cjkchar, bpe, cjkchar+bpe, given {tokens_type}"
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
pattern = re.compile(r"([\u4e00-\u9fff])")
for text in texts:
# Example:
# txt = "你好 ITS'S OKAY 的"
# chars = ["你", "好", " ITS'S OKAY ", "的"]
chars = pattern.split(text)
mix_chars = [w for w in chars if len(w.strip()) > 0]
text_list = []
for ch_or_w in mix_chars:
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
if pattern.fullmatch(ch_or_w) is not None:
text_list.append(ch_or_w)
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
# encode ch_or_w using bpe_model.
else:
text_list += sp.encode_as_pieces(ch_or_w)
texts_list.append(text_list)
result: List[List[Union[int, str]]] = []
for text in texts_list:
text_list = []
contain_oov = False
for txt in text:
if txt in tokens_table:
text_list.append(tokens_table[txt] if output_ids else txt)
else:
print(
f"Can't find token {txt} in token table, check your "
f"tokens.txt see if {txt} in it. skipping text : {text}."
)
contain_oov = True
break
if contain_oov:
continue
else:
result.append(text_list)
return result