Wei Kang
Committed by GitHub

Refactor hotwords,support loading hotwords from file (#296)

正在显示 34 个修改的文件 包含 800 行增加297 行删除
@@ -166,3 +166,8 @@ python3 ./python-api-examples/offline-decode-files.py \ @@ -166,3 +166,8 @@ python3 ./python-api-examples/offline-decode-files.py \
166 python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose 166 python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
167 167
168 rm -rf $repo 168 rm -rf $repo
  169 +
  170 +# test text2token
  171 +git clone https://github.com/pkufool/sherpa-test-data /tmp/sherpa-test-data
  172 +
  173 +python3 sherpa-onnx/python/tests/test_text2token.py --verbose
@@ -39,7 +39,7 @@ jobs: @@ -39,7 +39,7 @@ jobs:
39 - name: Install Python dependencies 39 - name: Install Python dependencies
40 shell: bash 40 shell: bash
41 run: | 41 run: |
42 - python3 -m pip install --upgrade pip numpy 42 + python3 -m pip install --upgrade pip numpy sentencepiece
43 43
44 - name: Install sherpa-onnx 44 - name: Install sherpa-onnx
45 shell: bash 45 shell: bash
@@ -39,7 +39,7 @@ jobs: @@ -39,7 +39,7 @@ jobs:
39 - name: Install Python dependencies 39 - name: Install Python dependencies
40 shell: bash 40 shell: bash
41 run: | 41 run: |
42 - python3 -m pip install --upgrade pip numpy 42 + python3 -m pip install --upgrade pip numpy sentencepiece
43 43
44 - name: Install sherpa-onnx 44 - name: Install sherpa-onnx
45 shell: bash 45 shell: bash
@@ -326,6 +326,31 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser): @@ -326,6 +326,31 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser):
326 ) 326 )
327 327
328 328
  329 +def add_hotwords_args(parser: argparse.ArgumentParser):
  330 + parser.add_argument(
  331 + "--hotwords-file",
  332 + type=str,
  333 + default="",
  334 + help="""
  335 + The file containing hotwords, one words/phrases per line, and for each
  336 + phrase the bpe/cjkchar are separated by a space. For example:
  337 +
  338 + ▁HE LL O ▁WORLD
  339 + 你 好 世 界
  340 + """,
  341 + )
  342 +
  343 + parser.add_argument(
  344 + "--hotwords-score",
  345 + type=float,
  346 + default=1.5,
  347 + help="""
  348 + The hotword score of each token for biasing word/phrase. Used only if
  349 + --hotwords-file is given.
  350 + """,
  351 + )
  352 +
  353 +
329 def check_args(args): 354 def check_args(args):
330 if not Path(args.tokens).is_file(): 355 if not Path(args.tokens).is_file():
331 raise ValueError(f"{args.tokens} does not exist") 356 raise ValueError(f"{args.tokens} does not exist")
@@ -342,6 +367,10 @@ def check_args(args): @@ -342,6 +367,10 @@ def check_args(args):
342 assert Path(args.decoder).is_file(), args.decoder 367 assert Path(args.decoder).is_file(), args.decoder
343 assert Path(args.joiner).is_file(), args.joiner 368 assert Path(args.joiner).is_file(), args.joiner
344 369
  370 + if args.hotwords_file != "":
  371 + assert args.decoding_method == "modified_beam_search", args.decoding_method
  372 + assert Path(args.hotwords_file).is_file(), args.hotwords_file
  373 +
345 374
346 def get_args(): 375 def get_args():
347 parser = argparse.ArgumentParser( 376 parser = argparse.ArgumentParser(
@@ -351,6 +380,7 @@ def get_args(): @@ -351,6 +380,7 @@ def get_args():
351 add_model_args(parser) 380 add_model_args(parser)
352 add_feature_config_args(parser) 381 add_feature_config_args(parser)
353 add_decoding_args(parser) 382 add_decoding_args(parser)
  383 + add_hotwords_args(parser)
354 384
355 parser.add_argument( 385 parser.add_argument(
356 "--port", 386 "--port",
@@ -792,6 +822,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -792,6 +822,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
792 feature_dim=args.feat_dim, 822 feature_dim=args.feat_dim,
793 decoding_method=args.decoding_method, 823 decoding_method=args.decoding_method,
794 max_active_paths=args.max_active_paths, 824 max_active_paths=args.max_active_paths,
  825 + hotwords_file=args.hotwords_file,
  826 + hotwords_score=args.hotwords_score,
795 ) 827 )
796 elif args.paraformer: 828 elif args.paraformer:
797 assert len(args.nemo_ctc) == 0, args.nemo_ctc 829 assert len(args.nemo_ctc) == 0, args.nemo_ctc
@@ -82,7 +82,6 @@ from pathlib import Path @@ -82,7 +82,6 @@ from pathlib import Path
82 from typing import List, Tuple 82 from typing import List, Tuple
83 83
84 import numpy as np 84 import numpy as np
85 -import sentencepiece as spm  
86 import sherpa_onnx 85 import sherpa_onnx
87 86
88 87
@@ -98,43 +97,25 @@ def get_args(): @@ -98,43 +97,25 @@ def get_args():
98 ) 97 )
99 98
100 parser.add_argument( 99 parser.add_argument(
101 - "--bpe-model", 100 + "--hotwords-file",
102 type=str, 101 type=str,
103 default="", 102 default="",
104 help=""" 103 help="""
105 - Path to bpe.model,  
106 - Used only when --decoding-method=modified_beam_search  
107 - """,  
108 - ) 104 + The file containing hotwords, one words/phrases per line, and for each
  105 + phrase the bpe/cjkchar are separated by a space. For example:
109 106
110 - parser.add_argument(  
111 - "--modeling-unit",  
112 - type=str,  
113 - default="char",  
114 - help="""  
115 - The type of modeling unit.  
116 - Valid values are bpe, bpe+char, char.  
117 - Note: the char here means characters in CJK languages. 107 + ▁HE LL O ▁WORLD
  108 + 你 好 世 界
118 """, 109 """,
119 ) 110 )
120 111
121 parser.add_argument( 112 parser.add_argument(
122 - "--contexts",  
123 - type=str,  
124 - default="",  
125 - help="""  
126 - The context list, it is a string containing some words/phrases separated  
127 - with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY".  
128 - """,  
129 - )  
130 -  
131 - parser.add_argument(  
132 - "--context-score", 113 + "--hotwords-score",
133 type=float, 114 type=float,
134 default=1.5, 115 default=1.5,
135 help=""" 116 help="""
136 - The context score of each token for biasing word/phrase. Used only if  
137 - --contexts is given. 117 + The hotword score of each token for biasing word/phrase. Used only if
  118 + --hotwords-file is given.
138 """, 119 """,
139 ) 120 )
140 121
@@ -273,25 +254,6 @@ def assert_file_exists(filename: str): @@ -273,25 +254,6 @@ def assert_file_exists(filename: str):
273 "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" 254 "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
274 ) 255 )
275 256
276 -  
277 -def encode_contexts(args, contexts: List[str]) -> List[List[int]]:  
278 - sp = None  
279 - if "bpe" in args.modeling_unit:  
280 - assert_file_exists(args.bpe_model)  
281 - sp = spm.SentencePieceProcessor()  
282 - sp.load(args.bpe_model)  
283 - tokens = {}  
284 - with open(args.tokens, "r", encoding="utf-8") as f:  
285 - for line in f:  
286 - toks = line.strip().split()  
287 - assert len(toks) == 2, len(toks)  
288 - assert toks[0] not in tokens, f"Duplicate token: {toks} "  
289 - tokens[toks[0]] = int(toks[1])  
290 - return sherpa_onnx.encode_contexts(  
291 - modeling_unit=args.modeling_unit, contexts=contexts, sp=sp, tokens_table=tokens  
292 - )  
293 -  
294 -  
295 def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: 257 def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
296 """ 258 """
297 Args: 259 Args:
@@ -322,7 +284,6 @@ def main(): @@ -322,7 +284,6 @@ def main():
322 assert_file_exists(args.tokens) 284 assert_file_exists(args.tokens)
323 assert args.num_threads > 0, args.num_threads 285 assert args.num_threads > 0, args.num_threads
324 286
325 - contexts_list = []  
326 if args.encoder: 287 if args.encoder:
327 assert len(args.paraformer) == 0, args.paraformer 288 assert len(args.paraformer) == 0, args.paraformer
328 assert len(args.nemo_ctc) == 0, args.nemo_ctc 289 assert len(args.nemo_ctc) == 0, args.nemo_ctc
@@ -330,11 +291,6 @@ def main(): @@ -330,11 +291,6 @@ def main():
330 assert len(args.whisper_decoder) == 0, args.whisper_decoder 291 assert len(args.whisper_decoder) == 0, args.whisper_decoder
331 assert len(args.tdnn_model) == 0, args.tdnn_model 292 assert len(args.tdnn_model) == 0, args.tdnn_model
332 293
333 - contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]  
334 - if contexts:  
335 - print(f"Contexts list: {contexts}")  
336 - contexts_list = encode_contexts(args, contexts)  
337 -  
338 assert_file_exists(args.encoder) 294 assert_file_exists(args.encoder)
339 assert_file_exists(args.decoder) 295 assert_file_exists(args.decoder)
340 assert_file_exists(args.joiner) 296 assert_file_exists(args.joiner)
@@ -348,7 +304,8 @@ def main(): @@ -348,7 +304,8 @@ def main():
348 sample_rate=args.sample_rate, 304 sample_rate=args.sample_rate,
349 feature_dim=args.feature_dim, 305 feature_dim=args.feature_dim,
350 decoding_method=args.decoding_method, 306 decoding_method=args.decoding_method,
351 - context_score=args.context_score, 307 + hotwords_file=args.hotwords_file,
  308 + hotwords_score=args.hotwords_score,
352 debug=args.debug, 309 debug=args.debug,
353 ) 310 )
354 elif args.paraformer: 311 elif args.paraformer:
@@ -425,12 +382,7 @@ def main(): @@ -425,12 +382,7 @@ def main():
425 samples, sample_rate = read_wave(wave_filename) 382 samples, sample_rate = read_wave(wave_filename)
426 duration = len(samples) / sample_rate 383 duration = len(samples) / sample_rate
427 total_duration += duration 384 total_duration += duration
428 - if contexts_list:  
429 - assert len(args.paraformer) == 0, args.paraformer  
430 - assert len(args.nemo_ctc) == 0, args.nemo_ctc  
431 - s = recognizer.create_stream(contexts_list=contexts_list)  
432 - else:  
433 - s = recognizer.create_stream() 385 + s = recognizer.create_stream()
434 s.accept_waveform(sample_rate, samples) 386 s.accept_waveform(sample_rate, samples)
435 387
436 streams.append(s) 388 streams.append(s)
@@ -48,7 +48,6 @@ from pathlib import Path @@ -48,7 +48,6 @@ from pathlib import Path
48 from typing import List, Tuple 48 from typing import List, Tuple
49 49
50 import numpy as np 50 import numpy as np
51 -import sentencepiece as spm  
52 import sherpa_onnx 51 import sherpa_onnx
53 52
54 53
@@ -124,46 +123,25 @@ def get_args(): @@ -124,46 +123,25 @@ def get_args():
124 ) 123 )
125 124
126 parser.add_argument( 125 parser.add_argument(
127 - "--bpe-model", 126 + "--hotwords-file",
128 type=str, 127 type=str,
129 default="", 128 default="",
130 help=""" 129 help="""
131 - Path to bpe.model, it will be used to tokenize contexts biasing phrases.  
132 - Used only when --decoding-method=modified_beam_search  
133 - """,  
134 - )  
135 -  
136 - parser.add_argument(  
137 - "--modeling-unit",  
138 - type=str,  
139 - default="char",  
140 - help="""  
141 - The type of modeling unit, it will be used to tokenize contexts biasing phrases.  
142 - Valid values are bpe, bpe+char, char.  
143 - Note: the char here means characters in CJK languages.  
144 - Used only when --decoding-method=modified_beam_search  
145 - """,  
146 - ) 130 + The file containing hotwords, one words/phrases per line, and for each
  131 + phrase the bpe/cjkchar are separated by a space. For example:
147 132
148 - parser.add_argument(  
149 - "--contexts",  
150 - type=str,  
151 - default="",  
152 - help="""  
153 - The context list, it is a string containing some words/phrases separated  
154 - with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY".  
155 - Used only when --decoding-method=modified_beam_search 133 + ▁HE LL O ▁WORLD
  134 + 你 好 世 界
156 """, 135 """,
157 ) 136 )
158 137
159 parser.add_argument( 138 parser.add_argument(
160 - "--context-score", 139 + "--hotwords-score",
161 type=float, 140 type=float,
162 default=1.5, 141 default=1.5,
163 help=""" 142 help="""
164 - The context score of each token for biasing word/phrase. Used only if  
165 - --contexts is given.  
166 - Used only when --decoding-method=modified_beam_search 143 + The hotword score of each token for biasing word/phrase. Used only if
  144 + --hotwords-file is given.
167 """, 145 """,
168 ) 146 )
169 147
@@ -214,27 +192,6 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: @@ -214,27 +192,6 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
214 return samples_float32, f.getframerate() 192 return samples_float32, f.getframerate()
215 193
216 194
217 -def encode_contexts(args, contexts: List[str]) -> List[List[int]]:  
218 - sp = None  
219 - if "bpe" in args.modeling_unit:  
220 - assert_file_exists(args.bpe_model)  
221 - sp = spm.SentencePieceProcessor()  
222 - sp.load(args.bpe_model)  
223 - tokens = {}  
224 - with open(args.tokens, "r", encoding="utf-8") as f:  
225 - for line in f:  
226 - toks = line.strip().split()  
227 - assert len(toks) == 2, len(toks)  
228 - assert toks[0] not in tokens, f"Duplicate token: {toks} "  
229 - tokens[toks[0]] = int(toks[1])  
230 - return sherpa_onnx.encode_contexts(  
231 - modeling_unit=args.modeling_unit,  
232 - contexts=contexts,  
233 - sp=sp,  
234 - tokens_table=tokens,  
235 - )  
236 -  
237 -  
238 def main(): 195 def main():
239 args = get_args() 196 args = get_args()
240 assert_file_exists(args.tokens) 197 assert_file_exists(args.tokens)
@@ -258,7 +215,8 @@ def main(): @@ -258,7 +215,8 @@ def main():
258 feature_dim=80, 215 feature_dim=80,
259 decoding_method=args.decoding_method, 216 decoding_method=args.decoding_method,
260 max_active_paths=args.max_active_paths, 217 max_active_paths=args.max_active_paths,
261 - context_score=args.context_score, 218 + hotwords_file=args.hotwords_file,
  219 + hotwords_score=args.hotwords_score,
262 ) 220 )
263 elif args.paraformer_encoder: 221 elif args.paraformer_encoder:
264 recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer( 222 recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
@@ -277,12 +235,6 @@ def main(): @@ -277,12 +235,6 @@ def main():
277 print("Started!") 235 print("Started!")
278 start_time = time.time() 236 start_time = time.time()
279 237
280 - contexts_list = []  
281 - contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]  
282 - if contexts:  
283 - print(f"Contexts list: {contexts}")  
284 - contexts_list = encode_contexts(args, contexts)  
285 -  
286 streams = [] 238 streams = []
287 total_duration = 0 239 total_duration = 0
288 for wave_filename in args.sound_files: 240 for wave_filename in args.sound_files:
@@ -291,10 +243,7 @@ def main(): @@ -291,10 +243,7 @@ def main():
291 duration = len(samples) / sample_rate 243 duration = len(samples) / sample_rate
292 total_duration += duration 244 total_duration += duration
293 245
294 - if contexts_list:  
295 - s = recognizer.create_stream(contexts_list=contexts_list)  
296 - else:  
297 - s = recognizer.create_stream() 246 + s = recognizer.create_stream()
298 247
299 s.accept_waveform(sample_rate, samples) 248 s.accept_waveform(sample_rate, samples)
300 249
@@ -79,6 +79,30 @@ def get_args(): @@ -79,6 +79,30 @@ def get_args():
79 help="Valid values: cpu, cuda, coreml", 79 help="Valid values: cpu, cuda, coreml",
80 ) 80 )
81 81
  82 + parser.add_argument(
  83 + "--hotwords-file",
  84 + type=str,
  85 + default="",
  86 + help="""
  87 + The file containing hotwords, one words/phrases per line, and for each
  88 + phrase the bpe/cjkchar are separated by a space. For example:
  89 +
  90 + ▁HE LL O ▁WORLD
  91 + 你 好 世 界
  92 + """,
  93 + )
  94 +
  95 + parser.add_argument(
  96 + "--hotwords-score",
  97 + type=float,
  98 + default=1.5,
  99 + help="""
  100 + The hotword score of each token for biasing word/phrase. Used only if
  101 + --hotwords-file is given.
  102 + """,
  103 + )
  104 +
  105 +
82 return parser.parse_args() 106 return parser.parse_args()
83 107
84 108
@@ -104,6 +128,8 @@ def create_recognizer(args): @@ -104,6 +128,8 @@ def create_recognizer(args):
104 rule3_min_utterance_length=300, # it essentially disables this rule 128 rule3_min_utterance_length=300, # it essentially disables this rule
105 decoding_method=args.decoding_method, 129 decoding_method=args.decoding_method,
106 provider=args.provider, 130 provider=args.provider,
  131 + hotwords_file=agrs.hotwords_file,
  132 + hotwords_score=args.hotwords_score,
107 ) 133 )
108 return recognizer 134 return recognizer
109 135
@@ -11,7 +11,6 @@ import sys @@ -11,7 +11,6 @@ import sys
11 from pathlib import Path 11 from pathlib import Path
12 12
13 from typing import List 13 from typing import List
14 -import sentencepiece as spm  
15 14
16 try: 15 try:
17 import sounddevice as sd 16 import sounddevice as sd
@@ -90,49 +89,29 @@ def get_args(): @@ -90,49 +89,29 @@ def get_args():
90 ) 89 )
91 90
92 parser.add_argument( 91 parser.add_argument(
93 - "--bpe-model", 92 + "--hotwords-file",
94 type=str, 93 type=str,
95 default="", 94 default="",
96 help=""" 95 help="""
97 - Path to bpe.model, it will be used to tokenize contexts biasing phrases.  
98 - Used only when --decoding-method=modified_beam_search  
99 - """,  
100 - ) 96 + The file containing hotwords, one words/phrases per line, and for each
  97 + phrase the bpe/cjkchar are separated by a space. For example:
101 98
102 - parser.add_argument(  
103 - "--modeling-unit",  
104 - type=str,  
105 - default="char",  
106 - help="""  
107 - The type of modeling unit, it will be used to tokenize contexts biasing phrases.  
108 - Valid values are bpe, bpe+char, char.  
109 - Note: the char here means characters in CJK languages.  
110 - Used only when --decoding-method=modified_beam_search 99 + ▁HE LL O ▁WORLD
  100 + 你 好 世 界
111 """, 101 """,
112 ) 102 )
113 103
114 parser.add_argument( 104 parser.add_argument(
115 - "--contexts",  
116 - type=str,  
117 - default="",  
118 - help="""  
119 - The context list, it is a string containing some words/phrases separated  
120 - with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY".  
121 - Used only when --decoding-method=modified_beam_search  
122 - """,  
123 - )  
124 -  
125 - parser.add_argument(  
126 - "--context-score", 105 + "--hotwords-score",
127 type=float, 106 type=float,
128 default=1.5, 107 default=1.5,
129 help=""" 108 help="""
130 - The context score of each token for biasing word/phrase. Used only if  
131 - --contexts is given.  
132 - Used only when --decoding-method=modified_beam_search 109 + The hotword score of each token for biasing word/phrase. Used only if
  110 + --hotwords-file is given.
133 """, 111 """,
134 ) 112 )
135 113
  114 +
136 return parser.parse_args() 115 return parser.parse_args()
137 116
138 117
@@ -155,32 +134,12 @@ def create_recognizer(args): @@ -155,32 +134,12 @@ def create_recognizer(args):
155 decoding_method=args.decoding_method, 134 decoding_method=args.decoding_method,
156 max_active_paths=args.max_active_paths, 135 max_active_paths=args.max_active_paths,
157 provider=args.provider, 136 provider=args.provider,
158 - context_score=args.context_score, 137 + hotwords_file=args.hotwords_file,
  138 + hotwords_score=args.hotwords_score,
159 ) 139 )
160 return recognizer 140 return recognizer
161 141
162 142
163 -def encode_contexts(args, contexts: List[str]) -> List[List[int]]:  
164 - sp = None  
165 - if "bpe" in args.modeling_unit:  
166 - assert_file_exists(args.bpe_model)  
167 - sp = spm.SentencePieceProcessor()  
168 - sp.load(args.bpe_model)  
169 - tokens = {}  
170 - with open(args.tokens, "r", encoding="utf-8") as f:  
171 - for line in f:  
172 - toks = line.strip().split()  
173 - assert len(toks) == 2, len(toks)  
174 - assert toks[0] not in tokens, f"Duplicate token: {toks} "  
175 - tokens[toks[0]] = int(toks[1])  
176 - return sherpa_onnx.encode_contexts(  
177 - modeling_unit=args.modeling_unit,  
178 - contexts=contexts,  
179 - sp=sp,  
180 - tokens_table=tokens,  
181 - )  
182 -  
183 -  
184 def main(): 143 def main():
185 args = get_args() 144 args = get_args()
186 145
@@ -193,12 +152,6 @@ def main(): @@ -193,12 +152,6 @@ def main():
193 default_input_device_idx = sd.default.device[0] 152 default_input_device_idx = sd.default.device[0]
194 print(f'Use default device: {devices[default_input_device_idx]["name"]}') 153 print(f'Use default device: {devices[default_input_device_idx]["name"]}')
195 154
196 - contexts_list = []  
197 - contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]  
198 - if contexts:  
199 - print(f"Contexts list: {contexts}")  
200 - contexts_list = encode_contexts(args, contexts)  
201 -  
202 recognizer = create_recognizer(args) 155 recognizer = create_recognizer(args)
203 print("Started! Please speak") 156 print("Started! Please speak")
204 157
@@ -207,10 +160,7 @@ def main(): @@ -207,10 +160,7 @@ def main():
207 sample_rate = 48000 160 sample_rate = 48000
208 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms 161 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
209 last_result = "" 162 last_result = ""
210 - if contexts_list:  
211 - stream = recognizer.create_stream(contexts_list=contexts_list)  
212 - else:  
213 - stream = recognizer.create_stream() 163 + stream = recognizer.create_stream()
214 with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: 164 with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
215 while True: 165 while True:
216 samples, _ = s.read(samples_per_read) # a blocking read 166 samples, _ = s.read(samples_per_read) # a blocking read
@@ -87,6 +87,30 @@ def get_args(): @@ -87,6 +87,30 @@ def get_args():
87 """, 87 """,
88 ) 88 )
89 89
  90 + parser.add_argument(
  91 + "--hotwords-file",
  92 + type=str,
  93 + default="",
  94 + help="""
  95 + The file containing hotwords, one words/phrases per line, and for each
  96 + phrase the bpe/cjkchar are separated by a space. For example:
  97 +
  98 + ▁HE LL O ▁WORLD
  99 + 你 好 世 界
  100 + """,
  101 + )
  102 +
  103 + parser.add_argument(
  104 + "--hotwords-score",
  105 + type=float,
  106 + default=1.5,
  107 + help="""
  108 + The hotword score of each token for biasing word/phrase. Used only if
  109 + --hotwords-file is given.
  110 + """,
  111 + )
  112 +
  113 +
90 return parser.parse_args() 114 return parser.parse_args()
91 115
92 116
@@ -107,6 +131,8 @@ def create_recognizer(args): @@ -107,6 +131,8 @@ def create_recognizer(args):
107 rule1_min_trailing_silence=2.4, 131 rule1_min_trailing_silence=2.4,
108 rule2_min_trailing_silence=1.2, 132 rule2_min_trailing_silence=1.2,
109 rule3_min_utterance_length=300, # it essentially disables this rule 133 rule3_min_utterance_length=300, # it essentially disables this rule
  134 + hotwords_file=args.hotwords_file,
  135 + hotwords_score=args.hotwords_score,
110 ) 136 )
111 return recognizer 137 return recognizer
112 138
@@ -187,6 +187,32 @@ def add_decoding_args(parser: argparse.ArgumentParser): @@ -187,6 +187,32 @@ def add_decoding_args(parser: argparse.ArgumentParser):
187 add_modified_beam_search_args(parser) 187 add_modified_beam_search_args(parser)
188 188
189 189
  190 +def add_hotwords_args(parser: argparse.ArgumentParser):
  191 + parser.add_argument(
  192 + "--hotwords-file",
  193 + type=str,
  194 + default="",
  195 + help="""
  196 + The file containing hotwords, one words/phrases per line, and for each
  197 + phrase the bpe/cjkchar are separated by a space. For example:
  198 +
  199 + ▁HE LL O ▁WORLD
  200 + 你 好 世 界
  201 + """,
  202 + )
  203 +
  204 + parser.add_argument(
  205 + "--hotwords-score",
  206 + type=float,
  207 + default=1.5,
  208 + help="""
  209 + The hotword score of each token for biasing word/phrase. Used only if
  210 + --hotwords-file is given.
  211 + """,
  212 + )
  213 +
  214 +
  215 +
190 def add_modified_beam_search_args(parser: argparse.ArgumentParser): 216 def add_modified_beam_search_args(parser: argparse.ArgumentParser):
191 parser.add_argument( 217 parser.add_argument(
192 "--num-active-paths", 218 "--num-active-paths",
@@ -239,6 +265,7 @@ def get_args(): @@ -239,6 +265,7 @@ def get_args():
239 add_model_args(parser) 265 add_model_args(parser)
240 add_decoding_args(parser) 266 add_decoding_args(parser)
241 add_endpointing_args(parser) 267 add_endpointing_args(parser)
  268 + add_hotwords_args(parser)
242 269
243 parser.add_argument( 270 parser.add_argument(
244 "--port", 271 "--port",
@@ -343,6 +370,8 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: @@ -343,6 +370,8 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
343 feature_dim=args.feat_dim, 370 feature_dim=args.feat_dim,
344 decoding_method=args.decoding_method, 371 decoding_method=args.decoding_method,
345 max_active_paths=args.num_active_paths, 372 max_active_paths=args.num_active_paths,
  373 + hotwords_score=args.hotwords_score,
  374 + hotwords_file=args.hotwords_file,
346 enable_endpoint_detection=args.use_endpoint != 0, 375 enable_endpoint_detection=args.use_endpoint != 0,
347 rule1_min_trailing_silence=args.rule1_min_trailing_silence, 376 rule1_min_trailing_silence=args.rule1_min_trailing_silence,
348 rule2_min_trailing_silence=args.rule2_min_trailing_silence, 377 rule2_min_trailing_silence=args.rule2_min_trailing_silence,
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This script encode the texts (given line by line through `text`) to tokens and
  5 +write the results to the file given by ``output``.
  6 +
  7 +Usage:
  8 +If the tokens_type is bpe:
  9 +
  10 +python3 ./text2token.py \
  11 + --text texts.txt \
  12 + --tokens tokens.txt \
  13 + --tokens-type bpe \
  14 + --bpe-model bpe.model \
  15 + --output hotwords.txt
  16 +
  17 +If the tokens_type is cjkchar:
  18 +
  19 +python3 ./text2token.py \
  20 + --text texts.txt \
  21 + --tokens tokens.txt \
  22 + --tokens-type cjkchar \
  23 + --output hotwords.txt
  24 +
  25 +If the tokens_type is cjkchar+bpe:
  26 +
  27 +python3 ./text2token.py \
  28 + --text texts.txt \
  29 + --tokens tokens.txt \
  30 + --tokens-type cjkchar+bpe \
  31 + --bpe-model bpe.model \
  32 + --output hotwords.txt
  33 +
  34 +"""
  35 +import argparse
  36 +
  37 +from sherpa_onnx import text2token
  38 +
  39 +def get_args():
  40 + parser = argparse.ArgumentParser()
  41 + parser.add_argument(
  42 + "--text",
  43 + type=str,
  44 + required=True,
  45 + help="Path to the input texts",
  46 + )
  47 +
  48 + parser.add_argument(
  49 + "--tokens",
  50 + type=str,
  51 + required=True,
  52 + help="The path to tokens.txt.",
  53 + )
  54 +
  55 + parser.add_argument(
  56 + "--tokens-type",
  57 + type=str,
  58 + required=True,
  59 + help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe",
  60 + )
  61 +
  62 + parser.add_argument(
  63 + "--bpe-model",
  64 + type=str,
  65 + help="The path to bpe.model. Only required when tokens-type is bpe or cjkchar+bpe.",
  66 + )
  67 +
  68 + parser.add_argument(
  69 + "--output",
  70 + type=str,
  71 + required=True,
  72 + help="Path where the encoded tokens will be written to.",
  73 + )
  74 +
  75 + return parser.parse_args()
  76 +
  77 +
  78 +def main():
  79 + args = get_args()
  80 +
  81 + texts = []
  82 + with open(args.text, "r", encoding="utf8") as f:
  83 + for line in f:
  84 + texts.append(line.strip())
  85 + encoded_texts = text2token(
  86 + texts,
  87 + tokens=args.tokens,
  88 + tokens_type=args.tokens_type,
  89 + bpe_model=args.bpe_model,
  90 + )
  91 + with open(args.output, "w", encoding="utf8") as f:
  92 + for txt in encoded_texts:
  93 + f.write(" ".join(txt) + "\n")
  94 +
  95 +
  96 +if __name__ == "__main__":
  97 + main()
@@ -39,6 +39,7 @@ install_requires = [ @@ -39,6 +39,7 @@ install_requires = [
39 "numpy", 39 "numpy",
40 "sentencepiece==0.1.96; python_version < '3.11'", 40 "sentencepiece==0.1.96; python_version < '3.11'",
41 "sentencepiece; python_version >= '3.11'", 41 "sentencepiece; python_version >= '3.11'",
  42 + "click>=7.1.1",
42 ] 43 ]
43 44
44 45
@@ -93,6 +94,11 @@ setuptools.setup( @@ -93,6 +94,11 @@ setuptools.setup(
93 "Programming Language :: Python", 94 "Programming Language :: Python",
94 "Topic :: Scientific/Engineering :: Artificial Intelligence", 95 "Topic :: Scientific/Engineering :: Artificial Intelligence",
95 ], 96 ],
  97 + entry_points={
  98 + 'console_scripts': [
  99 + 'sherpa-onnx-cli=sherpa_onnx.cli:cli',
  100 + ],
  101 + },
96 license="Apache licensed, as found in the LICENSE file", 102 license="Apache licensed, as found in the LICENSE file",
97 ) 103 )
98 104
@@ -72,6 +72,7 @@ set(sources @@ -72,6 +72,7 @@ set(sources
72 text-utils.cc 72 text-utils.cc
73 transpose.cc 73 transpose.cc
74 unbind.cc 74 unbind.cc
  75 + utils.cc
75 wave-reader.cc 76 wave-reader.cc
76 ) 77 )
77 78
@@ -4,11 +4,14 @@ @@ -4,11 +4,14 @@
4 4
5 #include "sherpa-onnx/csrc/context-graph.h" 5 #include "sherpa-onnx/csrc/context-graph.h"
6 6
  7 +#include <chrono> // NOLINT
7 #include <map> 8 #include <map>
  9 +#include <random>
8 #include <string> 10 #include <string>
9 #include <vector> 11 #include <vector>
10 12
11 #include "gtest/gtest.h" 13 #include "gtest/gtest.h"
  14 +#include "sherpa-onnx/csrc/macros.h"
12 15
13 namespace sherpa_onnx { 16 namespace sherpa_onnx {
14 17
@@ -41,4 +44,29 @@ TEST(ContextGraph, TestBasic) { @@ -41,4 +44,29 @@ TEST(ContextGraph, TestBasic) {
41 } 44 }
42 } 45 }
43 46
  47 +TEST(ContextGraph, Benchmark) {
  48 + std::random_device rd;
  49 + std::mt19937 mt(rd());
  50 + std::uniform_int_distribution<int32_t> char_dist(0, 25);
  51 + std::uniform_int_distribution<int32_t> len_dist(3, 8);
  52 + for (int32_t num = 10; num <= 10000; num *= 10) {
  53 + std::vector<std::vector<int32_t>> contexts;
  54 + for (int32_t i = 0; i < num; ++i) {
  55 + std::vector<int32_t> tmp;
  56 + int32_t word_len = len_dist(mt);
  57 + for (int32_t j = 0; j < word_len; ++j) {
  58 + tmp.push_back(char_dist(mt));
  59 + }
  60 + contexts.push_back(std::move(tmp));
  61 + }
  62 + auto start = std::chrono::high_resolution_clock::now();
  63 + auto context_graph = ContextGraph(contexts, 1);
  64 + auto stop = std::chrono::high_resolution_clock::now();
  65 + auto duration =
  66 + std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
  67 + SHERPA_ONNX_LOGE("Construct context graph for %d item takes %ld us.", num,
  68 + duration.count());
  69 + }
  70 +}
  71 +
44 } // namespace sherpa_onnx 72 } // namespace sherpa_onnx
@@ -6,6 +6,7 @@ @@ -6,6 +6,7 @@
6 #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_ 6 #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_
7 7
8 #include <memory> 8 #include <memory>
  9 +#include <string>
9 #include <vector> 10 #include <vector>
10 11
11 #if __ANDROID_API__ >= 9 12 #if __ANDROID_API__ >= 9
@@ -32,7 +33,7 @@ class OfflineRecognizerImpl { @@ -32,7 +33,7 @@ class OfflineRecognizerImpl {
32 virtual ~OfflineRecognizerImpl() = default; 33 virtual ~OfflineRecognizerImpl() = default;
33 34
34 virtual std::unique_ptr<OfflineStream> CreateStream( 35 virtual std::unique_ptr<OfflineStream> CreateStream(
35 - const std::vector<std::vector<int32_t>> &context_list) const { 36 + const std::string &hotwords) const {
36 SHERPA_ONNX_LOGE("Only transducer models support contextual biasing."); 37 SHERPA_ONNX_LOGE("Only transducer models support contextual biasing.");
37 exit(-1); 38 exit(-1);
38 } 39 }
@@ -5,7 +5,9 @@ @@ -5,7 +5,9 @@
5 #ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ 5 #ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
6 #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ 6 #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
7 7
  8 +#include <fstream>
8 #include <memory> 9 #include <memory>
  10 +#include <regex> // NOLINT
9 #include <string> 11 #include <string>
10 #include <utility> 12 #include <utility>
11 #include <vector> 13 #include <vector>
@@ -16,6 +18,7 @@ @@ -16,6 +18,7 @@
16 #endif 18 #endif
17 19
18 #include "sherpa-onnx/csrc/context-graph.h" 20 #include "sherpa-onnx/csrc/context-graph.h"
  21 +#include "sherpa-onnx/csrc/log.h"
19 #include "sherpa-onnx/csrc/macros.h" 22 #include "sherpa-onnx/csrc/macros.h"
20 #include "sherpa-onnx/csrc/offline-recognizer-impl.h" 23 #include "sherpa-onnx/csrc/offline-recognizer-impl.h"
21 #include "sherpa-onnx/csrc/offline-recognizer.h" 24 #include "sherpa-onnx/csrc/offline-recognizer.h"
@@ -25,6 +28,7 @@ @@ -25,6 +28,7 @@
25 #include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h" 28 #include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h"
26 #include "sherpa-onnx/csrc/pad-sequence.h" 29 #include "sherpa-onnx/csrc/pad-sequence.h"
27 #include "sherpa-onnx/csrc/symbol-table.h" 30 #include "sherpa-onnx/csrc/symbol-table.h"
  31 +#include "sherpa-onnx/csrc/utils.h"
28 32
29 namespace sherpa_onnx { 33 namespace sherpa_onnx {
30 34
@@ -60,6 +64,9 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -60,6 +64,9 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
60 : config_(config), 64 : config_(config),
61 symbol_table_(config_.model_config.tokens), 65 symbol_table_(config_.model_config.tokens),
62 model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) { 66 model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
  67 + if (!config_.hotwords_file.empty()) {
  68 + InitHotwords();
  69 + }
63 if (config_.decoding_method == "greedy_search") { 70 if (config_.decoding_method == "greedy_search") {
64 decoder_ = 71 decoder_ =
65 std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); 72 std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
@@ -105,17 +112,24 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -105,17 +112,24 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
105 #endif 112 #endif
106 113
107 std::unique_ptr<OfflineStream> CreateStream( 114 std::unique_ptr<OfflineStream> CreateStream(
108 - const std::vector<std::vector<int32_t>> &context_list) const override {  
109 - // We create context_graph at this level, because we might have default  
110 - // context_graph(will be added later if needed) that belongs to the whole  
111 - // model rather than each stream. 115 + const std::string &hotwords) const override {
  116 + auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
  117 + std::istringstream is(hws);
  118 + std::vector<std::vector<int32_t>> current;
  119 + if (!EncodeHotwords(is, symbol_table_, &current)) {
  120 + SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
  121 + hotwords.c_str());
  122 + }
  123 + current.insert(current.end(), hotwords_.begin(), hotwords_.end());
  124 +
112 auto context_graph = 125 auto context_graph =
113 - std::make_shared<ContextGraph>(context_list, config_.context_score); 126 + std::make_shared<ContextGraph>(current, config_.hotwords_score);
114 return std::make_unique<OfflineStream>(config_.feat_config, context_graph); 127 return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
115 } 128 }
116 129
117 std::unique_ptr<OfflineStream> CreateStream() const override { 130 std::unique_ptr<OfflineStream> CreateStream() const override {
118 - return std::make_unique<OfflineStream>(config_.feat_config); 131 + return std::make_unique<OfflineStream>(config_.feat_config,
  132 + hotwords_graph_);
119 } 133 }
120 134
121 void DecodeStreams(OfflineStream **ss, int32_t n) const override { 135 void DecodeStreams(OfflineStream **ss, int32_t n) const override {
@@ -171,9 +185,29 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -171,9 +185,29 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
171 } 185 }
172 } 186 }
173 187
  188 + void InitHotwords() {
  189 + // each line in hotwords_file contains space-separated words
  190 +
  191 + std::ifstream is(config_.hotwords_file);
  192 + if (!is) {
  193 + SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
  194 + config_.hotwords_file.c_str());
  195 + exit(-1);
  196 + }
  197 +
  198 + if (!EncodeHotwords(is, symbol_table_, &hotwords_)) {
  199 + SHERPA_ONNX_LOGE("Encode hotwords failed.");
  200 + exit(-1);
  201 + }
  202 + hotwords_graph_ =
  203 + std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
  204 + }
  205 +
174 private: 206 private:
175 OfflineRecognizerConfig config_; 207 OfflineRecognizerConfig config_;
176 SymbolTable symbol_table_; 208 SymbolTable symbol_table_;
  209 + std::vector<std::vector<int32_t>> hotwords_;
  210 + ContextGraphPtr hotwords_graph_;
177 std::unique_ptr<OfflineTransducerModel> model_; 211 std::unique_ptr<OfflineTransducerModel> model_;
178 std::unique_ptr<OfflineTransducerDecoder> decoder_; 212 std::unique_ptr<OfflineTransducerDecoder> decoder_;
179 std::unique_ptr<OfflineLM> lm_; 213 std::unique_ptr<OfflineLM> lm_;
@@ -26,7 +26,15 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { @@ -26,7 +26,15 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
26 26
27 po->Register("max-active-paths", &max_active_paths, 27 po->Register("max-active-paths", &max_active_paths,
28 "Used only when decoding_method is modified_beam_search"); 28 "Used only when decoding_method is modified_beam_search");
29 - po->Register("context-score", &context_score, 29 +
  30 + po->Register(
  31 + "hotwords-file", &hotwords_file,
  32 + "The file containing hotwords, one words/phrases per line, and for each"
  33 + "phrase the bpe/cjkchar are separated by a space. For example: "
  34 + "▁HE LL O ▁WORLD"
  35 + "你 好 世 界");
  36 +
  37 + po->Register("hotwords-score", &hotwords_score,
30 "The bonus score for each token in context word/phrase. " 38 "The bonus score for each token in context word/phrase. "
31 "Used only when decoding_method is modified_beam_search"); 39 "Used only when decoding_method is modified_beam_search");
32 } 40 }
@@ -53,7 +61,8 @@ std::string OfflineRecognizerConfig::ToString() const { @@ -53,7 +61,8 @@ std::string OfflineRecognizerConfig::ToString() const {
53 os << "lm_config=" << lm_config.ToString() << ", "; 61 os << "lm_config=" << lm_config.ToString() << ", ";
54 os << "decoding_method=\"" << decoding_method << "\", "; 62 os << "decoding_method=\"" << decoding_method << "\", ";
55 os << "max_active_paths=" << max_active_paths << ", "; 63 os << "max_active_paths=" << max_active_paths << ", ";
56 - os << "context_score=" << context_score << ")"; 64 + os << "hotwords_file=\"" << hotwords_file << "\", ";
  65 + os << "hotwords_score=" << hotwords_score << ")";
57 66
58 return os.str(); 67 return os.str();
59 } 68 }
@@ -70,8 +79,8 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config) @@ -70,8 +79,8 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
70 OfflineRecognizer::~OfflineRecognizer() = default; 79 OfflineRecognizer::~OfflineRecognizer() = default;
71 80
72 std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream( 81 std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream(
73 - const std::vector<std::vector<int32_t>> &context_list) const {  
74 - return impl_->CreateStream(context_list); 82 + const std::string &hotwords) const {
  83 + return impl_->CreateStream(hotwords);
75 } 84 }
76 85
77 std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const { 86 std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const {
@@ -31,7 +31,10 @@ struct OfflineRecognizerConfig { @@ -31,7 +31,10 @@ struct OfflineRecognizerConfig {
31 31
32 std::string decoding_method = "greedy_search"; 32 std::string decoding_method = "greedy_search";
33 int32_t max_active_paths = 4; 33 int32_t max_active_paths = 4;
34 - float context_score = 1.5; 34 +
  35 + std::string hotwords_file;
  36 + float hotwords_score = 1.5;
  37 +
35 // only greedy_search is implemented 38 // only greedy_search is implemented
36 // TODO(fangjun): Implement modified_beam_search 39 // TODO(fangjun): Implement modified_beam_search
37 40
@@ -40,13 +43,16 @@ struct OfflineRecognizerConfig { @@ -40,13 +43,16 @@ struct OfflineRecognizerConfig {
40 const OfflineModelConfig &model_config, 43 const OfflineModelConfig &model_config,
41 const OfflineLMConfig &lm_config, 44 const OfflineLMConfig &lm_config,
42 const std::string &decoding_method, 45 const std::string &decoding_method,
43 - int32_t max_active_paths, float context_score) 46 + int32_t max_active_paths,
  47 + const std::string &hotwords_file,
  48 + float hotwords_score)
44 : feat_config(feat_config), 49 : feat_config(feat_config),
45 model_config(model_config), 50 model_config(model_config),
46 lm_config(lm_config), 51 lm_config(lm_config),
47 decoding_method(decoding_method), 52 decoding_method(decoding_method),
48 max_active_paths(max_active_paths), 53 max_active_paths(max_active_paths),
49 - context_score(context_score) {} 54 + hotwords_file(hotwords_file),
  55 + hotwords_score(hotwords_score) {}
50 56
51 void Register(ParseOptions *po); 57 void Register(ParseOptions *po);
52 bool Validate() const; 58 bool Validate() const;
@@ -69,9 +75,17 @@ class OfflineRecognizer { @@ -69,9 +75,17 @@ class OfflineRecognizer {
69 /// Create a stream for decoding. 75 /// Create a stream for decoding.
70 std::unique_ptr<OfflineStream> CreateStream() const; 76 std::unique_ptr<OfflineStream> CreateStream() const;
71 77
72 - /// Create a stream for decoding. 78 + /** Create a stream for decoding.
  79 + *
  80 + * @param The hotwords for this string, it might contain several hotwords,
  81 + * the hotwords are separated by "/". In each of the hotwords, there
  82 + * are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
  83 + * For example, hotwords I LOVE YOU and HELLO WORLD, looks like:
  84 + *
  85 + * "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
  86 + */
73 std::unique_ptr<OfflineStream> CreateStream( 87 std::unique_ptr<OfflineStream> CreateStream(
74 - const std::vector<std::vector<int32_t>> &context_list) const; 88 + const std::string &hotwords) const;
75 89
76 /** Decode a single stream 90 /** Decode a single stream
77 * 91 *
@@ -6,6 +6,7 @@ @@ -6,6 +6,7 @@
6 #define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_IMPL_H_ 6 #define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_IMPL_H_
7 7
8 #include <memory> 8 #include <memory>
  9 +#include <string>
9 #include <vector> 10 #include <vector>
10 11
11 #include "sherpa-onnx/csrc/macros.h" 12 #include "sherpa-onnx/csrc/macros.h"
@@ -29,7 +30,7 @@ class OnlineRecognizerImpl { @@ -29,7 +30,7 @@ class OnlineRecognizerImpl {
29 virtual std::unique_ptr<OnlineStream> CreateStream() const = 0; 30 virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
30 31
31 virtual std::unique_ptr<OnlineStream> CreateStream( 32 virtual std::unique_ptr<OnlineStream> CreateStream(
32 - const std::vector<std::vector<int32_t>> &contexts) const { 33 + const std::string &hotwords) const {
33 SHERPA_ONNX_LOGE("Only transducer models support contextual biasing."); 34 SHERPA_ONNX_LOGE("Only transducer models support contextual biasing.");
34 exit(-1); 35 exit(-1);
35 } 36 }
@@ -7,6 +7,8 @@ @@ -7,6 +7,8 @@
7 7
8 #include <algorithm> 8 #include <algorithm>
9 #include <memory> 9 #include <memory>
  10 +#include <regex> // NOLINT
  11 +#include <string>
10 #include <utility> 12 #include <utility>
11 #include <vector> 13 #include <vector>
12 14
@@ -20,6 +22,7 @@ @@ -20,6 +22,7 @@
20 #include "sherpa-onnx/csrc/online-transducer-model.h" 22 #include "sherpa-onnx/csrc/online-transducer-model.h"
21 #include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" 23 #include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"
22 #include "sherpa-onnx/csrc/symbol-table.h" 24 #include "sherpa-onnx/csrc/symbol-table.h"
  25 +#include "sherpa-onnx/csrc/utils.h"
23 26
24 namespace sherpa_onnx { 27 namespace sherpa_onnx {
25 28
@@ -57,6 +60,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -57,6 +60,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
57 model_(OnlineTransducerModel::Create(config.model_config)), 60 model_(OnlineTransducerModel::Create(config.model_config)),
58 sym_(config.model_config.tokens), 61 sym_(config.model_config.tokens),
59 endpoint_(config_.endpoint_config) { 62 endpoint_(config_.endpoint_config) {
  63 + if (!config_.hotwords_file.empty()) {
  64 + InitHotwords();
  65 + }
60 if (sym_.contains("<unk>")) { 66 if (sym_.contains("<unk>")) {
61 unk_id_ = sym_["<unk>"]; 67 unk_id_ = sym_["<unk>"];
62 } 68 }
@@ -106,18 +112,24 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -106,18 +112,24 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
106 #endif 112 #endif
107 113
108 std::unique_ptr<OnlineStream> CreateStream() const override { 114 std::unique_ptr<OnlineStream> CreateStream() const override {
109 - auto stream = std::make_unique<OnlineStream>(config_.feat_config); 115 + auto stream =
  116 + std::make_unique<OnlineStream>(config_.feat_config, hotwords_graph_);
110 InitOnlineStream(stream.get()); 117 InitOnlineStream(stream.get());
111 return stream; 118 return stream;
112 } 119 }
113 120
114 std::unique_ptr<OnlineStream> CreateStream( 121 std::unique_ptr<OnlineStream> CreateStream(
115 - const std::vector<std::vector<int32_t>> &contexts) const override {  
116 - // We create context_graph at this level, because we might have default  
117 - // context_graph(will be added later if needed) that belongs to the whole  
118 - // model rather than each stream. 122 + const std::string &hotwords) const override {
  123 + auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
  124 + std::istringstream is(hws);
  125 + std::vector<std::vector<int32_t>> current;
  126 + if (!EncodeHotwords(is, sym_, &current)) {
  127 + SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
  128 + hotwords.c_str());
  129 + }
  130 + current.insert(current.end(), hotwords_.begin(), hotwords_.end());
119 auto context_graph = 131 auto context_graph =
120 - std::make_shared<ContextGraph>(contexts, config_.context_score); 132 + std::make_shared<ContextGraph>(current, config_.hotwords_score);
121 auto stream = 133 auto stream =
122 std::make_unique<OnlineStream>(config_.feat_config, context_graph); 134 std::make_unique<OnlineStream>(config_.feat_config, context_graph);
123 InitOnlineStream(stream.get()); 135 InitOnlineStream(stream.get());
@@ -253,6 +265,24 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -253,6 +265,24 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
253 s->Reset(); 265 s->Reset();
254 } 266 }
255 267
  268 + void InitHotwords() {
  269 + // each line in hotwords_file contains space-separated words
  270 +
  271 + std::ifstream is(config_.hotwords_file);
  272 + if (!is) {
  273 + SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
  274 + config_.hotwords_file.c_str());
  275 + exit(-1);
  276 + }
  277 +
  278 + if (!EncodeHotwords(is, sym_, &hotwords_)) {
  279 + SHERPA_ONNX_LOGE("Encode hotwords failed.");
  280 + exit(-1);
  281 + }
  282 + hotwords_graph_ =
  283 + std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
  284 + }
  285 +
256 private: 286 private:
257 void InitOnlineStream(OnlineStream *stream) const { 287 void InitOnlineStream(OnlineStream *stream) const {
258 auto r = decoder_->GetEmptyResult(); 288 auto r = decoder_->GetEmptyResult();
@@ -271,6 +301,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -271,6 +301,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
271 301
272 private: 302 private:
273 OnlineRecognizerConfig config_; 303 OnlineRecognizerConfig config_;
  304 + std::vector<std::vector<int32_t>> hotwords_;
  305 + ContextGraphPtr hotwords_graph_;
274 std::unique_ptr<OnlineTransducerModel> model_; 306 std::unique_ptr<OnlineTransducerModel> model_;
275 std::unique_ptr<OnlineLM> lm_; 307 std::unique_ptr<OnlineLM> lm_;
276 std::unique_ptr<OnlineTransducerDecoder> decoder_; 308 std::unique_ptr<OnlineTransducerDecoder> decoder_;
@@ -57,9 +57,15 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { @@ -57,9 +57,15 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
57 "True to enable endpoint detection. False to disable it."); 57 "True to enable endpoint detection. False to disable it.");
58 po->Register("max-active-paths", &max_active_paths, 58 po->Register("max-active-paths", &max_active_paths,
59 "beam size used in modified beam search."); 59 "beam size used in modified beam search.");
60 - po->Register("context-score", &context_score, 60 + po->Register("hotwords-score", &hotwords_score,
61 "The bonus score for each token in context word/phrase. " 61 "The bonus score for each token in context word/phrase. "
62 "Used only when decoding_method is modified_beam_search"); 62 "Used only when decoding_method is modified_beam_search");
  63 + po->Register(
  64 + "hotwords-file", &hotwords_file,
  65 + "The file containing hotwords, one words/phrases per line, and for each"
  66 + "phrase the bpe/cjkchar are separated by a space. For example: "
  67 + "▁HE LL O ▁WORLD"
  68 + "你 好 世 界");
63 po->Register("decoding-method", &decoding_method, 69 po->Register("decoding-method", &decoding_method,
64 "decoding method," 70 "decoding method,"
65 "now support greedy_search and modified_beam_search."); 71 "now support greedy_search and modified_beam_search.");
@@ -87,7 +93,8 @@ std::string OnlineRecognizerConfig::ToString() const { @@ -87,7 +93,8 @@ std::string OnlineRecognizerConfig::ToString() const {
87 os << "endpoint_config=" << endpoint_config.ToString() << ", "; 93 os << "endpoint_config=" << endpoint_config.ToString() << ", ";
88 os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; 94 os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
89 os << "max_active_paths=" << max_active_paths << ", "; 95 os << "max_active_paths=" << max_active_paths << ", ";
90 - os << "context_score=" << context_score << ", "; 96 + os << "hotwords_score=" << hotwords_score << ", ";
  97 + os << "hotwords_file=\"" << hotwords_file << "\", ";
91 os << "decoding_method=\"" << decoding_method << "\")"; 98 os << "decoding_method=\"" << decoding_method << "\")";
92 99
93 return os.str(); 100 return os.str();
@@ -109,8 +116,8 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const { @@ -109,8 +116,8 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
109 } 116 }
110 117
111 std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream( 118 std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream(
112 - const std::vector<std::vector<int32_t>> &context_list) const {  
113 - return impl_->CreateStream(context_list); 119 + const std::string &hotwords) const {
  120 + return impl_->CreateStream(hotwords);
114 } 121 }
115 122
116 bool OnlineRecognizer::IsReady(OnlineStream *s) const { 123 bool OnlineRecognizer::IsReady(OnlineStream *s) const {
@@ -78,8 +78,10 @@ struct OnlineRecognizerConfig { @@ -78,8 +78,10 @@ struct OnlineRecognizerConfig {
78 78
79 // used only for modified_beam_search 79 // used only for modified_beam_search
80 int32_t max_active_paths = 4; 80 int32_t max_active_paths = 4;
  81 +
81 /// used only for modified_beam_search 82 /// used only for modified_beam_search
82 - float context_score = 1.5; 83 + float hotwords_score = 1.5;
  84 + std::string hotwords_file;
83 85
84 OnlineRecognizerConfig() = default; 86 OnlineRecognizerConfig() = default;
85 87
@@ -89,14 +91,16 @@ struct OnlineRecognizerConfig { @@ -89,14 +91,16 @@ struct OnlineRecognizerConfig {
89 const EndpointConfig &endpoint_config, 91 const EndpointConfig &endpoint_config,
90 bool enable_endpoint, 92 bool enable_endpoint,
91 const std::string &decoding_method, 93 const std::string &decoding_method,
92 - int32_t max_active_paths, float context_score) 94 + int32_t max_active_paths,
  95 + const std::string &hotwords_file, float hotwords_score)
93 : feat_config(feat_config), 96 : feat_config(feat_config),
94 model_config(model_config), 97 model_config(model_config),
95 endpoint_config(endpoint_config), 98 endpoint_config(endpoint_config),
96 enable_endpoint(enable_endpoint), 99 enable_endpoint(enable_endpoint),
97 decoding_method(decoding_method), 100 decoding_method(decoding_method),
98 max_active_paths(max_active_paths), 101 max_active_paths(max_active_paths),
99 - context_score(context_score) {} 102 + hotwords_score(hotwords_score),
  103 + hotwords_file(hotwords_file) {}
100 104
101 void Register(ParseOptions *po); 105 void Register(ParseOptions *po);
102 bool Validate() const; 106 bool Validate() const;
@@ -119,9 +123,16 @@ class OnlineRecognizer { @@ -119,9 +123,16 @@ class OnlineRecognizer {
119 /// Create a stream for decoding. 123 /// Create a stream for decoding.
120 std::unique_ptr<OnlineStream> CreateStream() const; 124 std::unique_ptr<OnlineStream> CreateStream() const;
121 125
122 - // Create a stream with context phrases  
123 - std::unique_ptr<OnlineStream> CreateStream(  
124 - const std::vector<std::vector<int32_t>> &context_list) const; 126 + /** Create a stream for decoding.
  127 + *
  128 + * @param The hotwords for this string, it might contain several hotwords,
  129 + * the hotwords are separated by "/". In each of the hotwords, there
  130 + * are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
  131 + * For example, hotwords I LOVE YOU and HELLO WORLD, looks like:
  132 + *
  133 + * "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
  134 + */
  135 + std::unique_ptr<OnlineStream> CreateStream(const std::string &hotwords) const;
125 136
126 /** 137 /**
127 * Return true if the given stream has enough frames for decoding. 138 * Return true if the given stream has enough frames for decoding.
  1 +// sherpa-onnx/csrc/utils.cc
  2 +//
  3 +// Copyright 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/utils.h"
  6 +
  7 +#include <iostream>
  8 +#include <sstream>
  9 +#include <string>
  10 +#include <utility>
  11 +#include <vector>
  12 +
  13 +#include "sherpa-onnx/csrc/log.h"
  14 +#include "sherpa-onnx/csrc/macros.h"
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
  19 + std::vector<std::vector<int32_t>> *hotwords) {
  20 + hotwords->clear();
  21 + std::vector<int32_t> tmp;
  22 + std::string line;
  23 + std::string word;
  24 +
  25 + while (std::getline(is, line)) {
  26 + std::istringstream iss(line);
  27 + std::vector<std::string> syms;
  28 + while (iss >> word) {
  29 + if (word.size() >= 3) {
  30 + // For BPE-based models, we replace ▁ with a space
  31 + // Unicode 9601, hex 0x2581, utf8 0xe29681
  32 + const uint8_t *p = reinterpret_cast<const uint8_t *>(word.c_str());
  33 + if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
  34 + word = word.replace(0, 3, " ");
  35 + }
  36 + }
  37 + if (symbol_table.contains(word)) {
  38 + int32_t number = symbol_table[word];
  39 + tmp.push_back(number);
  40 + } else {
  41 + SHERPA_ONNX_LOGE(
  42 + "Cannot find ID for hotword %s at line: %s. (Hint: words on "
  43 + "the "
  44 + "same line are separated by spaces)",
  45 + word.c_str(), line.c_str());
  46 + return false;
  47 + }
  48 + }
  49 + hotwords->push_back(std::move(tmp));
  50 + }
  51 + return true;
  52 +}
  53 +
  54 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/utils.h
  2 +//
  3 +// Copyright 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_UTILS_H_
  5 +#define SHERPA_ONNX_CSRC_UTILS_H_
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/symbol-table.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +/* Encode the hotwords in an input stream to be tokens ids.
  15 + *
  16 + * @param is The input stream, it contains several lines, one hotword for each
  17 + * line. For each hotword, the tokens (cjkchar or bpe) are separated
  18 + * by spaces.
  19 + * @param symbol_table The tokens table mapping symbols to ids. All the symbols
  20 + * in the stream should be in the symbol_table, if not this
  21 + * function returns fasle.
  22 + *
  23 + * @@param hotwords The encoded ids to be written to.
  24 + *
  25 + * @return If all the symbols from ``is`` are in the symbol_table, returns true
  26 + * otherwise returns false.
  27 + */
  28 +bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
  29 + std::vector<std::vector<int32_t>> *hotwords);
  30 +
  31 +} // namespace sherpa_onnx
  32 +
  33 +#endif // SHERPA_ONNX_CSRC_UTILS_H_
@@ -16,17 +16,19 @@ static void PybindOfflineRecognizerConfig(py::module *m) { @@ -16,17 +16,19 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
16 py::class_<PyClass>(*m, "OfflineRecognizerConfig") 16 py::class_<PyClass>(*m, "OfflineRecognizerConfig")
17 .def(py::init<const OfflineFeatureExtractorConfig &, 17 .def(py::init<const OfflineFeatureExtractorConfig &,
18 const OfflineModelConfig &, const OfflineLMConfig &, 18 const OfflineModelConfig &, const OfflineLMConfig &,
19 - const std::string &, int32_t, float>(), 19 + const std::string &, int32_t, const std::string &, float>(),
20 py::arg("feat_config"), py::arg("model_config"), 20 py::arg("feat_config"), py::arg("model_config"),
21 py::arg("lm_config") = OfflineLMConfig(), 21 py::arg("lm_config") = OfflineLMConfig(),
22 py::arg("decoding_method") = "greedy_search", 22 py::arg("decoding_method") = "greedy_search",
23 - py::arg("max_active_paths") = 4, py::arg("context_score") = 1.5) 23 + py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
  24 + py::arg("hotwords_score") = 1.5)
24 .def_readwrite("feat_config", &PyClass::feat_config) 25 .def_readwrite("feat_config", &PyClass::feat_config)
25 .def_readwrite("model_config", &PyClass::model_config) 26 .def_readwrite("model_config", &PyClass::model_config)
26 .def_readwrite("lm_config", &PyClass::lm_config) 27 .def_readwrite("lm_config", &PyClass::lm_config)
27 .def_readwrite("decoding_method", &PyClass::decoding_method) 28 .def_readwrite("decoding_method", &PyClass::decoding_method)
28 .def_readwrite("max_active_paths", &PyClass::max_active_paths) 29 .def_readwrite("max_active_paths", &PyClass::max_active_paths)
29 - .def_readwrite("context_score", &PyClass::context_score) 30 + .def_readwrite("hotwords_file", &PyClass::hotwords_file)
  31 + .def_readwrite("hotwords_score", &PyClass::hotwords_score)
30 .def("__str__", &PyClass::ToString); 32 .def("__str__", &PyClass::ToString);
31 } 33 }
32 34
@@ -40,11 +42,10 @@ void PybindOfflineRecognizer(py::module *m) { @@ -40,11 +42,10 @@ void PybindOfflineRecognizer(py::module *m) {
40 [](const PyClass &self) { return self.CreateStream(); }) 42 [](const PyClass &self) { return self.CreateStream(); })
41 .def( 43 .def(
42 "create_stream", 44 "create_stream",
43 - [](PyClass &self,  
44 - const std::vector<std::vector<int32_t>> &contexts_list) {  
45 - return self.CreateStream(contexts_list); 45 + [](PyClass &self, const std::string &hotwords) {
  46 + return self.CreateStream(hotwords);
46 }, 47 },
47 - py::arg("contexts_list")) 48 + py::arg("hotwords"))
48 .def("decode_stream", &PyClass::DecodeStream) 49 .def("decode_stream", &PyClass::DecodeStream)
49 .def("decode_streams", 50 .def("decode_streams",
50 [](const PyClass &self, std::vector<OfflineStream *> ss) { 51 [](const PyClass &self, std::vector<OfflineStream *> ss) {
@@ -21,8 +21,8 @@ void PybindOnlineModelConfig(py::module *m) { @@ -21,8 +21,8 @@ void PybindOnlineModelConfig(py::module *m) {
21 using PyClass = OnlineModelConfig; 21 using PyClass = OnlineModelConfig;
22 py::class_<PyClass>(*m, "OnlineModelConfig") 22 py::class_<PyClass>(*m, "OnlineModelConfig")
23 .def(py::init<const OnlineTransducerModelConfig &, 23 .def(py::init<const OnlineTransducerModelConfig &,
24 - const OnlineParaformerModelConfig &, std::string &, int32_t,  
25 - bool, const std::string &, const std::string &>(), 24 + const OnlineParaformerModelConfig &, const std::string &,
  25 + int32_t, bool, const std::string &, const std::string &>(),
26 py::arg("transducer") = OnlineTransducerModelConfig(), 26 py::arg("transducer") = OnlineTransducerModelConfig(),
27 py::arg("paraformer") = OnlineParaformerModelConfig(), 27 py::arg("paraformer") = OnlineParaformerModelConfig(),
28 py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, 28 py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
@@ -29,18 +29,20 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -29,18 +29,20 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
29 py::class_<PyClass>(*m, "OnlineRecognizerConfig") 29 py::class_<PyClass>(*m, "OnlineRecognizerConfig")
30 .def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &, 30 .def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
31 const OnlineLMConfig &, const EndpointConfig &, bool, 31 const OnlineLMConfig &, const EndpointConfig &, bool,
32 - const std::string &, int32_t, float>(), 32 + const std::string &, int32_t, const std::string &, float>(),
33 py::arg("feat_config"), py::arg("model_config"), 33 py::arg("feat_config"), py::arg("model_config"),
34 py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), 34 py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
35 py::arg("enable_endpoint"), py::arg("decoding_method"), 35 py::arg("enable_endpoint"), py::arg("decoding_method"),
36 - py::arg("max_active_paths") = 4, py::arg("context_score") = 0) 36 + py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
  37 + py::arg("hotwords_score") = 0)
37 .def_readwrite("feat_config", &PyClass::feat_config) 38 .def_readwrite("feat_config", &PyClass::feat_config)
38 .def_readwrite("model_config", &PyClass::model_config) 39 .def_readwrite("model_config", &PyClass::model_config)
39 .def_readwrite("endpoint_config", &PyClass::endpoint_config) 40 .def_readwrite("endpoint_config", &PyClass::endpoint_config)
40 .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) 41 .def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
41 .def_readwrite("decoding_method", &PyClass::decoding_method) 42 .def_readwrite("decoding_method", &PyClass::decoding_method)
42 .def_readwrite("max_active_paths", &PyClass::max_active_paths) 43 .def_readwrite("max_active_paths", &PyClass::max_active_paths)
43 - .def_readwrite("context_score", &PyClass::context_score) 44 + .def_readwrite("hotwords_file", &PyClass::hotwords_file)
  45 + .def_readwrite("hotwords_score", &PyClass::hotwords_score)
44 .def("__str__", &PyClass::ToString); 46 .def("__str__", &PyClass::ToString);
45 } 47 }
46 48
@@ -55,11 +57,10 @@ void PybindOnlineRecognizer(py::module *m) { @@ -55,11 +57,10 @@ void PybindOnlineRecognizer(py::module *m) {
55 [](const PyClass &self) { return self.CreateStream(); }) 57 [](const PyClass &self) { return self.CreateStream(); })
56 .def( 58 .def(
57 "create_stream", 59 "create_stream",
58 - [](PyClass &self,  
59 - const std::vector<std::vector<int32_t>> &contexts_list) {  
60 - return self.CreateStream(contexts_list); 60 + [](PyClass &self, const std::string &hotwords) {
  61 + return self.CreateStream(hotwords);
61 }, 62 },
62 - py::arg("contexts_list")) 63 + py::arg("hotwords"))
63 .def("is_ready", &PyClass::IsReady) 64 .def("is_ready", &PyClass::IsReady)
64 .def("decode_stream", &PyClass::DecodeStream) 65 .def("decode_stream", &PyClass::DecodeStream)
65 .def("decode_streams", 66 .def("decode_streams",
@@ -4,4 +4,4 @@ from _sherpa_onnx import Display, OfflineStream, OnlineStream @@ -4,4 +4,4 @@ from _sherpa_onnx import Display, OfflineStream, OnlineStream
4 4
5 from .offline_recognizer import OfflineRecognizer 5 from .offline_recognizer import OfflineRecognizer
6 from .online_recognizer import OnlineRecognizer 6 from .online_recognizer import OnlineRecognizer
7 -from .utils import encode_contexts 7 +from .utils import text2token
  1 +# Copyright (c) 2023 Xiaomi Corporation
  2 +
  3 +import logging
  4 +import click
  5 +from pathlib import Path
  6 +from sherpa_onnx import text2token
  7 +
  8 +
  9 +@click.group()
  10 +def cli():
  11 + """
  12 + The shell entry point to sherpa-onnx.
  13 + """
  14 + logging.basicConfig(
  15 + format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
  16 + level=logging.INFO,
  17 + )
  18 +
  19 +
  20 +@cli.command(name="text2token")
  21 +@click.argument("input", type=click.Path(exists=True, dir_okay=False))
  22 +@click.argument("output", type=click.Path())
  23 +@click.option(
  24 + "--tokens",
  25 + type=str,
  26 + required=True,
  27 + help="The path to tokens.txt.",
  28 +)
  29 +@click.option(
  30 + "--tokens-type",
  31 + type=str,
  32 + required=True,
  33 + help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe",
  34 +)
  35 +@click.option(
  36 + "--bpe-model",
  37 + type=str,
  38 + help="The path to bpe.model. Only required when tokens-type is bpe or cjkchar+bpe.",
  39 +)
  40 +def encode_text(
  41 + input: Path, output: Path, tokens: Path, tokens_type: str, bpe_model: Path
  42 +):
  43 + """
  44 + Encode the texts given by the INPUT to tokens and write the results to the OUTPUT.
  45 + """
  46 + texts = []
  47 + with open(input, "r", encoding="utf8") as f:
  48 + for line in f:
  49 + texts.append(line.strip())
  50 + encoded_texts = text2token(
  51 + texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model
  52 + )
  53 + with open(output, "w", encoding="utf8") as f:
  54 + for txt in encoded_texts:
  55 + f.write(" ".join(txt) + "\n")
@@ -43,7 +43,8 @@ class OfflineRecognizer(object): @@ -43,7 +43,8 @@ class OfflineRecognizer(object):
43 feature_dim: int = 80, 43 feature_dim: int = 80,
44 decoding_method: str = "greedy_search", 44 decoding_method: str = "greedy_search",
45 max_active_paths: int = 4, 45 max_active_paths: int = 4,
46 - context_score: float = 1.5, 46 + hotwords_file: str = "",
  47 + hotwords_score: float = 1.5,
47 debug: bool = False, 48 debug: bool = False,
48 provider: str = "cpu", 49 provider: str = "cpu",
49 ): 50 ):
@@ -105,7 +106,8 @@ class OfflineRecognizer(object): @@ -105,7 +106,8 @@ class OfflineRecognizer(object):
105 feat_config=feat_config, 106 feat_config=feat_config,
106 model_config=model_config, 107 model_config=model_config,
107 decoding_method=decoding_method, 108 decoding_method=decoding_method,
108 - context_score=context_score, 109 + hotwords_file=hotwords_file,
  110 + hotwords_score=hotwords_score,
109 ) 111 )
110 self.recognizer = _Recognizer(recognizer_config) 112 self.recognizer = _Recognizer(recognizer_config)
111 self.config = recognizer_config 113 self.config = recognizer_config
@@ -379,11 +381,11 @@ class OfflineRecognizer(object): @@ -379,11 +381,11 @@ class OfflineRecognizer(object):
379 self.config = recognizer_config 381 self.config = recognizer_config
380 return self 382 return self
381 383
382 - def create_stream(self, contexts_list: Optional[List[List[int]]] = None):  
383 - if contexts_list is None: 384 + def create_stream(self, hotwords: Optional[str] = None):
  385 + if hotwords is None:
384 return self.recognizer.create_stream() 386 return self.recognizer.create_stream()
385 else: 387 else:
386 - return self.recognizer.create_stream(contexts_list) 388 + return self.recognizer.create_stream(hotwords)
387 389
388 def decode_stream(self, s: OfflineStream): 390 def decode_stream(self, s: OfflineStream):
389 self.recognizer.decode_stream(s) 391 self.recognizer.decode_stream(s)
@@ -42,7 +42,8 @@ class OnlineRecognizer(object): @@ -42,7 +42,8 @@ class OnlineRecognizer(object):
42 rule3_min_utterance_length: float = 20.0, 42 rule3_min_utterance_length: float = 20.0,
43 decoding_method: str = "greedy_search", 43 decoding_method: str = "greedy_search",
44 max_active_paths: int = 4, 44 max_active_paths: int = 4,
45 - context_score: float = 1.5, 45 + hotwords_score: float = 1.5,
  46 + hotwords_file: str = "",
46 provider: str = "cpu", 47 provider: str = "cpu",
47 model_type: str = "", 48 model_type: str = "",
48 ): 49 ):
@@ -138,7 +139,8 @@ class OnlineRecognizer(object): @@ -138,7 +139,8 @@ class OnlineRecognizer(object):
138 enable_endpoint=enable_endpoint_detection, 139 enable_endpoint=enable_endpoint_detection,
139 decoding_method=decoding_method, 140 decoding_method=decoding_method,
140 max_active_paths=max_active_paths, 141 max_active_paths=max_active_paths,
141 - context_score=context_score, 142 + hotwords_score=hotwords_score,
  143 + hotwords_file=hotwords_file,
142 ) 144 )
143 145
144 self.recognizer = _Recognizer(recognizer_config) 146 self.recognizer = _Recognizer(recognizer_config)
@@ -248,11 +250,11 @@ class OnlineRecognizer(object): @@ -248,11 +250,11 @@ class OnlineRecognizer(object):
248 self.config = recognizer_config 250 self.config = recognizer_config
249 return self 251 return self
250 252
251 - def create_stream(self, contexts_list: Optional[List[List[int]]] = None):  
252 - if contexts_list is None: 253 + def create_stream(self, hotwords: Optional[str] = None):
  254 + if hotwords is None:
253 return self.recognizer.create_stream() 255 return self.recognizer.create_stream()
254 else: 256 else:
255 - return self.recognizer.create_stream(contexts_list) 257 + return self.recognizer.create_stream(hotwords)
256 258
257 def decode_stream(self, s: OnlineStream): 259 def decode_stream(self, s: OnlineStream):
258 self.recognizer.decode_stream(s) 260 self.recognizer.decode_stream(s)
1 -from typing import Dict, List, Optional 1 +# Copyright (c) 2023 Xiaomi Corporation
  2 +import re
2 3
  4 +from pathlib import Path
  5 +from typing import List, Optional, Union
3 6
4 -def encode_contexts(  
5 - modeling_unit: str,  
6 - contexts: List[str],  
7 - sp: Optional["SentencePieceProcessor"] = None,  
8 - tokens_table: Optional[Dict[str, int]] = None,  
9 -) -> List[List[int]]: 7 +import sentencepiece as spm
  8 +
  9 +
  10 +def text2token(
  11 + texts: List[str],
  12 + tokens: str,
  13 + tokens_type: str = "cjkchar",
  14 + bpe_model: Optional[str] = None,
  15 + output_ids: bool = False,
  16 +) -> List[List[Union[str, int]]]:
10 """ 17 """
11 - Encode the given contexts (a list of string) to a list of a list of token ids. 18 + Encode the given texts (a list of string) to a list of a list of tokens.
12 19
13 Args: 20 Args:
14 - modeling_unit:  
15 - The valid values are bpe, char, bpe+char.  
16 - Note: char here means characters in CJK languages, not English like languages.  
17 - contexts: 21 + texts:
18 The given contexts list (a list of string). 22 The given contexts list (a list of string).
19 - sp:  
20 - An instance of SentencePieceProcessor.  
21 - tokens_table:  
22 - The tokens_table containing the tokens and the corresponding ids. 23 + tokens:
  24 + The path of the tokens.txt.
  25 + tokens_type:
  26 + The valid values are cjkchar, bpe, cjkchar+bpe.
  27 + bpe_model:
  28 + The path of the bpe model. Only required when tokens_type is bpe or
  29 + cjkchar+bpe.
  30 + output_ids:
  31 + True to output token ids otherwise tokens.
23 Returns: 32 Returns:
24 - Return the contexts_list, it is a list of a list of token ids. 33 + Return the encoded texts, it is a list of a list of token ids if output_ids
  34 + is True, or it is a list of list of tokens.
25 """ 35 """
26 - contexts_list = []  
27 - if "bpe" in modeling_unit:  
28 - assert sp is not None  
29 - if "char" in modeling_unit:  
30 - assert tokens_table is not None  
31 - assert len(tokens_table) > 0, len(tokens_table) 36 + assert Path(tokens).is_file(), f"File not exists, {tokens}"
  37 + tokens_table = {}
  38 + with open(tokens, "r", encoding="utf-8") as f:
  39 + for line in f:
  40 + toks = line.strip().split()
  41 + assert len(toks) == 2, len(toks)
  42 + assert toks[0] not in tokens_table, f"Duplicate token: {toks} "
  43 + tokens_table[toks[0]] = int(toks[1])
32 44
33 - if "char" == modeling_unit:  
34 - for context in contexts:  
35 - assert ' ' not in context  
36 - ids = [  
37 - tokens_table[txt] if txt in tokens_table else tokens_table["<unk>"]  
38 - for txt in context  
39 - ]  
40 - contexts_list.append(ids)  
41 - elif "bpe" == modeling_unit:  
42 - contexts_list = sp.encode(contexts, out_type=int)  
43 - else:  
44 - assert modeling_unit == "bpe+char", modeling_unit 45 + if "bpe" in tokens_type:
  46 + assert Path(bpe_model).is_file(), f"File not exists, {bpe_model}"
  47 + sp = spm.SentencePieceProcessor()
  48 + sp.load(bpe_model)
45 49
  50 + texts_list: List[List[str]] = []
  51 +
  52 + if tokens_type == "cjkchar":
  53 + texts_list = [list("".join(text.split())) for text in texts]
  54 + elif tokens_type == "bpe":
  55 + texts_list = sp.encode(texts, out_type=str)
  56 + else:
  57 + assert (
  58 + tokens_type == "cjkchar+bpe"
  59 + ), f"Supported tokens_type are cjkchar, bpe, cjkchar+bpe, given {tokens_type}"
46 # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref: 60 # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
47 # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 61 # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
48 pattern = re.compile(r"([\u4e00-\u9fff])") 62 pattern = re.compile(r"([\u4e00-\u9fff])")
49 - for context in contexts: 63 + for text in texts:
50 # Example: 64 # Example:
51 # txt = "你好 ITS'S OKAY 的" 65 # txt = "你好 ITS'S OKAY 的"
52 # chars = ["你", "好", " ITS'S OKAY ", "的"] 66 # chars = ["你", "好", " ITS'S OKAY ", "的"]
53 - chars = pattern.split(context.upper()) 67 + chars = pattern.split(text)
54 mix_chars = [w for w in chars if len(w.strip()) > 0] 68 mix_chars = [w for w in chars if len(w.strip()) > 0]
55 - ids = [] 69 + text_list = []
56 for ch_or_w in mix_chars: 70 for ch_or_w in mix_chars:
57 # ch_or_w is a single CJK charater(i.e., "你"), do nothing. 71 # ch_or_w is a single CJK charater(i.e., "你"), do nothing.
58 if pattern.fullmatch(ch_or_w) is not None: 72 if pattern.fullmatch(ch_or_w) is not None:
59 - ids.append(  
60 - tokens_table[ch_or_w]  
61 - if ch_or_w in tokens_table  
62 - else tokens_table["<unk>"]  
63 - ) 73 + text_list.append(ch_or_w)
64 # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "), 74 # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
65 # encode ch_or_w using bpe_model. 75 # encode ch_or_w using bpe_model.
66 else: 76 else:
67 - for p in sp.encode_as_pieces(ch_or_w):  
68 - ids.append(  
69 - tokens_table[p]  
70 - if p in tokens_table  
71 - else tokens_table["<unk>"]  
72 - )  
73 - contexts_list.append(ids)  
74 - return contexts_list 77 + text_list += sp.encode_as_pieces(ch_or_w)
  78 + texts_list.append(text_list)
  79 +
  80 + result: List[List[Union[int, str]]] = []
  81 + for text in texts_list:
  82 + text_list = []
  83 + contain_oov = False
  84 + for txt in text:
  85 + if txt in tokens_table:
  86 + text_list.append(tokens_table[txt] if output_ids else txt)
  87 + else:
  88 + print(f"OOV token : {txt}, skipping text : {text}.")
  89 + contain_oov = True
  90 + break
  91 + if contain_oov:
  92 + continue
  93 + else:
  94 + result.append(text_list)
  95 + return result
@@ -6,12 +6,14 @@ function(sherpa_onnx_add_py_test source) @@ -6,12 +6,14 @@ function(sherpa_onnx_add_py_test source)
6 COMMAND 6 COMMAND
7 "${PYTHON_EXECUTABLE}" 7 "${PYTHON_EXECUTABLE}"
8 "${CMAKE_CURRENT_SOURCE_DIR}/${source}" 8 "${CMAKE_CURRENT_SOURCE_DIR}/${source}"
  9 + WORKING_DIRECTORY
  10 + ${CMAKE_CURRENT_SOURCE_DIR}
9 ) 11 )
10 12
11 get_filename_component(sherpa_onnx_path ${CMAKE_CURRENT_LIST_DIR} DIRECTORY) 13 get_filename_component(sherpa_onnx_path ${CMAKE_CURRENT_LIST_DIR} DIRECTORY)
12 14
13 set_property(TEST ${name} 15 set_property(TEST ${name}
14 - PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_path}:$<TARGET_FILE_DIR:_sherpa_onnx>:$ENV{PYTHONPATH}" 16 + PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_onnx_path}:$<TARGET_FILE_DIR:_sherpa_onnx>:$ENV{PYTHONPATH}"
15 ) 17 )
16 endfunction() 18 endfunction()
17 19
@@ -21,6 +23,7 @@ set(py_test_files @@ -21,6 +23,7 @@ set(py_test_files
21 test_offline_recognizer.py 23 test_offline_recognizer.py
22 test_online_recognizer.py 24 test_online_recognizer.py
23 test_online_transducer_model_config.py 25 test_online_transducer_model_config.py
  26 + test_text2token.py
24 ) 27 )
25 28
26 foreach(source IN LISTS py_test_files) 29 foreach(source IN LISTS py_test_files)
  1 +# sherpa-onnx/python/tests/test_text2token.py
  2 +#
  3 +# Copyright (c) 2023 Xiaomi Corporation
  4 +#
  5 +# To run this single test, use
  6 +#
  7 +# ctest --verbose -R test_text2token_py
  8 +
  9 +import unittest
  10 +from pathlib import Path
  11 +
  12 +import sherpa_onnx
  13 +
  14 +d = "/tmp/sherpa-test-data"
  15 +# Please refer to
  16 +# https://github.com/pkufool/sherpa-test-data
  17 +# to download test data for testing
  18 +
  19 +
  20 +class TestText2Token(unittest.TestCase):
  21 + def test_bpe(self):
  22 + tokens = f"{d}/text2token/tokens_en.txt"
  23 + bpe_model = f"{d}/text2token/bpe_en.model"
  24 +
  25 + if not Path(tokens).is_file() or not Path(bpe_model).is_file():
  26 + print(
  27 + f"No test data found, skipping test_bpe().\n"
  28 + f"You can download the test data by: \n"
  29 + f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
  30 + )
  31 + return
  32 +
  33 + texts = ["HELLO WORLD", "I LOVE YOU"]
  34 + encoded_texts = sherpa_onnx.text2token(
  35 + texts,
  36 + tokens=tokens,
  37 + tokens_type="bpe",
  38 + bpe_model=bpe_model,
  39 + )
  40 + assert encoded_texts == [
  41 + ["▁HE", "LL", "O", "▁WORLD"],
  42 + ["▁I", "▁LOVE", "▁YOU"],
  43 + ], encoded_texts
  44 +
  45 + encoded_ids = sherpa_onnx.text2token(
  46 + texts,
  47 + tokens=tokens,
  48 + tokens_type="bpe",
  49 + bpe_model=bpe_model,
  50 + output_ids=True,
  51 + )
  52 + assert encoded_ids == [[22, 58, 24, 425], [19, 370, 47]], encoded_ids
  53 +
  54 + def test_cjkchar(self):
  55 + tokens = f"{d}/text2token/tokens_cn.txt"
  56 +
  57 + if not Path(tokens).is_file():
  58 + print(
  59 + f"No test data found, skipping test_cjkchar().\n"
  60 + f"You can download the test data by: \n"
  61 + f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
  62 + )
  63 + return
  64 +
  65 + texts = ["世界人民大团结", "中国 VS 美国"]
  66 + encoded_texts = sherpa_onnx.text2token(
  67 + texts, tokens=tokens, tokens_type="cjkchar"
  68 + )
  69 + assert encoded_texts == [
  70 + ["世", "界", "人", "民", "大", "团", "结"],
  71 + ["中", "国", "V", "S", "美", "国"],
  72 + ], encoded_texts
  73 + encoded_ids = sherpa_onnx.text2token(
  74 + texts,
  75 + tokens=tokens,
  76 + tokens_type="cjkchar",
  77 + output_ids=True,
  78 + )
  79 + assert encoded_ids == [
  80 + [379, 380, 72, 874, 93, 1251, 489],
  81 + [262, 147, 3423, 2476, 21, 147],
  82 + ], encoded_ids
  83 +
  84 + def test_cjkchar_bpe(self):
  85 + tokens = f"{d}/text2token/tokens_mix.txt"
  86 + bpe_model = f"{d}/text2token/bpe_mix.model"
  87 +
  88 + if not Path(tokens).is_file() or not Path(bpe_model).is_file():
  89 + print(
  90 + f"No test data found, skipping test_cjkchar_bpe().\n"
  91 + f"You can download the test data by: \n"
  92 + f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
  93 + )
  94 + return
  95 +
  96 + texts = ["世界人民 GOES TOGETHER", "中国 GOES WITH 美国"]
  97 + encoded_texts = sherpa_onnx.text2token(
  98 + texts,
  99 + tokens=tokens,
  100 + tokens_type="cjkchar+bpe",
  101 + bpe_model=bpe_model,
  102 + )
  103 + assert encoded_texts == [
  104 + ["世", "界", "人", "民", "▁GO", "ES", "▁TOGETHER"],
  105 + ["中", "国", "▁GO", "ES", "▁WITH", "美", "国"],
  106 + ], encoded_texts
  107 + encoded_ids = sherpa_onnx.text2token(
  108 + texts,
  109 + tokens=tokens,
  110 + tokens_type="cjkchar+bpe",
  111 + bpe_model=bpe_model,
  112 + output_ids=True,
  113 + )
  114 + assert encoded_ids == [
  115 + [1368, 1392, 557, 680, 275, 178, 475],
  116 + [685, 736, 275, 178, 179, 921, 736],
  117 + ], encoded_ids
  118 +
  119 +
  120 +if __name__ == "__main__":
  121 + unittest.main()