frankyoujian
Committed by GitHub

Support real time hotwords on python (#230)

* support real time hotwords on python

* fix comments
@@ -10,6 +10,9 @@ import argparse @@ -10,6 +10,9 @@ import argparse
10 import sys 10 import sys
11 from pathlib import Path 11 from pathlib import Path
12 12
  13 +from typing import List, Tuple
  14 +import sentencepiece as spm
  15 +
13 try: 16 try:
14 import sounddevice as sd 17 import sounddevice as sd
15 except ImportError: 18 except ImportError:
@@ -70,6 +73,59 @@ def get_args(): @@ -70,6 +73,59 @@ def get_args():
70 help="Valid values are greedy_search and modified_beam_search", 73 help="Valid values are greedy_search and modified_beam_search",
71 ) 74 )
72 75
  76 + parser.add_argument(
  77 + "--max-active-paths",
  78 + type=int,
  79 + default=4,
  80 + help="""Used only when --decoding-method is modified_beam_search.
  81 + It specifies number of active paths to keep during decoding.
  82 + """,
  83 + )
  84 +
  85 + parser.add_argument(
  86 + "--bpe-model",
  87 + type=str,
  88 + default="",
  89 + help="""
  90 + Path to bpe.model, it will be used to tokenize contexts biasing phrases.
  91 + Used only when --decoding-method=modified_beam_search
  92 + """,
  93 + )
  94 +
  95 + parser.add_argument(
  96 + "--modeling-unit",
  97 + type=str,
  98 + default="char",
  99 + help="""
  100 + The type of modeling unit, it will be used to tokenize contexts biasing phrases.
  101 + Valid values are bpe, bpe+char, char.
  102 + Note: the char here means characters in CJK languages.
  103 + Used only when --decoding-method=modified_beam_search
  104 + """,
  105 + )
  106 +
  107 + parser.add_argument(
  108 + "--contexts",
  109 + type=str,
  110 + default="",
  111 + help="""
  112 + The context list, it is a string containing some words/phrases separated
  113 + with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY".
  114 + Used only when --decoding-method=modified_beam_search
  115 + """,
  116 + )
  117 +
  118 + parser.add_argument(
  119 + "--context-score",
  120 + type=float,
  121 + default=1.5,
  122 + help="""
  123 + The context score of each token for biasing word/phrase. Used only if
  124 + --contexts is given.
  125 + Used only when --decoding-method=modified_beam_search
  126 + """,
  127 + )
  128 +
73 return parser.parse_args() 129 return parser.parse_args()
74 130
75 131
@@ -91,11 +147,40 @@ def create_recognizer(): @@ -91,11 +147,40 @@ def create_recognizer():
91 sample_rate=16000, 147 sample_rate=16000,
92 feature_dim=80, 148 feature_dim=80,
93 decoding_method=args.decoding_method, 149 decoding_method=args.decoding_method,
  150 + max_active_paths=args.max_active_paths,
  151 + context_score=args.context_score,
94 ) 152 )
95 return recognizer 153 return recognizer
96 154
  155 +def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
  156 + sp = None
  157 + if "bpe" in args.modeling_unit:
  158 + assert_file_exists(args.bpe_model)
  159 + sp = spm.SentencePieceProcessor()
  160 + sp.load(args.bpe_model)
  161 + tokens = {}
  162 + with open(args.tokens, "r", encoding="utf-8") as f:
  163 + for line in f:
  164 + toks = line.strip().split()
  165 + assert len(toks) == 2, len(toks)
  166 + assert toks[0] not in tokens, f"Duplicate token: {toks} "
  167 + tokens[toks[0]] = int(toks[1])
  168 + return sherpa_onnx.encode_contexts(
  169 + modeling_unit=args.modeling_unit,
  170 + contexts=contexts,
  171 + sp=sp,
  172 + tokens_table=tokens,
  173 + )
97 174
98 def main(): 175 def main():
  176 + args = get_args()
  177 +
  178 + contexts_list = []
  179 + contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
  180 + if contexts:
  181 + print(f"Contexts list: {contexts}")
  182 + contexts_list = encode_contexts(args, contexts)
  183 +
99 recognizer = create_recognizer() 184 recognizer = create_recognizer()
100 print("Started! Please speak") 185 print("Started! Please speak")
101 186
@@ -104,7 +189,10 @@ def main(): @@ -104,7 +189,10 @@ def main():
104 sample_rate = 48000 189 sample_rate = 48000
105 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms 190 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
106 last_result = "" 191 last_result = ""
107 - stream = recognizer.create_stream() 192 + if contexts_list:
  193 + stream = recognizer.create_stream(contexts_list=contexts_list)
  194 + else:
  195 + stream = recognizer.create_stream()
108 with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: 196 with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
109 while True: 197 while True:
110 samples, _ = s.read(samples_per_read) # a blocking read 198 samples, _ = s.read(samples_per_read) # a blocking read
@@ -117,7 +205,6 @@ def main(): @@ -117,7 +205,6 @@ def main():
117 last_result = result 205 last_result = result
118 print("\r{}".format(result), end="", flush=True) 206 print("\r{}".format(result), end="", flush=True)
119 207
120 -  
121 if __name__ == "__main__": 208 if __name__ == "__main__":
122 devices = sd.query_devices() 209 devices = sd.query_devices()
123 print(devices) 210 print(devices)