utils.py
3.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Copyright (c) 2023 Xiaomi Corporation
import re
from pathlib import Path
from typing import List, Optional, Union
import sentencepiece as spm
def text2token(
texts: List[str],
tokens: str,
tokens_type: str = "cjkchar",
bpe_model: Optional[str] = None,
output_ids: bool = False,
) -> List[List[Union[str, int]]]:
"""
Encode the given texts (a list of string) to a list of a list of tokens.
Args:
texts:
The given contexts list (a list of string).
tokens:
The path of the tokens.txt.
tokens_type:
The valid values are cjkchar, bpe, cjkchar+bpe.
bpe_model:
The path of the bpe model. Only required when tokens_type is bpe or
cjkchar+bpe.
output_ids:
True to output token ids otherwise tokens.
Returns:
Return the encoded texts, it is a list of a list of token ids if output_ids
is True, or it is a list of list of tokens.
"""
assert Path(tokens).is_file(), f"File not exists, {tokens}"
tokens_table = {}
with open(tokens, "r", encoding="utf-8") as f:
for line in f:
toks = line.strip().split()
assert len(toks) == 2, len(toks)
assert toks[0] not in tokens_table, f"Duplicate token: {toks} "
tokens_table[toks[0]] = int(toks[1])
if "bpe" in tokens_type:
assert Path(bpe_model).is_file(), f"File not exists, {bpe_model}"
sp = spm.SentencePieceProcessor()
sp.load(bpe_model)
texts_list: List[List[str]] = []
if tokens_type == "cjkchar":
texts_list = [list("".join(text.split())) for text in texts]
elif tokens_type == "bpe":
texts_list = sp.encode(texts, out_type=str)
else:
assert (
tokens_type == "cjkchar+bpe"
), f"Supported tokens_type are cjkchar, bpe, cjkchar+bpe, given {tokens_type}"
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
pattern = re.compile(r"([\u4e00-\u9fff])")
for text in texts:
# Example:
# txt = "你好 ITS'S OKAY 的"
# chars = ["你", "好", " ITS'S OKAY ", "的"]
chars = pattern.split(text)
mix_chars = [w for w in chars if len(w.strip()) > 0]
text_list = []
for ch_or_w in mix_chars:
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
if pattern.fullmatch(ch_or_w) is not None:
text_list.append(ch_or_w)
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
# encode ch_or_w using bpe_model.
else:
text_list += sp.encode_as_pieces(ch_or_w)
texts_list.append(text_list)
result: List[List[Union[int, str]]] = []
for text in texts_list:
text_list = []
contain_oov = False
for txt in text:
if txt in tokens_table:
text_list.append(tokens_table[txt] if output_ids else txt)
else:
print(f"OOV token : {txt}, skipping text : {text}.")
contain_oov = True
break
if contain_oov:
continue
else:
result.append(text_list)
return result