Committed by
GitHub
Refactor hotwords,support loading hotwords from file (#296)
正在显示
34 个修改的文件
包含
800 行增加
和
297 行删除
| @@ -166,3 +166,8 @@ python3 ./python-api-examples/offline-decode-files.py \ | @@ -166,3 +166,8 @@ python3 ./python-api-examples/offline-decode-files.py \ | ||
| 166 | python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose | 166 | python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose |
| 167 | 167 | ||
| 168 | rm -rf $repo | 168 | rm -rf $repo |
| 169 | + | ||
| 170 | +# test text2token | ||
| 171 | +git clone https://github.com/pkufool/sherpa-test-data /tmp/sherpa-test-data | ||
| 172 | + | ||
| 173 | +python3 sherpa-onnx/python/tests/test_text2token.py --verbose |
| @@ -39,7 +39,7 @@ jobs: | @@ -39,7 +39,7 @@ jobs: | ||
| 39 | - name: Install Python dependencies | 39 | - name: Install Python dependencies |
| 40 | shell: bash | 40 | shell: bash |
| 41 | run: | | 41 | run: | |
| 42 | - python3 -m pip install --upgrade pip numpy | 42 | + python3 -m pip install --upgrade pip numpy sentencepiece |
| 43 | 43 | ||
| 44 | - name: Install sherpa-onnx | 44 | - name: Install sherpa-onnx |
| 45 | shell: bash | 45 | shell: bash |
| @@ -39,7 +39,7 @@ jobs: | @@ -39,7 +39,7 @@ jobs: | ||
| 39 | - name: Install Python dependencies | 39 | - name: Install Python dependencies |
| 40 | shell: bash | 40 | shell: bash |
| 41 | run: | | 41 | run: | |
| 42 | - python3 -m pip install --upgrade pip numpy | 42 | + python3 -m pip install --upgrade pip numpy sentencepiece |
| 43 | 43 | ||
| 44 | - name: Install sherpa-onnx | 44 | - name: Install sherpa-onnx |
| 45 | shell: bash | 45 | shell: bash |
| @@ -326,6 +326,31 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser): | @@ -326,6 +326,31 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser): | ||
| 326 | ) | 326 | ) |
| 327 | 327 | ||
| 328 | 328 | ||
| 329 | +def add_hotwords_args(parser: argparse.ArgumentParser): | ||
| 330 | + parser.add_argument( | ||
| 331 | + "--hotwords-file", | ||
| 332 | + type=str, | ||
| 333 | + default="", | ||
| 334 | + help=""" | ||
| 335 | + The file containing hotwords, one words/phrases per line, and for each | ||
| 336 | + phrase the bpe/cjkchar are separated by a space. For example: | ||
| 337 | + | ||
| 338 | + ▁HE LL O ▁WORLD | ||
| 339 | + 你 好 世 界 | ||
| 340 | + """, | ||
| 341 | + ) | ||
| 342 | + | ||
| 343 | + parser.add_argument( | ||
| 344 | + "--hotwords-score", | ||
| 345 | + type=float, | ||
| 346 | + default=1.5, | ||
| 347 | + help=""" | ||
| 348 | + The hotword score of each token for biasing word/phrase. Used only if | ||
| 349 | + --hotwords-file is given. | ||
| 350 | + """, | ||
| 351 | + ) | ||
| 352 | + | ||
| 353 | + | ||
| 329 | def check_args(args): | 354 | def check_args(args): |
| 330 | if not Path(args.tokens).is_file(): | 355 | if not Path(args.tokens).is_file(): |
| 331 | raise ValueError(f"{args.tokens} does not exist") | 356 | raise ValueError(f"{args.tokens} does not exist") |
| @@ -342,6 +367,10 @@ def check_args(args): | @@ -342,6 +367,10 @@ def check_args(args): | ||
| 342 | assert Path(args.decoder).is_file(), args.decoder | 367 | assert Path(args.decoder).is_file(), args.decoder |
| 343 | assert Path(args.joiner).is_file(), args.joiner | 368 | assert Path(args.joiner).is_file(), args.joiner |
| 344 | 369 | ||
| 370 | + if args.hotwords_file != "": | ||
| 371 | + assert args.decoding_method == "modified_beam_search", args.decoding_method | ||
| 372 | + assert Path(args.hotwords_file).is_file(), args.hotwords_file | ||
| 373 | + | ||
| 345 | 374 | ||
| 346 | def get_args(): | 375 | def get_args(): |
| 347 | parser = argparse.ArgumentParser( | 376 | parser = argparse.ArgumentParser( |
| @@ -351,6 +380,7 @@ def get_args(): | @@ -351,6 +380,7 @@ def get_args(): | ||
| 351 | add_model_args(parser) | 380 | add_model_args(parser) |
| 352 | add_feature_config_args(parser) | 381 | add_feature_config_args(parser) |
| 353 | add_decoding_args(parser) | 382 | add_decoding_args(parser) |
| 383 | + add_hotwords_args(parser) | ||
| 354 | 384 | ||
| 355 | parser.add_argument( | 385 | parser.add_argument( |
| 356 | "--port", | 386 | "--port", |
| @@ -792,6 +822,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -792,6 +822,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 792 | feature_dim=args.feat_dim, | 822 | feature_dim=args.feat_dim, |
| 793 | decoding_method=args.decoding_method, | 823 | decoding_method=args.decoding_method, |
| 794 | max_active_paths=args.max_active_paths, | 824 | max_active_paths=args.max_active_paths, |
| 825 | + hotwords_file=args.hotwords_file, | ||
| 826 | + hotwords_score=args.hotwords_score, | ||
| 795 | ) | 827 | ) |
| 796 | elif args.paraformer: | 828 | elif args.paraformer: |
| 797 | assert len(args.nemo_ctc) == 0, args.nemo_ctc | 829 | assert len(args.nemo_ctc) == 0, args.nemo_ctc |
| @@ -82,7 +82,6 @@ from pathlib import Path | @@ -82,7 +82,6 @@ from pathlib import Path | ||
| 82 | from typing import List, Tuple | 82 | from typing import List, Tuple |
| 83 | 83 | ||
| 84 | import numpy as np | 84 | import numpy as np |
| 85 | -import sentencepiece as spm | ||
| 86 | import sherpa_onnx | 85 | import sherpa_onnx |
| 87 | 86 | ||
| 88 | 87 | ||
| @@ -98,43 +97,25 @@ def get_args(): | @@ -98,43 +97,25 @@ def get_args(): | ||
| 98 | ) | 97 | ) |
| 99 | 98 | ||
| 100 | parser.add_argument( | 99 | parser.add_argument( |
| 101 | - "--bpe-model", | 100 | + "--hotwords-file", |
| 102 | type=str, | 101 | type=str, |
| 103 | default="", | 102 | default="", |
| 104 | help=""" | 103 | help=""" |
| 105 | - Path to bpe.model, | ||
| 106 | - Used only when --decoding-method=modified_beam_search | ||
| 107 | - """, | ||
| 108 | - ) | 104 | + The file containing hotwords, one words/phrases per line, and for each |
| 105 | + phrase the bpe/cjkchar are separated by a space. For example: | ||
| 109 | 106 | ||
| 110 | - parser.add_argument( | ||
| 111 | - "--modeling-unit", | ||
| 112 | - type=str, | ||
| 113 | - default="char", | ||
| 114 | - help=""" | ||
| 115 | - The type of modeling unit. | ||
| 116 | - Valid values are bpe, bpe+char, char. | ||
| 117 | - Note: the char here means characters in CJK languages. | 107 | + ▁HE LL O ▁WORLD |
| 108 | + 你 好 世 界 | ||
| 118 | """, | 109 | """, |
| 119 | ) | 110 | ) |
| 120 | 111 | ||
| 121 | parser.add_argument( | 112 | parser.add_argument( |
| 122 | - "--contexts", | ||
| 123 | - type=str, | ||
| 124 | - default="", | ||
| 125 | - help=""" | ||
| 126 | - The context list, it is a string containing some words/phrases separated | ||
| 127 | - with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY". | ||
| 128 | - """, | ||
| 129 | - ) | ||
| 130 | - | ||
| 131 | - parser.add_argument( | ||
| 132 | - "--context-score", | 113 | + "--hotwords-score", |
| 133 | type=float, | 114 | type=float, |
| 134 | default=1.5, | 115 | default=1.5, |
| 135 | help=""" | 116 | help=""" |
| 136 | - The context score of each token for biasing word/phrase. Used only if | ||
| 137 | - --contexts is given. | 117 | + The hotword score of each token for biasing word/phrase. Used only if |
| 118 | + --hotwords-file is given. | ||
| 138 | """, | 119 | """, |
| 139 | ) | 120 | ) |
| 140 | 121 | ||
| @@ -273,25 +254,6 @@ def assert_file_exists(filename: str): | @@ -273,25 +254,6 @@ def assert_file_exists(filename: str): | ||
| 273 | "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" | 254 | "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" |
| 274 | ) | 255 | ) |
| 275 | 256 | ||
| 276 | - | ||
| 277 | -def encode_contexts(args, contexts: List[str]) -> List[List[int]]: | ||
| 278 | - sp = None | ||
| 279 | - if "bpe" in args.modeling_unit: | ||
| 280 | - assert_file_exists(args.bpe_model) | ||
| 281 | - sp = spm.SentencePieceProcessor() | ||
| 282 | - sp.load(args.bpe_model) | ||
| 283 | - tokens = {} | ||
| 284 | - with open(args.tokens, "r", encoding="utf-8") as f: | ||
| 285 | - for line in f: | ||
| 286 | - toks = line.strip().split() | ||
| 287 | - assert len(toks) == 2, len(toks) | ||
| 288 | - assert toks[0] not in tokens, f"Duplicate token: {toks} " | ||
| 289 | - tokens[toks[0]] = int(toks[1]) | ||
| 290 | - return sherpa_onnx.encode_contexts( | ||
| 291 | - modeling_unit=args.modeling_unit, contexts=contexts, sp=sp, tokens_table=tokens | ||
| 292 | - ) | ||
| 293 | - | ||
| 294 | - | ||
| 295 | def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | 257 | def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: |
| 296 | """ | 258 | """ |
| 297 | Args: | 259 | Args: |
| @@ -322,7 +284,6 @@ def main(): | @@ -322,7 +284,6 @@ def main(): | ||
| 322 | assert_file_exists(args.tokens) | 284 | assert_file_exists(args.tokens) |
| 323 | assert args.num_threads > 0, args.num_threads | 285 | assert args.num_threads > 0, args.num_threads |
| 324 | 286 | ||
| 325 | - contexts_list = [] | ||
| 326 | if args.encoder: | 287 | if args.encoder: |
| 327 | assert len(args.paraformer) == 0, args.paraformer | 288 | assert len(args.paraformer) == 0, args.paraformer |
| 328 | assert len(args.nemo_ctc) == 0, args.nemo_ctc | 289 | assert len(args.nemo_ctc) == 0, args.nemo_ctc |
| @@ -330,11 +291,6 @@ def main(): | @@ -330,11 +291,6 @@ def main(): | ||
| 330 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 291 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 331 | assert len(args.tdnn_model) == 0, args.tdnn_model | 292 | assert len(args.tdnn_model) == 0, args.tdnn_model |
| 332 | 293 | ||
| 333 | - contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] | ||
| 334 | - if contexts: | ||
| 335 | - print(f"Contexts list: {contexts}") | ||
| 336 | - contexts_list = encode_contexts(args, contexts) | ||
| 337 | - | ||
| 338 | assert_file_exists(args.encoder) | 294 | assert_file_exists(args.encoder) |
| 339 | assert_file_exists(args.decoder) | 295 | assert_file_exists(args.decoder) |
| 340 | assert_file_exists(args.joiner) | 296 | assert_file_exists(args.joiner) |
| @@ -348,7 +304,8 @@ def main(): | @@ -348,7 +304,8 @@ def main(): | ||
| 348 | sample_rate=args.sample_rate, | 304 | sample_rate=args.sample_rate, |
| 349 | feature_dim=args.feature_dim, | 305 | feature_dim=args.feature_dim, |
| 350 | decoding_method=args.decoding_method, | 306 | decoding_method=args.decoding_method, |
| 351 | - context_score=args.context_score, | 307 | + hotwords_file=args.hotwords_file, |
| 308 | + hotwords_score=args.hotwords_score, | ||
| 352 | debug=args.debug, | 309 | debug=args.debug, |
| 353 | ) | 310 | ) |
| 354 | elif args.paraformer: | 311 | elif args.paraformer: |
| @@ -425,12 +382,7 @@ def main(): | @@ -425,12 +382,7 @@ def main(): | ||
| 425 | samples, sample_rate = read_wave(wave_filename) | 382 | samples, sample_rate = read_wave(wave_filename) |
| 426 | duration = len(samples) / sample_rate | 383 | duration = len(samples) / sample_rate |
| 427 | total_duration += duration | 384 | total_duration += duration |
| 428 | - if contexts_list: | ||
| 429 | - assert len(args.paraformer) == 0, args.paraformer | ||
| 430 | - assert len(args.nemo_ctc) == 0, args.nemo_ctc | ||
| 431 | - s = recognizer.create_stream(contexts_list=contexts_list) | ||
| 432 | - else: | ||
| 433 | - s = recognizer.create_stream() | 385 | + s = recognizer.create_stream() |
| 434 | s.accept_waveform(sample_rate, samples) | 386 | s.accept_waveform(sample_rate, samples) |
| 435 | 387 | ||
| 436 | streams.append(s) | 388 | streams.append(s) |
| @@ -48,7 +48,6 @@ from pathlib import Path | @@ -48,7 +48,6 @@ from pathlib import Path | ||
| 48 | from typing import List, Tuple | 48 | from typing import List, Tuple |
| 49 | 49 | ||
| 50 | import numpy as np | 50 | import numpy as np |
| 51 | -import sentencepiece as spm | ||
| 52 | import sherpa_onnx | 51 | import sherpa_onnx |
| 53 | 52 | ||
| 54 | 53 | ||
| @@ -124,46 +123,25 @@ def get_args(): | @@ -124,46 +123,25 @@ def get_args(): | ||
| 124 | ) | 123 | ) |
| 125 | 124 | ||
| 126 | parser.add_argument( | 125 | parser.add_argument( |
| 127 | - "--bpe-model", | 126 | + "--hotwords-file", |
| 128 | type=str, | 127 | type=str, |
| 129 | default="", | 128 | default="", |
| 130 | help=""" | 129 | help=""" |
| 131 | - Path to bpe.model, it will be used to tokenize contexts biasing phrases. | ||
| 132 | - Used only when --decoding-method=modified_beam_search | ||
| 133 | - """, | ||
| 134 | - ) | ||
| 135 | - | ||
| 136 | - parser.add_argument( | ||
| 137 | - "--modeling-unit", | ||
| 138 | - type=str, | ||
| 139 | - default="char", | ||
| 140 | - help=""" | ||
| 141 | - The type of modeling unit, it will be used to tokenize contexts biasing phrases. | ||
| 142 | - Valid values are bpe, bpe+char, char. | ||
| 143 | - Note: the char here means characters in CJK languages. | ||
| 144 | - Used only when --decoding-method=modified_beam_search | ||
| 145 | - """, | ||
| 146 | - ) | 130 | + The file containing hotwords, one words/phrases per line, and for each |
| 131 | + phrase the bpe/cjkchar are separated by a space. For example: | ||
| 147 | 132 | ||
| 148 | - parser.add_argument( | ||
| 149 | - "--contexts", | ||
| 150 | - type=str, | ||
| 151 | - default="", | ||
| 152 | - help=""" | ||
| 153 | - The context list, it is a string containing some words/phrases separated | ||
| 154 | - with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY". | ||
| 155 | - Used only when --decoding-method=modified_beam_search | 133 | + ▁HE LL O ▁WORLD |
| 134 | + 你 好 世 界 | ||
| 156 | """, | 135 | """, |
| 157 | ) | 136 | ) |
| 158 | 137 | ||
| 159 | parser.add_argument( | 138 | parser.add_argument( |
| 160 | - "--context-score", | 139 | + "--hotwords-score", |
| 161 | type=float, | 140 | type=float, |
| 162 | default=1.5, | 141 | default=1.5, |
| 163 | help=""" | 142 | help=""" |
| 164 | - The context score of each token for biasing word/phrase. Used only if | ||
| 165 | - --contexts is given. | ||
| 166 | - Used only when --decoding-method=modified_beam_search | 143 | + The hotword score of each token for biasing word/phrase. Used only if |
| 144 | + --hotwords-file is given. | ||
| 167 | """, | 145 | """, |
| 168 | ) | 146 | ) |
| 169 | 147 | ||
| @@ -214,27 +192,6 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | @@ -214,27 +192,6 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | ||
| 214 | return samples_float32, f.getframerate() | 192 | return samples_float32, f.getframerate() |
| 215 | 193 | ||
| 216 | 194 | ||
| 217 | -def encode_contexts(args, contexts: List[str]) -> List[List[int]]: | ||
| 218 | - sp = None | ||
| 219 | - if "bpe" in args.modeling_unit: | ||
| 220 | - assert_file_exists(args.bpe_model) | ||
| 221 | - sp = spm.SentencePieceProcessor() | ||
| 222 | - sp.load(args.bpe_model) | ||
| 223 | - tokens = {} | ||
| 224 | - with open(args.tokens, "r", encoding="utf-8") as f: | ||
| 225 | - for line in f: | ||
| 226 | - toks = line.strip().split() | ||
| 227 | - assert len(toks) == 2, len(toks) | ||
| 228 | - assert toks[0] not in tokens, f"Duplicate token: {toks} " | ||
| 229 | - tokens[toks[0]] = int(toks[1]) | ||
| 230 | - return sherpa_onnx.encode_contexts( | ||
| 231 | - modeling_unit=args.modeling_unit, | ||
| 232 | - contexts=contexts, | ||
| 233 | - sp=sp, | ||
| 234 | - tokens_table=tokens, | ||
| 235 | - ) | ||
| 236 | - | ||
| 237 | - | ||
| 238 | def main(): | 195 | def main(): |
| 239 | args = get_args() | 196 | args = get_args() |
| 240 | assert_file_exists(args.tokens) | 197 | assert_file_exists(args.tokens) |
| @@ -258,7 +215,8 @@ def main(): | @@ -258,7 +215,8 @@ def main(): | ||
| 258 | feature_dim=80, | 215 | feature_dim=80, |
| 259 | decoding_method=args.decoding_method, | 216 | decoding_method=args.decoding_method, |
| 260 | max_active_paths=args.max_active_paths, | 217 | max_active_paths=args.max_active_paths, |
| 261 | - context_score=args.context_score, | 218 | + hotwords_file=args.hotwords_file, |
| 219 | + hotwords_score=args.hotwords_score, | ||
| 262 | ) | 220 | ) |
| 263 | elif args.paraformer_encoder: | 221 | elif args.paraformer_encoder: |
| 264 | recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer( | 222 | recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer( |
| @@ -277,12 +235,6 @@ def main(): | @@ -277,12 +235,6 @@ def main(): | ||
| 277 | print("Started!") | 235 | print("Started!") |
| 278 | start_time = time.time() | 236 | start_time = time.time() |
| 279 | 237 | ||
| 280 | - contexts_list = [] | ||
| 281 | - contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] | ||
| 282 | - if contexts: | ||
| 283 | - print(f"Contexts list: {contexts}") | ||
| 284 | - contexts_list = encode_contexts(args, contexts) | ||
| 285 | - | ||
| 286 | streams = [] | 238 | streams = [] |
| 287 | total_duration = 0 | 239 | total_duration = 0 |
| 288 | for wave_filename in args.sound_files: | 240 | for wave_filename in args.sound_files: |
| @@ -291,10 +243,7 @@ def main(): | @@ -291,10 +243,7 @@ def main(): | ||
| 291 | duration = len(samples) / sample_rate | 243 | duration = len(samples) / sample_rate |
| 292 | total_duration += duration | 244 | total_duration += duration |
| 293 | 245 | ||
| 294 | - if contexts_list: | ||
| 295 | - s = recognizer.create_stream(contexts_list=contexts_list) | ||
| 296 | - else: | ||
| 297 | - s = recognizer.create_stream() | 246 | + s = recognizer.create_stream() |
| 298 | 247 | ||
| 299 | s.accept_waveform(sample_rate, samples) | 248 | s.accept_waveform(sample_rate, samples) |
| 300 | 249 |
| @@ -79,6 +79,30 @@ def get_args(): | @@ -79,6 +79,30 @@ def get_args(): | ||
| 79 | help="Valid values: cpu, cuda, coreml", | 79 | help="Valid values: cpu, cuda, coreml", |
| 80 | ) | 80 | ) |
| 81 | 81 | ||
| 82 | + parser.add_argument( | ||
| 83 | + "--hotwords-file", | ||
| 84 | + type=str, | ||
| 85 | + default="", | ||
| 86 | + help=""" | ||
| 87 | + The file containing hotwords, one words/phrases per line, and for each | ||
| 88 | + phrase the bpe/cjkchar are separated by a space. For example: | ||
| 89 | + | ||
| 90 | + ▁HE LL O ▁WORLD | ||
| 91 | + 你 好 世 界 | ||
| 92 | + """, | ||
| 93 | + ) | ||
| 94 | + | ||
| 95 | + parser.add_argument( | ||
| 96 | + "--hotwords-score", | ||
| 97 | + type=float, | ||
| 98 | + default=1.5, | ||
| 99 | + help=""" | ||
| 100 | + The hotword score of each token for biasing word/phrase. Used only if | ||
| 101 | + --hotwords-file is given. | ||
| 102 | + """, | ||
| 103 | + ) | ||
| 104 | + | ||
| 105 | + | ||
| 82 | return parser.parse_args() | 106 | return parser.parse_args() |
| 83 | 107 | ||
| 84 | 108 | ||
| @@ -104,6 +128,8 @@ def create_recognizer(args): | @@ -104,6 +128,8 @@ def create_recognizer(args): | ||
| 104 | rule3_min_utterance_length=300, # it essentially disables this rule | 128 | rule3_min_utterance_length=300, # it essentially disables this rule |
| 105 | decoding_method=args.decoding_method, | 129 | decoding_method=args.decoding_method, |
| 106 | provider=args.provider, | 130 | provider=args.provider, |
| 131 | + hotwords_file=agrs.hotwords_file, | ||
| 132 | + hotwords_score=args.hotwords_score, | ||
| 107 | ) | 133 | ) |
| 108 | return recognizer | 134 | return recognizer |
| 109 | 135 |
| @@ -11,7 +11,6 @@ import sys | @@ -11,7 +11,6 @@ import sys | ||
| 11 | from pathlib import Path | 11 | from pathlib import Path |
| 12 | 12 | ||
| 13 | from typing import List | 13 | from typing import List |
| 14 | -import sentencepiece as spm | ||
| 15 | 14 | ||
| 16 | try: | 15 | try: |
| 17 | import sounddevice as sd | 16 | import sounddevice as sd |
| @@ -90,49 +89,29 @@ def get_args(): | @@ -90,49 +89,29 @@ def get_args(): | ||
| 90 | ) | 89 | ) |
| 91 | 90 | ||
| 92 | parser.add_argument( | 91 | parser.add_argument( |
| 93 | - "--bpe-model", | 92 | + "--hotwords-file", |
| 94 | type=str, | 93 | type=str, |
| 95 | default="", | 94 | default="", |
| 96 | help=""" | 95 | help=""" |
| 97 | - Path to bpe.model, it will be used to tokenize contexts biasing phrases. | ||
| 98 | - Used only when --decoding-method=modified_beam_search | ||
| 99 | - """, | ||
| 100 | - ) | 96 | + The file containing hotwords, one words/phrases per line, and for each |
| 97 | + phrase the bpe/cjkchar are separated by a space. For example: | ||
| 101 | 98 | ||
| 102 | - parser.add_argument( | ||
| 103 | - "--modeling-unit", | ||
| 104 | - type=str, | ||
| 105 | - default="char", | ||
| 106 | - help=""" | ||
| 107 | - The type of modeling unit, it will be used to tokenize contexts biasing phrases. | ||
| 108 | - Valid values are bpe, bpe+char, char. | ||
| 109 | - Note: the char here means characters in CJK languages. | ||
| 110 | - Used only when --decoding-method=modified_beam_search | 99 | + ▁HE LL O ▁WORLD |
| 100 | + 你 好 世 界 | ||
| 111 | """, | 101 | """, |
| 112 | ) | 102 | ) |
| 113 | 103 | ||
| 114 | parser.add_argument( | 104 | parser.add_argument( |
| 115 | - "--contexts", | ||
| 116 | - type=str, | ||
| 117 | - default="", | ||
| 118 | - help=""" | ||
| 119 | - The context list, it is a string containing some words/phrases separated | ||
| 120 | - with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY". | ||
| 121 | - Used only when --decoding-method=modified_beam_search | ||
| 122 | - """, | ||
| 123 | - ) | ||
| 124 | - | ||
| 125 | - parser.add_argument( | ||
| 126 | - "--context-score", | 105 | + "--hotwords-score", |
| 127 | type=float, | 106 | type=float, |
| 128 | default=1.5, | 107 | default=1.5, |
| 129 | help=""" | 108 | help=""" |
| 130 | - The context score of each token for biasing word/phrase. Used only if | ||
| 131 | - --contexts is given. | ||
| 132 | - Used only when --decoding-method=modified_beam_search | 109 | + The hotword score of each token for biasing word/phrase. Used only if |
| 110 | + --hotwords-file is given. | ||
| 133 | """, | 111 | """, |
| 134 | ) | 112 | ) |
| 135 | 113 | ||
| 114 | + | ||
| 136 | return parser.parse_args() | 115 | return parser.parse_args() |
| 137 | 116 | ||
| 138 | 117 | ||
| @@ -155,32 +134,12 @@ def create_recognizer(args): | @@ -155,32 +134,12 @@ def create_recognizer(args): | ||
| 155 | decoding_method=args.decoding_method, | 134 | decoding_method=args.decoding_method, |
| 156 | max_active_paths=args.max_active_paths, | 135 | max_active_paths=args.max_active_paths, |
| 157 | provider=args.provider, | 136 | provider=args.provider, |
| 158 | - context_score=args.context_score, | 137 | + hotwords_file=args.hotwords_file, |
| 138 | + hotwords_score=args.hotwords_score, | ||
| 159 | ) | 139 | ) |
| 160 | return recognizer | 140 | return recognizer |
| 161 | 141 | ||
| 162 | 142 | ||
| 163 | -def encode_contexts(args, contexts: List[str]) -> List[List[int]]: | ||
| 164 | - sp = None | ||
| 165 | - if "bpe" in args.modeling_unit: | ||
| 166 | - assert_file_exists(args.bpe_model) | ||
| 167 | - sp = spm.SentencePieceProcessor() | ||
| 168 | - sp.load(args.bpe_model) | ||
| 169 | - tokens = {} | ||
| 170 | - with open(args.tokens, "r", encoding="utf-8") as f: | ||
| 171 | - for line in f: | ||
| 172 | - toks = line.strip().split() | ||
| 173 | - assert len(toks) == 2, len(toks) | ||
| 174 | - assert toks[0] not in tokens, f"Duplicate token: {toks} " | ||
| 175 | - tokens[toks[0]] = int(toks[1]) | ||
| 176 | - return sherpa_onnx.encode_contexts( | ||
| 177 | - modeling_unit=args.modeling_unit, | ||
| 178 | - contexts=contexts, | ||
| 179 | - sp=sp, | ||
| 180 | - tokens_table=tokens, | ||
| 181 | - ) | ||
| 182 | - | ||
| 183 | - | ||
| 184 | def main(): | 143 | def main(): |
| 185 | args = get_args() | 144 | args = get_args() |
| 186 | 145 | ||
| @@ -193,12 +152,6 @@ def main(): | @@ -193,12 +152,6 @@ def main(): | ||
| 193 | default_input_device_idx = sd.default.device[0] | 152 | default_input_device_idx = sd.default.device[0] |
| 194 | print(f'Use default device: {devices[default_input_device_idx]["name"]}') | 153 | print(f'Use default device: {devices[default_input_device_idx]["name"]}') |
| 195 | 154 | ||
| 196 | - contexts_list = [] | ||
| 197 | - contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] | ||
| 198 | - if contexts: | ||
| 199 | - print(f"Contexts list: {contexts}") | ||
| 200 | - contexts_list = encode_contexts(args, contexts) | ||
| 201 | - | ||
| 202 | recognizer = create_recognizer(args) | 155 | recognizer = create_recognizer(args) |
| 203 | print("Started! Please speak") | 156 | print("Started! Please speak") |
| 204 | 157 | ||
| @@ -207,10 +160,7 @@ def main(): | @@ -207,10 +160,7 @@ def main(): | ||
| 207 | sample_rate = 48000 | 160 | sample_rate = 48000 |
| 208 | samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms | 161 | samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms |
| 209 | last_result = "" | 162 | last_result = "" |
| 210 | - if contexts_list: | ||
| 211 | - stream = recognizer.create_stream(contexts_list=contexts_list) | ||
| 212 | - else: | ||
| 213 | - stream = recognizer.create_stream() | 163 | + stream = recognizer.create_stream() |
| 214 | with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: | 164 | with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: |
| 215 | while True: | 165 | while True: |
| 216 | samples, _ = s.read(samples_per_read) # a blocking read | 166 | samples, _ = s.read(samples_per_read) # a blocking read |
| @@ -87,6 +87,30 @@ def get_args(): | @@ -87,6 +87,30 @@ def get_args(): | ||
| 87 | """, | 87 | """, |
| 88 | ) | 88 | ) |
| 89 | 89 | ||
| 90 | + parser.add_argument( | ||
| 91 | + "--hotwords-file", | ||
| 92 | + type=str, | ||
| 93 | + default="", | ||
| 94 | + help=""" | ||
| 95 | + The file containing hotwords, one words/phrases per line, and for each | ||
| 96 | + phrase the bpe/cjkchar are separated by a space. For example: | ||
| 97 | + | ||
| 98 | + ▁HE LL O ▁WORLD | ||
| 99 | + 你 好 世 界 | ||
| 100 | + """, | ||
| 101 | + ) | ||
| 102 | + | ||
| 103 | + parser.add_argument( | ||
| 104 | + "--hotwords-score", | ||
| 105 | + type=float, | ||
| 106 | + default=1.5, | ||
| 107 | + help=""" | ||
| 108 | + The hotword score of each token for biasing word/phrase. Used only if | ||
| 109 | + --hotwords-file is given. | ||
| 110 | + """, | ||
| 111 | + ) | ||
| 112 | + | ||
| 113 | + | ||
| 90 | return parser.parse_args() | 114 | return parser.parse_args() |
| 91 | 115 | ||
| 92 | 116 | ||
| @@ -107,6 +131,8 @@ def create_recognizer(args): | @@ -107,6 +131,8 @@ def create_recognizer(args): | ||
| 107 | rule1_min_trailing_silence=2.4, | 131 | rule1_min_trailing_silence=2.4, |
| 108 | rule2_min_trailing_silence=1.2, | 132 | rule2_min_trailing_silence=1.2, |
| 109 | rule3_min_utterance_length=300, # it essentially disables this rule | 133 | rule3_min_utterance_length=300, # it essentially disables this rule |
| 134 | + hotwords_file=args.hotwords_file, | ||
| 135 | + hotwords_score=args.hotwords_score, | ||
| 110 | ) | 136 | ) |
| 111 | return recognizer | 137 | return recognizer |
| 112 | 138 |
| @@ -187,6 +187,32 @@ def add_decoding_args(parser: argparse.ArgumentParser): | @@ -187,6 +187,32 @@ def add_decoding_args(parser: argparse.ArgumentParser): | ||
| 187 | add_modified_beam_search_args(parser) | 187 | add_modified_beam_search_args(parser) |
| 188 | 188 | ||
| 189 | 189 | ||
| 190 | +def add_hotwords_args(parser: argparse.ArgumentParser): | ||
| 191 | + parser.add_argument( | ||
| 192 | + "--hotwords-file", | ||
| 193 | + type=str, | ||
| 194 | + default="", | ||
| 195 | + help=""" | ||
| 196 | + The file containing hotwords, one words/phrases per line, and for each | ||
| 197 | + phrase the bpe/cjkchar are separated by a space. For example: | ||
| 198 | + | ||
| 199 | + ▁HE LL O ▁WORLD | ||
| 200 | + 你 好 世 界 | ||
| 201 | + """, | ||
| 202 | + ) | ||
| 203 | + | ||
| 204 | + parser.add_argument( | ||
| 205 | + "--hotwords-score", | ||
| 206 | + type=float, | ||
| 207 | + default=1.5, | ||
| 208 | + help=""" | ||
| 209 | + The hotword score of each token for biasing word/phrase. Used only if | ||
| 210 | + --hotwords-file is given. | ||
| 211 | + """, | ||
| 212 | + ) | ||
| 213 | + | ||
| 214 | + | ||
| 215 | + | ||
| 190 | def add_modified_beam_search_args(parser: argparse.ArgumentParser): | 216 | def add_modified_beam_search_args(parser: argparse.ArgumentParser): |
| 191 | parser.add_argument( | 217 | parser.add_argument( |
| 192 | "--num-active-paths", | 218 | "--num-active-paths", |
| @@ -239,6 +265,7 @@ def get_args(): | @@ -239,6 +265,7 @@ def get_args(): | ||
| 239 | add_model_args(parser) | 265 | add_model_args(parser) |
| 240 | add_decoding_args(parser) | 266 | add_decoding_args(parser) |
| 241 | add_endpointing_args(parser) | 267 | add_endpointing_args(parser) |
| 268 | + add_hotwords_args(parser) | ||
| 242 | 269 | ||
| 243 | parser.add_argument( | 270 | parser.add_argument( |
| 244 | "--port", | 271 | "--port", |
| @@ -343,6 +370,8 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: | @@ -343,6 +370,8 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: | ||
| 343 | feature_dim=args.feat_dim, | 370 | feature_dim=args.feat_dim, |
| 344 | decoding_method=args.decoding_method, | 371 | decoding_method=args.decoding_method, |
| 345 | max_active_paths=args.num_active_paths, | 372 | max_active_paths=args.num_active_paths, |
| 373 | + hotwords_score=args.hotwords_score, | ||
| 374 | + hotwords_file=args.hotwords_file, | ||
| 346 | enable_endpoint_detection=args.use_endpoint != 0, | 375 | enable_endpoint_detection=args.use_endpoint != 0, |
| 347 | rule1_min_trailing_silence=args.rule1_min_trailing_silence, | 376 | rule1_min_trailing_silence=args.rule1_min_trailing_silence, |
| 348 | rule2_min_trailing_silence=args.rule2_min_trailing_silence, | 377 | rule2_min_trailing_silence=args.rule2_min_trailing_silence, |
scripts/text2token.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +""" | ||
| 4 | +This script encode the texts (given line by line through `text`) to tokens and | ||
| 5 | +write the results to the file given by ``output``. | ||
| 6 | + | ||
| 7 | +Usage: | ||
| 8 | +If the tokens_type is bpe: | ||
| 9 | + | ||
| 10 | +python3 ./text2token.py \ | ||
| 11 | + --text texts.txt \ | ||
| 12 | + --tokens tokens.txt \ | ||
| 13 | + --tokens-type bpe \ | ||
| 14 | + --bpe-model bpe.model \ | ||
| 15 | + --output hotwords.txt | ||
| 16 | + | ||
| 17 | +If the tokens_type is cjkchar: | ||
| 18 | + | ||
| 19 | +python3 ./text2token.py \ | ||
| 20 | + --text texts.txt \ | ||
| 21 | + --tokens tokens.txt \ | ||
| 22 | + --tokens-type cjkchar \ | ||
| 23 | + --output hotwords.txt | ||
| 24 | + | ||
| 25 | +If the tokens_type is cjkchar+bpe: | ||
| 26 | + | ||
| 27 | +python3 ./text2token.py \ | ||
| 28 | + --text texts.txt \ | ||
| 29 | + --tokens tokens.txt \ | ||
| 30 | + --tokens-type cjkchar+bpe \ | ||
| 31 | + --bpe-model bpe.model \ | ||
| 32 | + --output hotwords.txt | ||
| 33 | + | ||
| 34 | +""" | ||
| 35 | +import argparse | ||
| 36 | + | ||
| 37 | +from sherpa_onnx import text2token | ||
| 38 | + | ||
| 39 | +def get_args(): | ||
| 40 | + parser = argparse.ArgumentParser() | ||
| 41 | + parser.add_argument( | ||
| 42 | + "--text", | ||
| 43 | + type=str, | ||
| 44 | + required=True, | ||
| 45 | + help="Path to the input texts", | ||
| 46 | + ) | ||
| 47 | + | ||
| 48 | + parser.add_argument( | ||
| 49 | + "--tokens", | ||
| 50 | + type=str, | ||
| 51 | + required=True, | ||
| 52 | + help="The path to tokens.txt.", | ||
| 53 | + ) | ||
| 54 | + | ||
| 55 | + parser.add_argument( | ||
| 56 | + "--tokens-type", | ||
| 57 | + type=str, | ||
| 58 | + required=True, | ||
| 59 | + help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe", | ||
| 60 | + ) | ||
| 61 | + | ||
| 62 | + parser.add_argument( | ||
| 63 | + "--bpe-model", | ||
| 64 | + type=str, | ||
| 65 | + help="The path to bpe.model. Only required when tokens-type is bpe or cjkchar+bpe.", | ||
| 66 | + ) | ||
| 67 | + | ||
| 68 | + parser.add_argument( | ||
| 69 | + "--output", | ||
| 70 | + type=str, | ||
| 71 | + required=True, | ||
| 72 | + help="Path where the encoded tokens will be written to.", | ||
| 73 | + ) | ||
| 74 | + | ||
| 75 | + return parser.parse_args() | ||
| 76 | + | ||
| 77 | + | ||
| 78 | +def main(): | ||
| 79 | + args = get_args() | ||
| 80 | + | ||
| 81 | + texts = [] | ||
| 82 | + with open(args.text, "r", encoding="utf8") as f: | ||
| 83 | + for line in f: | ||
| 84 | + texts.append(line.strip()) | ||
| 85 | + encoded_texts = text2token( | ||
| 86 | + texts, | ||
| 87 | + tokens=args.tokens, | ||
| 88 | + tokens_type=args.tokens_type, | ||
| 89 | + bpe_model=args.bpe_model, | ||
| 90 | + ) | ||
| 91 | + with open(args.output, "w", encoding="utf8") as f: | ||
| 92 | + for txt in encoded_texts: | ||
| 93 | + f.write(" ".join(txt) + "\n") | ||
| 94 | + | ||
| 95 | + | ||
| 96 | +if __name__ == "__main__": | ||
| 97 | + main() |
| @@ -39,6 +39,7 @@ install_requires = [ | @@ -39,6 +39,7 @@ install_requires = [ | ||
| 39 | "numpy", | 39 | "numpy", |
| 40 | "sentencepiece==0.1.96; python_version < '3.11'", | 40 | "sentencepiece==0.1.96; python_version < '3.11'", |
| 41 | "sentencepiece; python_version >= '3.11'", | 41 | "sentencepiece; python_version >= '3.11'", |
| 42 | + "click>=7.1.1", | ||
| 42 | ] | 43 | ] |
| 43 | 44 | ||
| 44 | 45 | ||
| @@ -93,6 +94,11 @@ setuptools.setup( | @@ -93,6 +94,11 @@ setuptools.setup( | ||
| 93 | "Programming Language :: Python", | 94 | "Programming Language :: Python", |
| 94 | "Topic :: Scientific/Engineering :: Artificial Intelligence", | 95 | "Topic :: Scientific/Engineering :: Artificial Intelligence", |
| 95 | ], | 96 | ], |
| 97 | + entry_points={ | ||
| 98 | + 'console_scripts': [ | ||
| 99 | + 'sherpa-onnx-cli=sherpa_onnx.cli:cli', | ||
| 100 | + ], | ||
| 101 | + }, | ||
| 96 | license="Apache licensed, as found in the LICENSE file", | 102 | license="Apache licensed, as found in the LICENSE file", |
| 97 | ) | 103 | ) |
| 98 | 104 |
| @@ -4,11 +4,14 @@ | @@ -4,11 +4,14 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/context-graph.h" | 5 | #include "sherpa-onnx/csrc/context-graph.h" |
| 6 | 6 | ||
| 7 | +#include <chrono> // NOLINT | ||
| 7 | #include <map> | 8 | #include <map> |
| 9 | +#include <random> | ||
| 8 | #include <string> | 10 | #include <string> |
| 9 | #include <vector> | 11 | #include <vector> |
| 10 | 12 | ||
| 11 | #include "gtest/gtest.h" | 13 | #include "gtest/gtest.h" |
| 14 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | 15 | ||
| 13 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 14 | 17 | ||
| @@ -41,4 +44,29 @@ TEST(ContextGraph, TestBasic) { | @@ -41,4 +44,29 @@ TEST(ContextGraph, TestBasic) { | ||
| 41 | } | 44 | } |
| 42 | } | 45 | } |
| 43 | 46 | ||
| 47 | +TEST(ContextGraph, Benchmark) { | ||
| 48 | + std::random_device rd; | ||
| 49 | + std::mt19937 mt(rd()); | ||
| 50 | + std::uniform_int_distribution<int32_t> char_dist(0, 25); | ||
| 51 | + std::uniform_int_distribution<int32_t> len_dist(3, 8); | ||
| 52 | + for (int32_t num = 10; num <= 10000; num *= 10) { | ||
| 53 | + std::vector<std::vector<int32_t>> contexts; | ||
| 54 | + for (int32_t i = 0; i < num; ++i) { | ||
| 55 | + std::vector<int32_t> tmp; | ||
| 56 | + int32_t word_len = len_dist(mt); | ||
| 57 | + for (int32_t j = 0; j < word_len; ++j) { | ||
| 58 | + tmp.push_back(char_dist(mt)); | ||
| 59 | + } | ||
| 60 | + contexts.push_back(std::move(tmp)); | ||
| 61 | + } | ||
| 62 | + auto start = std::chrono::high_resolution_clock::now(); | ||
| 63 | + auto context_graph = ContextGraph(contexts, 1); | ||
| 64 | + auto stop = std::chrono::high_resolution_clock::now(); | ||
| 65 | + auto duration = | ||
| 66 | + std::chrono::duration_cast<std::chrono::microseconds>(stop - start); | ||
| 67 | + SHERPA_ONNX_LOGE("Construct context graph for %d item takes %ld us.", num, | ||
| 68 | + duration.count()); | ||
| 69 | + } | ||
| 70 | +} | ||
| 71 | + | ||
| 44 | } // namespace sherpa_onnx | 72 | } // namespace sherpa_onnx |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_ | 6 | #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_ |
| 7 | 7 | ||
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | +#include <string> | ||
| 9 | #include <vector> | 10 | #include <vector> |
| 10 | 11 | ||
| 11 | #if __ANDROID_API__ >= 9 | 12 | #if __ANDROID_API__ >= 9 |
| @@ -32,7 +33,7 @@ class OfflineRecognizerImpl { | @@ -32,7 +33,7 @@ class OfflineRecognizerImpl { | ||
| 32 | virtual ~OfflineRecognizerImpl() = default; | 33 | virtual ~OfflineRecognizerImpl() = default; |
| 33 | 34 | ||
| 34 | virtual std::unique_ptr<OfflineStream> CreateStream( | 35 | virtual std::unique_ptr<OfflineStream> CreateStream( |
| 35 | - const std::vector<std::vector<int32_t>> &context_list) const { | 36 | + const std::string &hotwords) const { |
| 36 | SHERPA_ONNX_LOGE("Only transducer models support contextual biasing."); | 37 | SHERPA_ONNX_LOGE("Only transducer models support contextual biasing."); |
| 37 | exit(-1); | 38 | exit(-1); |
| 38 | } | 39 | } |
| @@ -5,7 +5,9 @@ | @@ -5,7 +5,9 @@ | ||
| 5 | #ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ | 5 | #ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ |
| 6 | #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ | 6 | #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ |
| 7 | 7 | ||
| 8 | +#include <fstream> | ||
| 8 | #include <memory> | 9 | #include <memory> |
| 10 | +#include <regex> // NOLINT | ||
| 9 | #include <string> | 11 | #include <string> |
| 10 | #include <utility> | 12 | #include <utility> |
| 11 | #include <vector> | 13 | #include <vector> |
| @@ -16,6 +18,7 @@ | @@ -16,6 +18,7 @@ | ||
| 16 | #endif | 18 | #endif |
| 17 | 19 | ||
| 18 | #include "sherpa-onnx/csrc/context-graph.h" | 20 | #include "sherpa-onnx/csrc/context-graph.h" |
| 21 | +#include "sherpa-onnx/csrc/log.h" | ||
| 19 | #include "sherpa-onnx/csrc/macros.h" | 22 | #include "sherpa-onnx/csrc/macros.h" |
| 20 | #include "sherpa-onnx/csrc/offline-recognizer-impl.h" | 23 | #include "sherpa-onnx/csrc/offline-recognizer-impl.h" |
| 21 | #include "sherpa-onnx/csrc/offline-recognizer.h" | 24 | #include "sherpa-onnx/csrc/offline-recognizer.h" |
| @@ -25,6 +28,7 @@ | @@ -25,6 +28,7 @@ | ||
| 25 | #include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h" | 28 | #include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h" |
| 26 | #include "sherpa-onnx/csrc/pad-sequence.h" | 29 | #include "sherpa-onnx/csrc/pad-sequence.h" |
| 27 | #include "sherpa-onnx/csrc/symbol-table.h" | 30 | #include "sherpa-onnx/csrc/symbol-table.h" |
| 31 | +#include "sherpa-onnx/csrc/utils.h" | ||
| 28 | 32 | ||
| 29 | namespace sherpa_onnx { | 33 | namespace sherpa_onnx { |
| 30 | 34 | ||
| @@ -60,6 +64,9 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -60,6 +64,9 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 60 | : config_(config), | 64 | : config_(config), |
| 61 | symbol_table_(config_.model_config.tokens), | 65 | symbol_table_(config_.model_config.tokens), |
| 62 | model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) { | 66 | model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) { |
| 67 | + if (!config_.hotwords_file.empty()) { | ||
| 68 | + InitHotwords(); | ||
| 69 | + } | ||
| 63 | if (config_.decoding_method == "greedy_search") { | 70 | if (config_.decoding_method == "greedy_search") { |
| 64 | decoder_ = | 71 | decoder_ = |
| 65 | std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); | 72 | std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); |
| @@ -105,17 +112,24 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -105,17 +112,24 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 105 | #endif | 112 | #endif |
| 106 | 113 | ||
| 107 | std::unique_ptr<OfflineStream> CreateStream( | 114 | std::unique_ptr<OfflineStream> CreateStream( |
| 108 | - const std::vector<std::vector<int32_t>> &context_list) const override { | ||
| 109 | - // We create context_graph at this level, because we might have default | ||
| 110 | - // context_graph(will be added later if needed) that belongs to the whole | ||
| 111 | - // model rather than each stream. | 115 | + const std::string &hotwords) const override { |
| 116 | + auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); | ||
| 117 | + std::istringstream is(hws); | ||
| 118 | + std::vector<std::vector<int32_t>> current; | ||
| 119 | + if (!EncodeHotwords(is, symbol_table_, ¤t)) { | ||
| 120 | + SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", | ||
| 121 | + hotwords.c_str()); | ||
| 122 | + } | ||
| 123 | + current.insert(current.end(), hotwords_.begin(), hotwords_.end()); | ||
| 124 | + | ||
| 112 | auto context_graph = | 125 | auto context_graph = |
| 113 | - std::make_shared<ContextGraph>(context_list, config_.context_score); | 126 | + std::make_shared<ContextGraph>(current, config_.hotwords_score); |
| 114 | return std::make_unique<OfflineStream>(config_.feat_config, context_graph); | 127 | return std::make_unique<OfflineStream>(config_.feat_config, context_graph); |
| 115 | } | 128 | } |
| 116 | 129 | ||
| 117 | std::unique_ptr<OfflineStream> CreateStream() const override { | 130 | std::unique_ptr<OfflineStream> CreateStream() const override { |
| 118 | - return std::make_unique<OfflineStream>(config_.feat_config); | 131 | + return std::make_unique<OfflineStream>(config_.feat_config, |
| 132 | + hotwords_graph_); | ||
| 119 | } | 133 | } |
| 120 | 134 | ||
| 121 | void DecodeStreams(OfflineStream **ss, int32_t n) const override { | 135 | void DecodeStreams(OfflineStream **ss, int32_t n) const override { |
| @@ -171,9 +185,29 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -171,9 +185,29 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 171 | } | 185 | } |
| 172 | } | 186 | } |
| 173 | 187 | ||
| 188 | + void InitHotwords() { | ||
| 189 | + // each line in hotwords_file contains space-separated words | ||
| 190 | + | ||
| 191 | + std::ifstream is(config_.hotwords_file); | ||
| 192 | + if (!is) { | ||
| 193 | + SHERPA_ONNX_LOGE("Open hotwords file failed: %s", | ||
| 194 | + config_.hotwords_file.c_str()); | ||
| 195 | + exit(-1); | ||
| 196 | + } | ||
| 197 | + | ||
| 198 | + if (!EncodeHotwords(is, symbol_table_, &hotwords_)) { | ||
| 199 | + SHERPA_ONNX_LOGE("Encode hotwords failed."); | ||
| 200 | + exit(-1); | ||
| 201 | + } | ||
| 202 | + hotwords_graph_ = | ||
| 203 | + std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); | ||
| 204 | + } | ||
| 205 | + | ||
| 174 | private: | 206 | private: |
| 175 | OfflineRecognizerConfig config_; | 207 | OfflineRecognizerConfig config_; |
| 176 | SymbolTable symbol_table_; | 208 | SymbolTable symbol_table_; |
| 209 | + std::vector<std::vector<int32_t>> hotwords_; | ||
| 210 | + ContextGraphPtr hotwords_graph_; | ||
| 177 | std::unique_ptr<OfflineTransducerModel> model_; | 211 | std::unique_ptr<OfflineTransducerModel> model_; |
| 178 | std::unique_ptr<OfflineTransducerDecoder> decoder_; | 212 | std::unique_ptr<OfflineTransducerDecoder> decoder_; |
| 179 | std::unique_ptr<OfflineLM> lm_; | 213 | std::unique_ptr<OfflineLM> lm_; |
| @@ -26,7 +26,15 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { | @@ -26,7 +26,15 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { | ||
| 26 | 26 | ||
| 27 | po->Register("max-active-paths", &max_active_paths, | 27 | po->Register("max-active-paths", &max_active_paths, |
| 28 | "Used only when decoding_method is modified_beam_search"); | 28 | "Used only when decoding_method is modified_beam_search"); |
| 29 | - po->Register("context-score", &context_score, | 29 | + |
| 30 | + po->Register( | ||
| 31 | + "hotwords-file", &hotwords_file, | ||
| 32 | + "The file containing hotwords, one words/phrases per line, and for each" | ||
| 33 | + "phrase the bpe/cjkchar are separated by a space. For example: " | ||
| 34 | + "▁HE LL O ▁WORLD" | ||
| 35 | + "你 好 世 界"); | ||
| 36 | + | ||
| 37 | + po->Register("hotwords-score", &hotwords_score, | ||
| 30 | "The bonus score for each token in context word/phrase. " | 38 | "The bonus score for each token in context word/phrase. " |
| 31 | "Used only when decoding_method is modified_beam_search"); | 39 | "Used only when decoding_method is modified_beam_search"); |
| 32 | } | 40 | } |
| @@ -53,7 +61,8 @@ std::string OfflineRecognizerConfig::ToString() const { | @@ -53,7 +61,8 @@ std::string OfflineRecognizerConfig::ToString() const { | ||
| 53 | os << "lm_config=" << lm_config.ToString() << ", "; | 61 | os << "lm_config=" << lm_config.ToString() << ", "; |
| 54 | os << "decoding_method=\"" << decoding_method << "\", "; | 62 | os << "decoding_method=\"" << decoding_method << "\", "; |
| 55 | os << "max_active_paths=" << max_active_paths << ", "; | 63 | os << "max_active_paths=" << max_active_paths << ", "; |
| 56 | - os << "context_score=" << context_score << ")"; | 64 | + os << "hotwords_file=\"" << hotwords_file << "\", "; |
| 65 | + os << "hotwords_score=" << hotwords_score << ")"; | ||
| 57 | 66 | ||
| 58 | return os.str(); | 67 | return os.str(); |
| 59 | } | 68 | } |
| @@ -70,8 +79,8 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config) | @@ -70,8 +79,8 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config) | ||
| 70 | OfflineRecognizer::~OfflineRecognizer() = default; | 79 | OfflineRecognizer::~OfflineRecognizer() = default; |
| 71 | 80 | ||
| 72 | std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream( | 81 | std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream( |
| 73 | - const std::vector<std::vector<int32_t>> &context_list) const { | ||
| 74 | - return impl_->CreateStream(context_list); | 82 | + const std::string &hotwords) const { |
| 83 | + return impl_->CreateStream(hotwords); | ||
| 75 | } | 84 | } |
| 76 | 85 | ||
| 77 | std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const { | 86 | std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const { |
| @@ -31,7 +31,10 @@ struct OfflineRecognizerConfig { | @@ -31,7 +31,10 @@ struct OfflineRecognizerConfig { | ||
| 31 | 31 | ||
| 32 | std::string decoding_method = "greedy_search"; | 32 | std::string decoding_method = "greedy_search"; |
| 33 | int32_t max_active_paths = 4; | 33 | int32_t max_active_paths = 4; |
| 34 | - float context_score = 1.5; | 34 | + |
| 35 | + std::string hotwords_file; | ||
| 36 | + float hotwords_score = 1.5; | ||
| 37 | + | ||
| 35 | // only greedy_search is implemented | 38 | // only greedy_search is implemented |
| 36 | // TODO(fangjun): Implement modified_beam_search | 39 | // TODO(fangjun): Implement modified_beam_search |
| 37 | 40 | ||
| @@ -40,13 +43,16 @@ struct OfflineRecognizerConfig { | @@ -40,13 +43,16 @@ struct OfflineRecognizerConfig { | ||
| 40 | const OfflineModelConfig &model_config, | 43 | const OfflineModelConfig &model_config, |
| 41 | const OfflineLMConfig &lm_config, | 44 | const OfflineLMConfig &lm_config, |
| 42 | const std::string &decoding_method, | 45 | const std::string &decoding_method, |
| 43 | - int32_t max_active_paths, float context_score) | 46 | + int32_t max_active_paths, |
| 47 | + const std::string &hotwords_file, | ||
| 48 | + float hotwords_score) | ||
| 44 | : feat_config(feat_config), | 49 | : feat_config(feat_config), |
| 45 | model_config(model_config), | 50 | model_config(model_config), |
| 46 | lm_config(lm_config), | 51 | lm_config(lm_config), |
| 47 | decoding_method(decoding_method), | 52 | decoding_method(decoding_method), |
| 48 | max_active_paths(max_active_paths), | 53 | max_active_paths(max_active_paths), |
| 49 | - context_score(context_score) {} | 54 | + hotwords_file(hotwords_file), |
| 55 | + hotwords_score(hotwords_score) {} | ||
| 50 | 56 | ||
| 51 | void Register(ParseOptions *po); | 57 | void Register(ParseOptions *po); |
| 52 | bool Validate() const; | 58 | bool Validate() const; |
| @@ -69,9 +75,17 @@ class OfflineRecognizer { | @@ -69,9 +75,17 @@ class OfflineRecognizer { | ||
| 69 | /// Create a stream for decoding. | 75 | /// Create a stream for decoding. |
| 70 | std::unique_ptr<OfflineStream> CreateStream() const; | 76 | std::unique_ptr<OfflineStream> CreateStream() const; |
| 71 | 77 | ||
| 72 | - /// Create a stream for decoding. | 78 | + /** Create a stream for decoding. |
| 79 | + * | ||
| 80 | + * @param The hotwords for this string, it might contain several hotwords, | ||
| 81 | + * the hotwords are separated by "/". In each of the hotwords, there | ||
| 82 | + * are cjkchars or bpes, the bpe/cjkchar are separated by space (" "). | ||
| 83 | + * For example, hotwords I LOVE YOU and HELLO WORLD, looks like: | ||
| 84 | + * | ||
| 85 | + * "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD" | ||
| 86 | + */ | ||
| 73 | std::unique_ptr<OfflineStream> CreateStream( | 87 | std::unique_ptr<OfflineStream> CreateStream( |
| 74 | - const std::vector<std::vector<int32_t>> &context_list) const; | 88 | + const std::string &hotwords) const; |
| 75 | 89 | ||
| 76 | /** Decode a single stream | 90 | /** Decode a single stream |
| 77 | * | 91 | * |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | #define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_IMPL_H_ | 6 | #define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_IMPL_H_ |
| 7 | 7 | ||
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | +#include <string> | ||
| 9 | #include <vector> | 10 | #include <vector> |
| 10 | 11 | ||
| 11 | #include "sherpa-onnx/csrc/macros.h" | 12 | #include "sherpa-onnx/csrc/macros.h" |
| @@ -29,7 +30,7 @@ class OnlineRecognizerImpl { | @@ -29,7 +30,7 @@ class OnlineRecognizerImpl { | ||
| 29 | virtual std::unique_ptr<OnlineStream> CreateStream() const = 0; | 30 | virtual std::unique_ptr<OnlineStream> CreateStream() const = 0; |
| 30 | 31 | ||
| 31 | virtual std::unique_ptr<OnlineStream> CreateStream( | 32 | virtual std::unique_ptr<OnlineStream> CreateStream( |
| 32 | - const std::vector<std::vector<int32_t>> &contexts) const { | 33 | + const std::string &hotwords) const { |
| 33 | SHERPA_ONNX_LOGE("Only transducer models support contextual biasing."); | 34 | SHERPA_ONNX_LOGE("Only transducer models support contextual biasing."); |
| 34 | exit(-1); | 35 | exit(-1); |
| 35 | } | 36 | } |
| @@ -7,6 +7,8 @@ | @@ -7,6 +7,8 @@ | ||
| 7 | 7 | ||
| 8 | #include <algorithm> | 8 | #include <algorithm> |
| 9 | #include <memory> | 9 | #include <memory> |
| 10 | +#include <regex> // NOLINT | ||
| 11 | +#include <string> | ||
| 10 | #include <utility> | 12 | #include <utility> |
| 11 | #include <vector> | 13 | #include <vector> |
| 12 | 14 | ||
| @@ -20,6 +22,7 @@ | @@ -20,6 +22,7 @@ | ||
| 20 | #include "sherpa-onnx/csrc/online-transducer-model.h" | 22 | #include "sherpa-onnx/csrc/online-transducer-model.h" |
| 21 | #include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" | 23 | #include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" |
| 22 | #include "sherpa-onnx/csrc/symbol-table.h" | 24 | #include "sherpa-onnx/csrc/symbol-table.h" |
| 25 | +#include "sherpa-onnx/csrc/utils.h" | ||
| 23 | 26 | ||
| 24 | namespace sherpa_onnx { | 27 | namespace sherpa_onnx { |
| 25 | 28 | ||
| @@ -57,6 +60,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -57,6 +60,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 57 | model_(OnlineTransducerModel::Create(config.model_config)), | 60 | model_(OnlineTransducerModel::Create(config.model_config)), |
| 58 | sym_(config.model_config.tokens), | 61 | sym_(config.model_config.tokens), |
| 59 | endpoint_(config_.endpoint_config) { | 62 | endpoint_(config_.endpoint_config) { |
| 63 | + if (!config_.hotwords_file.empty()) { | ||
| 64 | + InitHotwords(); | ||
| 65 | + } | ||
| 60 | if (sym_.contains("<unk>")) { | 66 | if (sym_.contains("<unk>")) { |
| 61 | unk_id_ = sym_["<unk>"]; | 67 | unk_id_ = sym_["<unk>"]; |
| 62 | } | 68 | } |
| @@ -106,18 +112,24 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -106,18 +112,24 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 106 | #endif | 112 | #endif |
| 107 | 113 | ||
| 108 | std::unique_ptr<OnlineStream> CreateStream() const override { | 114 | std::unique_ptr<OnlineStream> CreateStream() const override { |
| 109 | - auto stream = std::make_unique<OnlineStream>(config_.feat_config); | 115 | + auto stream = |
| 116 | + std::make_unique<OnlineStream>(config_.feat_config, hotwords_graph_); | ||
| 110 | InitOnlineStream(stream.get()); | 117 | InitOnlineStream(stream.get()); |
| 111 | return stream; | 118 | return stream; |
| 112 | } | 119 | } |
| 113 | 120 | ||
| 114 | std::unique_ptr<OnlineStream> CreateStream( | 121 | std::unique_ptr<OnlineStream> CreateStream( |
| 115 | - const std::vector<std::vector<int32_t>> &contexts) const override { | ||
| 116 | - // We create context_graph at this level, because we might have default | ||
| 117 | - // context_graph(will be added later if needed) that belongs to the whole | ||
| 118 | - // model rather than each stream. | 122 | + const std::string &hotwords) const override { |
| 123 | + auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); | ||
| 124 | + std::istringstream is(hws); | ||
| 125 | + std::vector<std::vector<int32_t>> current; | ||
| 126 | + if (!EncodeHotwords(is, sym_, ¤t)) { | ||
| 127 | + SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", | ||
| 128 | + hotwords.c_str()); | ||
| 129 | + } | ||
| 130 | + current.insert(current.end(), hotwords_.begin(), hotwords_.end()); | ||
| 119 | auto context_graph = | 131 | auto context_graph = |
| 120 | - std::make_shared<ContextGraph>(contexts, config_.context_score); | 132 | + std::make_shared<ContextGraph>(current, config_.hotwords_score); |
| 121 | auto stream = | 133 | auto stream = |
| 122 | std::make_unique<OnlineStream>(config_.feat_config, context_graph); | 134 | std::make_unique<OnlineStream>(config_.feat_config, context_graph); |
| 123 | InitOnlineStream(stream.get()); | 135 | InitOnlineStream(stream.get()); |
| @@ -253,6 +265,24 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -253,6 +265,24 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 253 | s->Reset(); | 265 | s->Reset(); |
| 254 | } | 266 | } |
| 255 | 267 | ||
| 268 | + void InitHotwords() { | ||
| 269 | + // each line in hotwords_file contains space-separated words | ||
| 270 | + | ||
| 271 | + std::ifstream is(config_.hotwords_file); | ||
| 272 | + if (!is) { | ||
| 273 | + SHERPA_ONNX_LOGE("Open hotwords file failed: %s", | ||
| 274 | + config_.hotwords_file.c_str()); | ||
| 275 | + exit(-1); | ||
| 276 | + } | ||
| 277 | + | ||
| 278 | + if (!EncodeHotwords(is, sym_, &hotwords_)) { | ||
| 279 | + SHERPA_ONNX_LOGE("Encode hotwords failed."); | ||
| 280 | + exit(-1); | ||
| 281 | + } | ||
| 282 | + hotwords_graph_ = | ||
| 283 | + std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); | ||
| 284 | + } | ||
| 285 | + | ||
| 256 | private: | 286 | private: |
| 257 | void InitOnlineStream(OnlineStream *stream) const { | 287 | void InitOnlineStream(OnlineStream *stream) const { |
| 258 | auto r = decoder_->GetEmptyResult(); | 288 | auto r = decoder_->GetEmptyResult(); |
| @@ -271,6 +301,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -271,6 +301,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 271 | 301 | ||
| 272 | private: | 302 | private: |
| 273 | OnlineRecognizerConfig config_; | 303 | OnlineRecognizerConfig config_; |
| 304 | + std::vector<std::vector<int32_t>> hotwords_; | ||
| 305 | + ContextGraphPtr hotwords_graph_; | ||
| 274 | std::unique_ptr<OnlineTransducerModel> model_; | 306 | std::unique_ptr<OnlineTransducerModel> model_; |
| 275 | std::unique_ptr<OnlineLM> lm_; | 307 | std::unique_ptr<OnlineLM> lm_; |
| 276 | std::unique_ptr<OnlineTransducerDecoder> decoder_; | 308 | std::unique_ptr<OnlineTransducerDecoder> decoder_; |
| @@ -57,9 +57,15 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | @@ -57,9 +57,15 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | ||
| 57 | "True to enable endpoint detection. False to disable it."); | 57 | "True to enable endpoint detection. False to disable it."); |
| 58 | po->Register("max-active-paths", &max_active_paths, | 58 | po->Register("max-active-paths", &max_active_paths, |
| 59 | "beam size used in modified beam search."); | 59 | "beam size used in modified beam search."); |
| 60 | - po->Register("context-score", &context_score, | 60 | + po->Register("hotwords-score", &hotwords_score, |
| 61 | "The bonus score for each token in context word/phrase. " | 61 | "The bonus score for each token in context word/phrase. " |
| 62 | "Used only when decoding_method is modified_beam_search"); | 62 | "Used only when decoding_method is modified_beam_search"); |
| 63 | + po->Register( | ||
| 64 | + "hotwords-file", &hotwords_file, | ||
| 65 | + "The file containing hotwords, one words/phrases per line, and for each" | ||
| 66 | + "phrase the bpe/cjkchar are separated by a space. For example: " | ||
| 67 | + "▁HE LL O ▁WORLD" | ||
| 68 | + "你 好 世 界"); | ||
| 63 | po->Register("decoding-method", &decoding_method, | 69 | po->Register("decoding-method", &decoding_method, |
| 64 | "decoding method," | 70 | "decoding method," |
| 65 | "now support greedy_search and modified_beam_search."); | 71 | "now support greedy_search and modified_beam_search."); |
| @@ -87,7 +93,8 @@ std::string OnlineRecognizerConfig::ToString() const { | @@ -87,7 +93,8 @@ std::string OnlineRecognizerConfig::ToString() const { | ||
| 87 | os << "endpoint_config=" << endpoint_config.ToString() << ", "; | 93 | os << "endpoint_config=" << endpoint_config.ToString() << ", "; |
| 88 | os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; | 94 | os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; |
| 89 | os << "max_active_paths=" << max_active_paths << ", "; | 95 | os << "max_active_paths=" << max_active_paths << ", "; |
| 90 | - os << "context_score=" << context_score << ", "; | 96 | + os << "hotwords_score=" << hotwords_score << ", "; |
| 97 | + os << "hotwords_file=\"" << hotwords_file << "\", "; | ||
| 91 | os << "decoding_method=\"" << decoding_method << "\")"; | 98 | os << "decoding_method=\"" << decoding_method << "\")"; |
| 92 | 99 | ||
| 93 | return os.str(); | 100 | return os.str(); |
| @@ -109,8 +116,8 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const { | @@ -109,8 +116,8 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const { | ||
| 109 | } | 116 | } |
| 110 | 117 | ||
| 111 | std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream( | 118 | std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream( |
| 112 | - const std::vector<std::vector<int32_t>> &context_list) const { | ||
| 113 | - return impl_->CreateStream(context_list); | 119 | + const std::string &hotwords) const { |
| 120 | + return impl_->CreateStream(hotwords); | ||
| 114 | } | 121 | } |
| 115 | 122 | ||
| 116 | bool OnlineRecognizer::IsReady(OnlineStream *s) const { | 123 | bool OnlineRecognizer::IsReady(OnlineStream *s) const { |
| @@ -78,8 +78,10 @@ struct OnlineRecognizerConfig { | @@ -78,8 +78,10 @@ struct OnlineRecognizerConfig { | ||
| 78 | 78 | ||
| 79 | // used only for modified_beam_search | 79 | // used only for modified_beam_search |
| 80 | int32_t max_active_paths = 4; | 80 | int32_t max_active_paths = 4; |
| 81 | + | ||
| 81 | /// used only for modified_beam_search | 82 | /// used only for modified_beam_search |
| 82 | - float context_score = 1.5; | 83 | + float hotwords_score = 1.5; |
| 84 | + std::string hotwords_file; | ||
| 83 | 85 | ||
| 84 | OnlineRecognizerConfig() = default; | 86 | OnlineRecognizerConfig() = default; |
| 85 | 87 | ||
| @@ -89,14 +91,16 @@ struct OnlineRecognizerConfig { | @@ -89,14 +91,16 @@ struct OnlineRecognizerConfig { | ||
| 89 | const EndpointConfig &endpoint_config, | 91 | const EndpointConfig &endpoint_config, |
| 90 | bool enable_endpoint, | 92 | bool enable_endpoint, |
| 91 | const std::string &decoding_method, | 93 | const std::string &decoding_method, |
| 92 | - int32_t max_active_paths, float context_score) | 94 | + int32_t max_active_paths, |
| 95 | + const std::string &hotwords_file, float hotwords_score) | ||
| 93 | : feat_config(feat_config), | 96 | : feat_config(feat_config), |
| 94 | model_config(model_config), | 97 | model_config(model_config), |
| 95 | endpoint_config(endpoint_config), | 98 | endpoint_config(endpoint_config), |
| 96 | enable_endpoint(enable_endpoint), | 99 | enable_endpoint(enable_endpoint), |
| 97 | decoding_method(decoding_method), | 100 | decoding_method(decoding_method), |
| 98 | max_active_paths(max_active_paths), | 101 | max_active_paths(max_active_paths), |
| 99 | - context_score(context_score) {} | 102 | + hotwords_score(hotwords_score), |
| 103 | + hotwords_file(hotwords_file) {} | ||
| 100 | 104 | ||
| 101 | void Register(ParseOptions *po); | 105 | void Register(ParseOptions *po); |
| 102 | bool Validate() const; | 106 | bool Validate() const; |
| @@ -119,9 +123,16 @@ class OnlineRecognizer { | @@ -119,9 +123,16 @@ class OnlineRecognizer { | ||
| 119 | /// Create a stream for decoding. | 123 | /// Create a stream for decoding. |
| 120 | std::unique_ptr<OnlineStream> CreateStream() const; | 124 | std::unique_ptr<OnlineStream> CreateStream() const; |
| 121 | 125 | ||
| 122 | - // Create a stream with context phrases | ||
| 123 | - std::unique_ptr<OnlineStream> CreateStream( | ||
| 124 | - const std::vector<std::vector<int32_t>> &context_list) const; | 126 | + /** Create a stream for decoding. |
| 127 | + * | ||
| 128 | + * @param The hotwords for this string, it might contain several hotwords, | ||
| 129 | + * the hotwords are separated by "/". In each of the hotwords, there | ||
| 130 | + * are cjkchars or bpes, the bpe/cjkchar are separated by space (" "). | ||
| 131 | + * For example, hotwords I LOVE YOU and HELLO WORLD, looks like: | ||
| 132 | + * | ||
| 133 | + * "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD" | ||
| 134 | + */ | ||
| 135 | + std::unique_ptr<OnlineStream> CreateStream(const std::string &hotwords) const; | ||
| 125 | 136 | ||
| 126 | /** | 137 | /** |
| 127 | * Return true if the given stream has enough frames for decoding. | 138 | * Return true if the given stream has enough frames for decoding. |
sherpa-onnx/csrc/utils.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/utils.cc | ||
| 2 | +// | ||
| 3 | +// Copyright 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/utils.h" | ||
| 6 | + | ||
| 7 | +#include <iostream> | ||
| 8 | +#include <sstream> | ||
| 9 | +#include <string> | ||
| 10 | +#include <utility> | ||
| 11 | +#include <vector> | ||
| 12 | + | ||
| 13 | +#include "sherpa-onnx/csrc/log.h" | ||
| 14 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table, | ||
| 19 | + std::vector<std::vector<int32_t>> *hotwords) { | ||
| 20 | + hotwords->clear(); | ||
| 21 | + std::vector<int32_t> tmp; | ||
| 22 | + std::string line; | ||
| 23 | + std::string word; | ||
| 24 | + | ||
| 25 | + while (std::getline(is, line)) { | ||
| 26 | + std::istringstream iss(line); | ||
| 27 | + std::vector<std::string> syms; | ||
| 28 | + while (iss >> word) { | ||
| 29 | + if (word.size() >= 3) { | ||
| 30 | + // For BPE-based models, we replace ▁ with a space | ||
| 31 | + // Unicode 9601, hex 0x2581, utf8 0xe29681 | ||
| 32 | + const uint8_t *p = reinterpret_cast<const uint8_t *>(word.c_str()); | ||
| 33 | + if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) { | ||
| 34 | + word = word.replace(0, 3, " "); | ||
| 35 | + } | ||
| 36 | + } | ||
| 37 | + if (symbol_table.contains(word)) { | ||
| 38 | + int32_t number = symbol_table[word]; | ||
| 39 | + tmp.push_back(number); | ||
| 40 | + } else { | ||
| 41 | + SHERPA_ONNX_LOGE( | ||
| 42 | + "Cannot find ID for hotword %s at line: %s. (Hint: words on " | ||
| 43 | + "the " | ||
| 44 | + "same line are separated by spaces)", | ||
| 45 | + word.c_str(), line.c_str()); | ||
| 46 | + return false; | ||
| 47 | + } | ||
| 48 | + } | ||
| 49 | + hotwords->push_back(std::move(tmp)); | ||
| 50 | + } | ||
| 51 | + return true; | ||
| 52 | +} | ||
| 53 | + | ||
| 54 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/utils.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/utils.h | ||
| 2 | +// | ||
| 3 | +// Copyright 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_UTILS_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_UTILS_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +/* Encode the hotwords in an input stream to be tokens ids. | ||
| 15 | + * | ||
| 16 | + * @param is The input stream, it contains several lines, one hotword for each | ||
| 17 | + * line. For each hotword, the tokens (cjkchar or bpe) are separated | ||
| 18 | + * by spaces. | ||
| 19 | + * @param symbol_table The tokens table mapping symbols to ids. All the symbols | ||
| 20 | + * in the stream should be in the symbol_table, if not this | ||
| 21 | + * function returns fasle. | ||
| 22 | + * | ||
| 23 | + * @@param hotwords The encoded ids to be written to. | ||
| 24 | + * | ||
| 25 | + * @return If all the symbols from ``is`` are in the symbol_table, returns true | ||
| 26 | + * otherwise returns false. | ||
| 27 | + */ | ||
| 28 | +bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table, | ||
| 29 | + std::vector<std::vector<int32_t>> *hotwords); | ||
| 30 | + | ||
| 31 | +} // namespace sherpa_onnx | ||
| 32 | + | ||
| 33 | +#endif // SHERPA_ONNX_CSRC_UTILS_H_ |
| @@ -16,17 +16,19 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | @@ -16,17 +16,19 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | ||
| 16 | py::class_<PyClass>(*m, "OfflineRecognizerConfig") | 16 | py::class_<PyClass>(*m, "OfflineRecognizerConfig") |
| 17 | .def(py::init<const OfflineFeatureExtractorConfig &, | 17 | .def(py::init<const OfflineFeatureExtractorConfig &, |
| 18 | const OfflineModelConfig &, const OfflineLMConfig &, | 18 | const OfflineModelConfig &, const OfflineLMConfig &, |
| 19 | - const std::string &, int32_t, float>(), | 19 | + const std::string &, int32_t, const std::string &, float>(), |
| 20 | py::arg("feat_config"), py::arg("model_config"), | 20 | py::arg("feat_config"), py::arg("model_config"), |
| 21 | py::arg("lm_config") = OfflineLMConfig(), | 21 | py::arg("lm_config") = OfflineLMConfig(), |
| 22 | py::arg("decoding_method") = "greedy_search", | 22 | py::arg("decoding_method") = "greedy_search", |
| 23 | - py::arg("max_active_paths") = 4, py::arg("context_score") = 1.5) | 23 | + py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", |
| 24 | + py::arg("hotwords_score") = 1.5) | ||
| 24 | .def_readwrite("feat_config", &PyClass::feat_config) | 25 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 25 | .def_readwrite("model_config", &PyClass::model_config) | 26 | .def_readwrite("model_config", &PyClass::model_config) |
| 26 | .def_readwrite("lm_config", &PyClass::lm_config) | 27 | .def_readwrite("lm_config", &PyClass::lm_config) |
| 27 | .def_readwrite("decoding_method", &PyClass::decoding_method) | 28 | .def_readwrite("decoding_method", &PyClass::decoding_method) |
| 28 | .def_readwrite("max_active_paths", &PyClass::max_active_paths) | 29 | .def_readwrite("max_active_paths", &PyClass::max_active_paths) |
| 29 | - .def_readwrite("context_score", &PyClass::context_score) | 30 | + .def_readwrite("hotwords_file", &PyClass::hotwords_file) |
| 31 | + .def_readwrite("hotwords_score", &PyClass::hotwords_score) | ||
| 30 | .def("__str__", &PyClass::ToString); | 32 | .def("__str__", &PyClass::ToString); |
| 31 | } | 33 | } |
| 32 | 34 | ||
| @@ -40,11 +42,10 @@ void PybindOfflineRecognizer(py::module *m) { | @@ -40,11 +42,10 @@ void PybindOfflineRecognizer(py::module *m) { | ||
| 40 | [](const PyClass &self) { return self.CreateStream(); }) | 42 | [](const PyClass &self) { return self.CreateStream(); }) |
| 41 | .def( | 43 | .def( |
| 42 | "create_stream", | 44 | "create_stream", |
| 43 | - [](PyClass &self, | ||
| 44 | - const std::vector<std::vector<int32_t>> &contexts_list) { | ||
| 45 | - return self.CreateStream(contexts_list); | 45 | + [](PyClass &self, const std::string &hotwords) { |
| 46 | + return self.CreateStream(hotwords); | ||
| 46 | }, | 47 | }, |
| 47 | - py::arg("contexts_list")) | 48 | + py::arg("hotwords")) |
| 48 | .def("decode_stream", &PyClass::DecodeStream) | 49 | .def("decode_stream", &PyClass::DecodeStream) |
| 49 | .def("decode_streams", | 50 | .def("decode_streams", |
| 50 | [](const PyClass &self, std::vector<OfflineStream *> ss) { | 51 | [](const PyClass &self, std::vector<OfflineStream *> ss) { |
| @@ -21,8 +21,8 @@ void PybindOnlineModelConfig(py::module *m) { | @@ -21,8 +21,8 @@ void PybindOnlineModelConfig(py::module *m) { | ||
| 21 | using PyClass = OnlineModelConfig; | 21 | using PyClass = OnlineModelConfig; |
| 22 | py::class_<PyClass>(*m, "OnlineModelConfig") | 22 | py::class_<PyClass>(*m, "OnlineModelConfig") |
| 23 | .def(py::init<const OnlineTransducerModelConfig &, | 23 | .def(py::init<const OnlineTransducerModelConfig &, |
| 24 | - const OnlineParaformerModelConfig &, std::string &, int32_t, | ||
| 25 | - bool, const std::string &, const std::string &>(), | 24 | + const OnlineParaformerModelConfig &, const std::string &, |
| 25 | + int32_t, bool, const std::string &, const std::string &>(), | ||
| 26 | py::arg("transducer") = OnlineTransducerModelConfig(), | 26 | py::arg("transducer") = OnlineTransducerModelConfig(), |
| 27 | py::arg("paraformer") = OnlineParaformerModelConfig(), | 27 | py::arg("paraformer") = OnlineParaformerModelConfig(), |
| 28 | py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, | 28 | py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, |
| @@ -29,18 +29,20 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | @@ -29,18 +29,20 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | ||
| 29 | py::class_<PyClass>(*m, "OnlineRecognizerConfig") | 29 | py::class_<PyClass>(*m, "OnlineRecognizerConfig") |
| 30 | .def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &, | 30 | .def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &, |
| 31 | const OnlineLMConfig &, const EndpointConfig &, bool, | 31 | const OnlineLMConfig &, const EndpointConfig &, bool, |
| 32 | - const std::string &, int32_t, float>(), | 32 | + const std::string &, int32_t, const std::string &, float>(), |
| 33 | py::arg("feat_config"), py::arg("model_config"), | 33 | py::arg("feat_config"), py::arg("model_config"), |
| 34 | py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), | 34 | py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), |
| 35 | py::arg("enable_endpoint"), py::arg("decoding_method"), | 35 | py::arg("enable_endpoint"), py::arg("decoding_method"), |
| 36 | - py::arg("max_active_paths") = 4, py::arg("context_score") = 0) | 36 | + py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", |
| 37 | + py::arg("hotwords_score") = 0) | ||
| 37 | .def_readwrite("feat_config", &PyClass::feat_config) | 38 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 38 | .def_readwrite("model_config", &PyClass::model_config) | 39 | .def_readwrite("model_config", &PyClass::model_config) |
| 39 | .def_readwrite("endpoint_config", &PyClass::endpoint_config) | 40 | .def_readwrite("endpoint_config", &PyClass::endpoint_config) |
| 40 | .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) | 41 | .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) |
| 41 | .def_readwrite("decoding_method", &PyClass::decoding_method) | 42 | .def_readwrite("decoding_method", &PyClass::decoding_method) |
| 42 | .def_readwrite("max_active_paths", &PyClass::max_active_paths) | 43 | .def_readwrite("max_active_paths", &PyClass::max_active_paths) |
| 43 | - .def_readwrite("context_score", &PyClass::context_score) | 44 | + .def_readwrite("hotwords_file", &PyClass::hotwords_file) |
| 45 | + .def_readwrite("hotwords_score", &PyClass::hotwords_score) | ||
| 44 | .def("__str__", &PyClass::ToString); | 46 | .def("__str__", &PyClass::ToString); |
| 45 | } | 47 | } |
| 46 | 48 | ||
| @@ -55,11 +57,10 @@ void PybindOnlineRecognizer(py::module *m) { | @@ -55,11 +57,10 @@ void PybindOnlineRecognizer(py::module *m) { | ||
| 55 | [](const PyClass &self) { return self.CreateStream(); }) | 57 | [](const PyClass &self) { return self.CreateStream(); }) |
| 56 | .def( | 58 | .def( |
| 57 | "create_stream", | 59 | "create_stream", |
| 58 | - [](PyClass &self, | ||
| 59 | - const std::vector<std::vector<int32_t>> &contexts_list) { | ||
| 60 | - return self.CreateStream(contexts_list); | 60 | + [](PyClass &self, const std::string &hotwords) { |
| 61 | + return self.CreateStream(hotwords); | ||
| 61 | }, | 62 | }, |
| 62 | - py::arg("contexts_list")) | 63 | + py::arg("hotwords")) |
| 63 | .def("is_ready", &PyClass::IsReady) | 64 | .def("is_ready", &PyClass::IsReady) |
| 64 | .def("decode_stream", &PyClass::DecodeStream) | 65 | .def("decode_stream", &PyClass::DecodeStream) |
| 65 | .def("decode_streams", | 66 | .def("decode_streams", |
| @@ -4,4 +4,4 @@ from _sherpa_onnx import Display, OfflineStream, OnlineStream | @@ -4,4 +4,4 @@ from _sherpa_onnx import Display, OfflineStream, OnlineStream | ||
| 4 | 4 | ||
| 5 | from .offline_recognizer import OfflineRecognizer | 5 | from .offline_recognizer import OfflineRecognizer |
| 6 | from .online_recognizer import OnlineRecognizer | 6 | from .online_recognizer import OnlineRecognizer |
| 7 | -from .utils import encode_contexts | 7 | +from .utils import text2token |
sherpa-onnx/python/sherpa_onnx/cli.py
0 → 100644
| 1 | +# Copyright (c) 2023 Xiaomi Corporation | ||
| 2 | + | ||
| 3 | +import logging | ||
| 4 | +import click | ||
| 5 | +from pathlib import Path | ||
| 6 | +from sherpa_onnx import text2token | ||
| 7 | + | ||
| 8 | + | ||
| 9 | +@click.group() | ||
| 10 | +def cli(): | ||
| 11 | + """ | ||
| 12 | + The shell entry point to sherpa-onnx. | ||
| 13 | + """ | ||
| 14 | + logging.basicConfig( | ||
| 15 | + format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s", | ||
| 16 | + level=logging.INFO, | ||
| 17 | + ) | ||
| 18 | + | ||
| 19 | + | ||
| 20 | +@cli.command(name="text2token") | ||
| 21 | +@click.argument("input", type=click.Path(exists=True, dir_okay=False)) | ||
| 22 | +@click.argument("output", type=click.Path()) | ||
| 23 | +@click.option( | ||
| 24 | + "--tokens", | ||
| 25 | + type=str, | ||
| 26 | + required=True, | ||
| 27 | + help="The path to tokens.txt.", | ||
| 28 | +) | ||
| 29 | +@click.option( | ||
| 30 | + "--tokens-type", | ||
| 31 | + type=str, | ||
| 32 | + required=True, | ||
| 33 | + help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe", | ||
| 34 | +) | ||
| 35 | +@click.option( | ||
| 36 | + "--bpe-model", | ||
| 37 | + type=str, | ||
| 38 | + help="The path to bpe.model. Only required when tokens-type is bpe or cjkchar+bpe.", | ||
| 39 | +) | ||
| 40 | +def encode_text( | ||
| 41 | + input: Path, output: Path, tokens: Path, tokens_type: str, bpe_model: Path | ||
| 42 | +): | ||
| 43 | + """ | ||
| 44 | + Encode the texts given by the INPUT to tokens and write the results to the OUTPUT. | ||
| 45 | + """ | ||
| 46 | + texts = [] | ||
| 47 | + with open(input, "r", encoding="utf8") as f: | ||
| 48 | + for line in f: | ||
| 49 | + texts.append(line.strip()) | ||
| 50 | + encoded_texts = text2token( | ||
| 51 | + texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model | ||
| 52 | + ) | ||
| 53 | + with open(output, "w", encoding="utf8") as f: | ||
| 54 | + for txt in encoded_texts: | ||
| 55 | + f.write(" ".join(txt) + "\n") |
| @@ -43,7 +43,8 @@ class OfflineRecognizer(object): | @@ -43,7 +43,8 @@ class OfflineRecognizer(object): | ||
| 43 | feature_dim: int = 80, | 43 | feature_dim: int = 80, |
| 44 | decoding_method: str = "greedy_search", | 44 | decoding_method: str = "greedy_search", |
| 45 | max_active_paths: int = 4, | 45 | max_active_paths: int = 4, |
| 46 | - context_score: float = 1.5, | 46 | + hotwords_file: str = "", |
| 47 | + hotwords_score: float = 1.5, | ||
| 47 | debug: bool = False, | 48 | debug: bool = False, |
| 48 | provider: str = "cpu", | 49 | provider: str = "cpu", |
| 49 | ): | 50 | ): |
| @@ -105,7 +106,8 @@ class OfflineRecognizer(object): | @@ -105,7 +106,8 @@ class OfflineRecognizer(object): | ||
| 105 | feat_config=feat_config, | 106 | feat_config=feat_config, |
| 106 | model_config=model_config, | 107 | model_config=model_config, |
| 107 | decoding_method=decoding_method, | 108 | decoding_method=decoding_method, |
| 108 | - context_score=context_score, | 109 | + hotwords_file=hotwords_file, |
| 110 | + hotwords_score=hotwords_score, | ||
| 109 | ) | 111 | ) |
| 110 | self.recognizer = _Recognizer(recognizer_config) | 112 | self.recognizer = _Recognizer(recognizer_config) |
| 111 | self.config = recognizer_config | 113 | self.config = recognizer_config |
| @@ -379,11 +381,11 @@ class OfflineRecognizer(object): | @@ -379,11 +381,11 @@ class OfflineRecognizer(object): | ||
| 379 | self.config = recognizer_config | 381 | self.config = recognizer_config |
| 380 | return self | 382 | return self |
| 381 | 383 | ||
| 382 | - def create_stream(self, contexts_list: Optional[List[List[int]]] = None): | ||
| 383 | - if contexts_list is None: | 384 | + def create_stream(self, hotwords: Optional[str] = None): |
| 385 | + if hotwords is None: | ||
| 384 | return self.recognizer.create_stream() | 386 | return self.recognizer.create_stream() |
| 385 | else: | 387 | else: |
| 386 | - return self.recognizer.create_stream(contexts_list) | 388 | + return self.recognizer.create_stream(hotwords) |
| 387 | 389 | ||
| 388 | def decode_stream(self, s: OfflineStream): | 390 | def decode_stream(self, s: OfflineStream): |
| 389 | self.recognizer.decode_stream(s) | 391 | self.recognizer.decode_stream(s) |
| @@ -42,7 +42,8 @@ class OnlineRecognizer(object): | @@ -42,7 +42,8 @@ class OnlineRecognizer(object): | ||
| 42 | rule3_min_utterance_length: float = 20.0, | 42 | rule3_min_utterance_length: float = 20.0, |
| 43 | decoding_method: str = "greedy_search", | 43 | decoding_method: str = "greedy_search", |
| 44 | max_active_paths: int = 4, | 44 | max_active_paths: int = 4, |
| 45 | - context_score: float = 1.5, | 45 | + hotwords_score: float = 1.5, |
| 46 | + hotwords_file: str = "", | ||
| 46 | provider: str = "cpu", | 47 | provider: str = "cpu", |
| 47 | model_type: str = "", | 48 | model_type: str = "", |
| 48 | ): | 49 | ): |
| @@ -138,7 +139,8 @@ class OnlineRecognizer(object): | @@ -138,7 +139,8 @@ class OnlineRecognizer(object): | ||
| 138 | enable_endpoint=enable_endpoint_detection, | 139 | enable_endpoint=enable_endpoint_detection, |
| 139 | decoding_method=decoding_method, | 140 | decoding_method=decoding_method, |
| 140 | max_active_paths=max_active_paths, | 141 | max_active_paths=max_active_paths, |
| 141 | - context_score=context_score, | 142 | + hotwords_score=hotwords_score, |
| 143 | + hotwords_file=hotwords_file, | ||
| 142 | ) | 144 | ) |
| 143 | 145 | ||
| 144 | self.recognizer = _Recognizer(recognizer_config) | 146 | self.recognizer = _Recognizer(recognizer_config) |
| @@ -248,11 +250,11 @@ class OnlineRecognizer(object): | @@ -248,11 +250,11 @@ class OnlineRecognizer(object): | ||
| 248 | self.config = recognizer_config | 250 | self.config = recognizer_config |
| 249 | return self | 251 | return self |
| 250 | 252 | ||
| 251 | - def create_stream(self, contexts_list: Optional[List[List[int]]] = None): | ||
| 252 | - if contexts_list is None: | 253 | + def create_stream(self, hotwords: Optional[str] = None): |
| 254 | + if hotwords is None: | ||
| 253 | return self.recognizer.create_stream() | 255 | return self.recognizer.create_stream() |
| 254 | else: | 256 | else: |
| 255 | - return self.recognizer.create_stream(contexts_list) | 257 | + return self.recognizer.create_stream(hotwords) |
| 256 | 258 | ||
| 257 | def decode_stream(self, s: OnlineStream): | 259 | def decode_stream(self, s: OnlineStream): |
| 258 | self.recognizer.decode_stream(s) | 260 | self.recognizer.decode_stream(s) |
| 1 | -from typing import Dict, List, Optional | 1 | +# Copyright (c) 2023 Xiaomi Corporation |
| 2 | +import re | ||
| 2 | 3 | ||
| 4 | +from pathlib import Path | ||
| 5 | +from typing import List, Optional, Union | ||
| 3 | 6 | ||
| 4 | -def encode_contexts( | ||
| 5 | - modeling_unit: str, | ||
| 6 | - contexts: List[str], | ||
| 7 | - sp: Optional["SentencePieceProcessor"] = None, | ||
| 8 | - tokens_table: Optional[Dict[str, int]] = None, | ||
| 9 | -) -> List[List[int]]: | 7 | +import sentencepiece as spm |
| 8 | + | ||
| 9 | + | ||
| 10 | +def text2token( | ||
| 11 | + texts: List[str], | ||
| 12 | + tokens: str, | ||
| 13 | + tokens_type: str = "cjkchar", | ||
| 14 | + bpe_model: Optional[str] = None, | ||
| 15 | + output_ids: bool = False, | ||
| 16 | +) -> List[List[Union[str, int]]]: | ||
| 10 | """ | 17 | """ |
| 11 | - Encode the given contexts (a list of string) to a list of a list of token ids. | 18 | + Encode the given texts (a list of string) to a list of a list of tokens. |
| 12 | 19 | ||
| 13 | Args: | 20 | Args: |
| 14 | - modeling_unit: | ||
| 15 | - The valid values are bpe, char, bpe+char. | ||
| 16 | - Note: char here means characters in CJK languages, not English like languages. | ||
| 17 | - contexts: | 21 | + texts: |
| 18 | The given contexts list (a list of string). | 22 | The given contexts list (a list of string). |
| 19 | - sp: | ||
| 20 | - An instance of SentencePieceProcessor. | ||
| 21 | - tokens_table: | ||
| 22 | - The tokens_table containing the tokens and the corresponding ids. | 23 | + tokens: |
| 24 | + The path of the tokens.txt. | ||
| 25 | + tokens_type: | ||
| 26 | + The valid values are cjkchar, bpe, cjkchar+bpe. | ||
| 27 | + bpe_model: | ||
| 28 | + The path of the bpe model. Only required when tokens_type is bpe or | ||
| 29 | + cjkchar+bpe. | ||
| 30 | + output_ids: | ||
| 31 | + True to output token ids otherwise tokens. | ||
| 23 | Returns: | 32 | Returns: |
| 24 | - Return the contexts_list, it is a list of a list of token ids. | 33 | + Return the encoded texts, it is a list of a list of token ids if output_ids |
| 34 | + is True, or it is a list of list of tokens. | ||
| 25 | """ | 35 | """ |
| 26 | - contexts_list = [] | ||
| 27 | - if "bpe" in modeling_unit: | ||
| 28 | - assert sp is not None | ||
| 29 | - if "char" in modeling_unit: | ||
| 30 | - assert tokens_table is not None | ||
| 31 | - assert len(tokens_table) > 0, len(tokens_table) | 36 | + assert Path(tokens).is_file(), f"File not exists, {tokens}" |
| 37 | + tokens_table = {} | ||
| 38 | + with open(tokens, "r", encoding="utf-8") as f: | ||
| 39 | + for line in f: | ||
| 40 | + toks = line.strip().split() | ||
| 41 | + assert len(toks) == 2, len(toks) | ||
| 42 | + assert toks[0] not in tokens_table, f"Duplicate token: {toks} " | ||
| 43 | + tokens_table[toks[0]] = int(toks[1]) | ||
| 32 | 44 | ||
| 33 | - if "char" == modeling_unit: | ||
| 34 | - for context in contexts: | ||
| 35 | - assert ' ' not in context | ||
| 36 | - ids = [ | ||
| 37 | - tokens_table[txt] if txt in tokens_table else tokens_table["<unk>"] | ||
| 38 | - for txt in context | ||
| 39 | - ] | ||
| 40 | - contexts_list.append(ids) | ||
| 41 | - elif "bpe" == modeling_unit: | ||
| 42 | - contexts_list = sp.encode(contexts, out_type=int) | ||
| 43 | - else: | ||
| 44 | - assert modeling_unit == "bpe+char", modeling_unit | 45 | + if "bpe" in tokens_type: |
| 46 | + assert Path(bpe_model).is_file(), f"File not exists, {bpe_model}" | ||
| 47 | + sp = spm.SentencePieceProcessor() | ||
| 48 | + sp.load(bpe_model) | ||
| 45 | 49 | ||
| 50 | + texts_list: List[List[str]] = [] | ||
| 51 | + | ||
| 52 | + if tokens_type == "cjkchar": | ||
| 53 | + texts_list = [list("".join(text.split())) for text in texts] | ||
| 54 | + elif tokens_type == "bpe": | ||
| 55 | + texts_list = sp.encode(texts, out_type=str) | ||
| 56 | + else: | ||
| 57 | + assert ( | ||
| 58 | + tokens_type == "cjkchar+bpe" | ||
| 59 | + ), f"Supported tokens_type are cjkchar, bpe, cjkchar+bpe, given {tokens_type}" | ||
| 46 | # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref: | 60 | # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref: |
| 47 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | 61 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) |
| 48 | pattern = re.compile(r"([\u4e00-\u9fff])") | 62 | pattern = re.compile(r"([\u4e00-\u9fff])") |
| 49 | - for context in contexts: | 63 | + for text in texts: |
| 50 | # Example: | 64 | # Example: |
| 51 | # txt = "你好 ITS'S OKAY 的" | 65 | # txt = "你好 ITS'S OKAY 的" |
| 52 | # chars = ["你", "好", " ITS'S OKAY ", "的"] | 66 | # chars = ["你", "好", " ITS'S OKAY ", "的"] |
| 53 | - chars = pattern.split(context.upper()) | 67 | + chars = pattern.split(text) |
| 54 | mix_chars = [w for w in chars if len(w.strip()) > 0] | 68 | mix_chars = [w for w in chars if len(w.strip()) > 0] |
| 55 | - ids = [] | 69 | + text_list = [] |
| 56 | for ch_or_w in mix_chars: | 70 | for ch_or_w in mix_chars: |
| 57 | # ch_or_w is a single CJK charater(i.e., "你"), do nothing. | 71 | # ch_or_w is a single CJK charater(i.e., "你"), do nothing. |
| 58 | if pattern.fullmatch(ch_or_w) is not None: | 72 | if pattern.fullmatch(ch_or_w) is not None: |
| 59 | - ids.append( | ||
| 60 | - tokens_table[ch_or_w] | ||
| 61 | - if ch_or_w in tokens_table | ||
| 62 | - else tokens_table["<unk>"] | ||
| 63 | - ) | 73 | + text_list.append(ch_or_w) |
| 64 | # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "), | 74 | # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "), |
| 65 | # encode ch_or_w using bpe_model. | 75 | # encode ch_or_w using bpe_model. |
| 66 | else: | 76 | else: |
| 67 | - for p in sp.encode_as_pieces(ch_or_w): | ||
| 68 | - ids.append( | ||
| 69 | - tokens_table[p] | ||
| 70 | - if p in tokens_table | ||
| 71 | - else tokens_table["<unk>"] | ||
| 72 | - ) | ||
| 73 | - contexts_list.append(ids) | ||
| 74 | - return contexts_list | 77 | + text_list += sp.encode_as_pieces(ch_or_w) |
| 78 | + texts_list.append(text_list) | ||
| 79 | + | ||
| 80 | + result: List[List[Union[int, str]]] = [] | ||
| 81 | + for text in texts_list: | ||
| 82 | + text_list = [] | ||
| 83 | + contain_oov = False | ||
| 84 | + for txt in text: | ||
| 85 | + if txt in tokens_table: | ||
| 86 | + text_list.append(tokens_table[txt] if output_ids else txt) | ||
| 87 | + else: | ||
| 88 | + print(f"OOV token : {txt}, skipping text : {text}.") | ||
| 89 | + contain_oov = True | ||
| 90 | + break | ||
| 91 | + if contain_oov: | ||
| 92 | + continue | ||
| 93 | + else: | ||
| 94 | + result.append(text_list) | ||
| 95 | + return result |
| @@ -6,12 +6,14 @@ function(sherpa_onnx_add_py_test source) | @@ -6,12 +6,14 @@ function(sherpa_onnx_add_py_test source) | ||
| 6 | COMMAND | 6 | COMMAND |
| 7 | "${PYTHON_EXECUTABLE}" | 7 | "${PYTHON_EXECUTABLE}" |
| 8 | "${CMAKE_CURRENT_SOURCE_DIR}/${source}" | 8 | "${CMAKE_CURRENT_SOURCE_DIR}/${source}" |
| 9 | + WORKING_DIRECTORY | ||
| 10 | + ${CMAKE_CURRENT_SOURCE_DIR} | ||
| 9 | ) | 11 | ) |
| 10 | 12 | ||
| 11 | get_filename_component(sherpa_onnx_path ${CMAKE_CURRENT_LIST_DIR} DIRECTORY) | 13 | get_filename_component(sherpa_onnx_path ${CMAKE_CURRENT_LIST_DIR} DIRECTORY) |
| 12 | 14 | ||
| 13 | set_property(TEST ${name} | 15 | set_property(TEST ${name} |
| 14 | - PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_path}:$<TARGET_FILE_DIR:_sherpa_onnx>:$ENV{PYTHONPATH}" | 16 | + PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_onnx_path}:$<TARGET_FILE_DIR:_sherpa_onnx>:$ENV{PYTHONPATH}" |
| 15 | ) | 17 | ) |
| 16 | endfunction() | 18 | endfunction() |
| 17 | 19 | ||
| @@ -21,6 +23,7 @@ set(py_test_files | @@ -21,6 +23,7 @@ set(py_test_files | ||
| 21 | test_offline_recognizer.py | 23 | test_offline_recognizer.py |
| 22 | test_online_recognizer.py | 24 | test_online_recognizer.py |
| 23 | test_online_transducer_model_config.py | 25 | test_online_transducer_model_config.py |
| 26 | + test_text2token.py | ||
| 24 | ) | 27 | ) |
| 25 | 28 | ||
| 26 | foreach(source IN LISTS py_test_files) | 29 | foreach(source IN LISTS py_test_files) |
sherpa-onnx/python/tests/test_text2token.py
0 → 100644
| 1 | +# sherpa-onnx/python/tests/test_text2token.py | ||
| 2 | +# | ||
| 3 | +# Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +# | ||
| 5 | +# To run this single test, use | ||
| 6 | +# | ||
| 7 | +# ctest --verbose -R test_text2token_py | ||
| 8 | + | ||
| 9 | +import unittest | ||
| 10 | +from pathlib import Path | ||
| 11 | + | ||
| 12 | +import sherpa_onnx | ||
| 13 | + | ||
| 14 | +d = "/tmp/sherpa-test-data" | ||
| 15 | +# Please refer to | ||
| 16 | +# https://github.com/pkufool/sherpa-test-data | ||
| 17 | +# to download test data for testing | ||
| 18 | + | ||
| 19 | + | ||
| 20 | +class TestText2Token(unittest.TestCase): | ||
| 21 | + def test_bpe(self): | ||
| 22 | + tokens = f"{d}/text2token/tokens_en.txt" | ||
| 23 | + bpe_model = f"{d}/text2token/bpe_en.model" | ||
| 24 | + | ||
| 25 | + if not Path(tokens).is_file() or not Path(bpe_model).is_file(): | ||
| 26 | + print( | ||
| 27 | + f"No test data found, skipping test_bpe().\n" | ||
| 28 | + f"You can download the test data by: \n" | ||
| 29 | + f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data" | ||
| 30 | + ) | ||
| 31 | + return | ||
| 32 | + | ||
| 33 | + texts = ["HELLO WORLD", "I LOVE YOU"] | ||
| 34 | + encoded_texts = sherpa_onnx.text2token( | ||
| 35 | + texts, | ||
| 36 | + tokens=tokens, | ||
| 37 | + tokens_type="bpe", | ||
| 38 | + bpe_model=bpe_model, | ||
| 39 | + ) | ||
| 40 | + assert encoded_texts == [ | ||
| 41 | + ["▁HE", "LL", "O", "▁WORLD"], | ||
| 42 | + ["▁I", "▁LOVE", "▁YOU"], | ||
| 43 | + ], encoded_texts | ||
| 44 | + | ||
| 45 | + encoded_ids = sherpa_onnx.text2token( | ||
| 46 | + texts, | ||
| 47 | + tokens=tokens, | ||
| 48 | + tokens_type="bpe", | ||
| 49 | + bpe_model=bpe_model, | ||
| 50 | + output_ids=True, | ||
| 51 | + ) | ||
| 52 | + assert encoded_ids == [[22, 58, 24, 425], [19, 370, 47]], encoded_ids | ||
| 53 | + | ||
| 54 | + def test_cjkchar(self): | ||
| 55 | + tokens = f"{d}/text2token/tokens_cn.txt" | ||
| 56 | + | ||
| 57 | + if not Path(tokens).is_file(): | ||
| 58 | + print( | ||
| 59 | + f"No test data found, skipping test_cjkchar().\n" | ||
| 60 | + f"You can download the test data by: \n" | ||
| 61 | + f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data" | ||
| 62 | + ) | ||
| 63 | + return | ||
| 64 | + | ||
| 65 | + texts = ["世界人民大团结", "中国 VS 美国"] | ||
| 66 | + encoded_texts = sherpa_onnx.text2token( | ||
| 67 | + texts, tokens=tokens, tokens_type="cjkchar" | ||
| 68 | + ) | ||
| 69 | + assert encoded_texts == [ | ||
| 70 | + ["世", "界", "人", "民", "大", "团", "结"], | ||
| 71 | + ["中", "国", "V", "S", "美", "国"], | ||
| 72 | + ], encoded_texts | ||
| 73 | + encoded_ids = sherpa_onnx.text2token( | ||
| 74 | + texts, | ||
| 75 | + tokens=tokens, | ||
| 76 | + tokens_type="cjkchar", | ||
| 77 | + output_ids=True, | ||
| 78 | + ) | ||
| 79 | + assert encoded_ids == [ | ||
| 80 | + [379, 380, 72, 874, 93, 1251, 489], | ||
| 81 | + [262, 147, 3423, 2476, 21, 147], | ||
| 82 | + ], encoded_ids | ||
| 83 | + | ||
| 84 | + def test_cjkchar_bpe(self): | ||
| 85 | + tokens = f"{d}/text2token/tokens_mix.txt" | ||
| 86 | + bpe_model = f"{d}/text2token/bpe_mix.model" | ||
| 87 | + | ||
| 88 | + if not Path(tokens).is_file() or not Path(bpe_model).is_file(): | ||
| 89 | + print( | ||
| 90 | + f"No test data found, skipping test_cjkchar_bpe().\n" | ||
| 91 | + f"You can download the test data by: \n" | ||
| 92 | + f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data" | ||
| 93 | + ) | ||
| 94 | + return | ||
| 95 | + | ||
| 96 | + texts = ["世界人民 GOES TOGETHER", "中国 GOES WITH 美国"] | ||
| 97 | + encoded_texts = sherpa_onnx.text2token( | ||
| 98 | + texts, | ||
| 99 | + tokens=tokens, | ||
| 100 | + tokens_type="cjkchar+bpe", | ||
| 101 | + bpe_model=bpe_model, | ||
| 102 | + ) | ||
| 103 | + assert encoded_texts == [ | ||
| 104 | + ["世", "界", "人", "民", "▁GO", "ES", "▁TOGETHER"], | ||
| 105 | + ["中", "国", "▁GO", "ES", "▁WITH", "美", "国"], | ||
| 106 | + ], encoded_texts | ||
| 107 | + encoded_ids = sherpa_onnx.text2token( | ||
| 108 | + texts, | ||
| 109 | + tokens=tokens, | ||
| 110 | + tokens_type="cjkchar+bpe", | ||
| 111 | + bpe_model=bpe_model, | ||
| 112 | + output_ids=True, | ||
| 113 | + ) | ||
| 114 | + assert encoded_ids == [ | ||
| 115 | + [1368, 1392, 557, 680, 275, 178, 475], | ||
| 116 | + [685, 736, 275, 178, 179, 921, 736], | ||
| 117 | + ], encoded_ids | ||
| 118 | + | ||
| 119 | + | ||
| 120 | +if __name__ == "__main__": | ||
| 121 | + unittest.main() |
-
请 注册 或 登录 后发表评论