Fangjun Kuang
Committed by GitHub

Support specifying provider for python examples (#244)

@@ -80,6 +80,13 @@ def get_args(): @@ -80,6 +80,13 @@ def get_args():
80 ) 80 )
81 81
82 parser.add_argument( 82 parser.add_argument(
  83 + "--provider",
  84 + type=str,
  85 + default="cpu",
  86 + help="Valid values: cpu, cuda, coreml",
  87 + )
  88 +
  89 + parser.add_argument(
83 "--bpe-model", 90 "--bpe-model",
84 type=str, 91 type=str,
85 default="", 92 default="",
@@ -204,6 +211,7 @@ def main(): @@ -204,6 +211,7 @@ def main():
204 decoder=args.decoder, 211 decoder=args.decoder,
205 joiner=args.joiner, 212 joiner=args.joiner,
206 num_threads=args.num_threads, 213 num_threads=args.num_threads,
  214 + provider=args.provider,
207 sample_rate=16000, 215 sample_rate=16000,
208 feature_dim=80, 216 feature_dim=80,
209 decoding_method=args.decoding_method, 217 decoding_method=args.decoding_method,
@@ -220,7 +228,6 @@ def main(): @@ -220,7 +228,6 @@ def main():
220 print(f"Contexts list: {contexts}") 228 print(f"Contexts list: {contexts}")
221 contexts_list = encode_contexts(args, contexts) 229 contexts_list = encode_contexts(args, contexts)
222 230
223 -  
224 streams = [] 231 streams = []
225 total_duration = 0 232 total_duration = 0
226 for wave_filename in args.sound_files: 233 for wave_filename in args.sound_files:
@@ -72,6 +72,13 @@ def get_args(): @@ -72,6 +72,13 @@ def get_args():
72 help="Valid values are greedy_search and modified_beam_search", 72 help="Valid values are greedy_search and modified_beam_search",
73 ) 73 )
74 74
  75 + parser.add_argument(
  76 + "--provider",
  77 + type=str,
  78 + default="cpu",
  79 + help="Valid values: cpu, cuda, coreml",
  80 + )
  81 +
75 return parser.parse_args() 82 return parser.parse_args()
76 83
77 84
@@ -97,6 +104,7 @@ def create_recognizer(): @@ -97,6 +104,7 @@ def create_recognizer():
97 rule2_min_trailing_silence=1.2, 104 rule2_min_trailing_silence=1.2,
98 rule3_min_utterance_length=300, # it essentially disables this rule 105 rule3_min_utterance_length=300, # it essentially disables this rule
99 decoding_method=args.decoding_method, 106 decoding_method=args.decoding_method,
  107 + provider=args.provider,
100 ) 108 )
101 return recognizer 109 return recognizer
102 110
@@ -83,6 +83,13 @@ def get_args(): @@ -83,6 +83,13 @@ def get_args():
83 ) 83 )
84 84
85 parser.add_argument( 85 parser.add_argument(
  86 + "--provider",
  87 + type=str,
  88 + default="cpu",
  89 + help="Valid values: cpu, cuda, coreml",
  90 + )
  91 +
  92 + parser.add_argument(
86 "--bpe-model", 93 "--bpe-model",
87 type=str, 94 type=str,
88 default="", 95 default="",
@@ -148,10 +155,12 @@ def create_recognizer(): @@ -148,10 +155,12 @@ def create_recognizer():
148 feature_dim=80, 155 feature_dim=80,
149 decoding_method=args.decoding_method, 156 decoding_method=args.decoding_method,
150 max_active_paths=args.max_active_paths, 157 max_active_paths=args.max_active_paths,
  158 + provider=args.provider,
151 context_score=args.context_score, 159 context_score=args.context_score,
152 ) 160 )
153 return recognizer 161 return recognizer
154 162
  163 +
155 def encode_contexts(args, contexts: List[str]) -> List[List[int]]: 164 def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
156 sp = None 165 sp = None
157 if "bpe" in args.modeling_unit: 166 if "bpe" in args.modeling_unit:
@@ -172,6 +181,7 @@ def encode_contexts(args, contexts: List[str]) -> List[List[int]]: @@ -172,6 +181,7 @@ def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
172 tokens_table=tokens, 181 tokens_table=tokens,
173 ) 182 )
174 183
  184 +
175 def main(): 185 def main():
176 args = get_args() 186 args = get_args()
177 187
@@ -205,6 +215,7 @@ def main(): @@ -205,6 +215,7 @@ def main():
205 last_result = result 215 last_result = result
206 print("\r{}".format(result), end="", flush=True) 216 print("\r{}".format(result), end="", flush=True)
207 217
  218 +
208 if __name__ == "__main__": 219 if __name__ == "__main__":
209 devices = sd.query_devices() 220 devices = sd.query_devices()
210 print(devices) 221 print(devices)
@@ -129,6 +129,13 @@ def add_model_args(parser: argparse.ArgumentParser): @@ -129,6 +129,13 @@ def add_model_args(parser: argparse.ArgumentParser):
129 help="Feature dimension of the model", 129 help="Feature dimension of the model",
130 ) 130 )
131 131
  132 + parser.add_argument(
  133 + "--provider",
  134 + type=str,
  135 + default="cpu",
  136 + help="Valid values: cpu, cuda, coreml",
  137 + )
  138 +
132 139
133 def add_decoding_args(parser: argparse.ArgumentParser): 140 def add_decoding_args(parser: argparse.ArgumentParser):
134 parser.add_argument( 141 parser.add_argument(
@@ -301,6 +308,7 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: @@ -301,6 +308,7 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
301 rule1_min_trailing_silence=args.rule1_min_trailing_silence, 308 rule1_min_trailing_silence=args.rule1_min_trailing_silence,
302 rule2_min_trailing_silence=args.rule2_min_trailing_silence, 309 rule2_min_trailing_silence=args.rule2_min_trailing_silence,
303 rule3_min_utterance_length=args.rule3_min_utterance_length, 310 rule3_min_utterance_length=args.rule3_min_utterance_length,
  311 + provider=args.provider,
304 ) 312 )
305 313
306 return recognizer 314 return recognizer