Committed by
GitHub
Support specifying provider for python examples (#244)
正在显示
4 个修改的文件
包含
35 行增加
和
1 行删除
| @@ -80,6 +80,13 @@ def get_args(): | @@ -80,6 +80,13 @@ def get_args(): | ||
| 80 | ) | 80 | ) |
| 81 | 81 | ||
| 82 | parser.add_argument( | 82 | parser.add_argument( |
| 83 | + "--provider", | ||
| 84 | + type=str, | ||
| 85 | + default="cpu", | ||
| 86 | + help="Valid values: cpu, cuda, coreml", | ||
| 87 | + ) | ||
| 88 | + | ||
| 89 | + parser.add_argument( | ||
| 83 | "--bpe-model", | 90 | "--bpe-model", |
| 84 | type=str, | 91 | type=str, |
| 85 | default="", | 92 | default="", |
| @@ -204,6 +211,7 @@ def main(): | @@ -204,6 +211,7 @@ def main(): | ||
| 204 | decoder=args.decoder, | 211 | decoder=args.decoder, |
| 205 | joiner=args.joiner, | 212 | joiner=args.joiner, |
| 206 | num_threads=args.num_threads, | 213 | num_threads=args.num_threads, |
| 214 | + provider=args.provider, | ||
| 207 | sample_rate=16000, | 215 | sample_rate=16000, |
| 208 | feature_dim=80, | 216 | feature_dim=80, |
| 209 | decoding_method=args.decoding_method, | 217 | decoding_method=args.decoding_method, |
| @@ -220,7 +228,6 @@ def main(): | @@ -220,7 +228,6 @@ def main(): | ||
| 220 | print(f"Contexts list: {contexts}") | 228 | print(f"Contexts list: {contexts}") |
| 221 | contexts_list = encode_contexts(args, contexts) | 229 | contexts_list = encode_contexts(args, contexts) |
| 222 | 230 | ||
| 223 | - | ||
| 224 | streams = [] | 231 | streams = [] |
| 225 | total_duration = 0 | 232 | total_duration = 0 |
| 226 | for wave_filename in args.sound_files: | 233 | for wave_filename in args.sound_files: |
| @@ -72,6 +72,13 @@ def get_args(): | @@ -72,6 +72,13 @@ def get_args(): | ||
| 72 | help="Valid values are greedy_search and modified_beam_search", | 72 | help="Valid values are greedy_search and modified_beam_search", |
| 73 | ) | 73 | ) |
| 74 | 74 | ||
| 75 | + parser.add_argument( | ||
| 76 | + "--provider", | ||
| 77 | + type=str, | ||
| 78 | + default="cpu", | ||
| 79 | + help="Valid values: cpu, cuda, coreml", | ||
| 80 | + ) | ||
| 81 | + | ||
| 75 | return parser.parse_args() | 82 | return parser.parse_args() |
| 76 | 83 | ||
| 77 | 84 | ||
| @@ -97,6 +104,7 @@ def create_recognizer(): | @@ -97,6 +104,7 @@ def create_recognizer(): | ||
| 97 | rule2_min_trailing_silence=1.2, | 104 | rule2_min_trailing_silence=1.2, |
| 98 | rule3_min_utterance_length=300, # it essentially disables this rule | 105 | rule3_min_utterance_length=300, # it essentially disables this rule |
| 99 | decoding_method=args.decoding_method, | 106 | decoding_method=args.decoding_method, |
| 107 | + provider=args.provider, | ||
| 100 | ) | 108 | ) |
| 101 | return recognizer | 109 | return recognizer |
| 102 | 110 |
| @@ -83,6 +83,13 @@ def get_args(): | @@ -83,6 +83,13 @@ def get_args(): | ||
| 83 | ) | 83 | ) |
| 84 | 84 | ||
| 85 | parser.add_argument( | 85 | parser.add_argument( |
| 86 | + "--provider", | ||
| 87 | + type=str, | ||
| 88 | + default="cpu", | ||
| 89 | + help="Valid values: cpu, cuda, coreml", | ||
| 90 | + ) | ||
| 91 | + | ||
| 92 | + parser.add_argument( | ||
| 86 | "--bpe-model", | 93 | "--bpe-model", |
| 87 | type=str, | 94 | type=str, |
| 88 | default="", | 95 | default="", |
| @@ -148,10 +155,12 @@ def create_recognizer(): | @@ -148,10 +155,12 @@ def create_recognizer(): | ||
| 148 | feature_dim=80, | 155 | feature_dim=80, |
| 149 | decoding_method=args.decoding_method, | 156 | decoding_method=args.decoding_method, |
| 150 | max_active_paths=args.max_active_paths, | 157 | max_active_paths=args.max_active_paths, |
| 158 | + provider=args.provider, | ||
| 151 | context_score=args.context_score, | 159 | context_score=args.context_score, |
| 152 | ) | 160 | ) |
| 153 | return recognizer | 161 | return recognizer |
| 154 | 162 | ||
| 163 | + | ||
| 155 | def encode_contexts(args, contexts: List[str]) -> List[List[int]]: | 164 | def encode_contexts(args, contexts: List[str]) -> List[List[int]]: |
| 156 | sp = None | 165 | sp = None |
| 157 | if "bpe" in args.modeling_unit: | 166 | if "bpe" in args.modeling_unit: |
| @@ -172,6 +181,7 @@ def encode_contexts(args, contexts: List[str]) -> List[List[int]]: | @@ -172,6 +181,7 @@ def encode_contexts(args, contexts: List[str]) -> List[List[int]]: | ||
| 172 | tokens_table=tokens, | 181 | tokens_table=tokens, |
| 173 | ) | 182 | ) |
| 174 | 183 | ||
| 184 | + | ||
| 175 | def main(): | 185 | def main(): |
| 176 | args = get_args() | 186 | args = get_args() |
| 177 | 187 | ||
| @@ -205,6 +215,7 @@ def main(): | @@ -205,6 +215,7 @@ def main(): | ||
| 205 | last_result = result | 215 | last_result = result |
| 206 | print("\r{}".format(result), end="", flush=True) | 216 | print("\r{}".format(result), end="", flush=True) |
| 207 | 217 | ||
| 218 | + | ||
| 208 | if __name__ == "__main__": | 219 | if __name__ == "__main__": |
| 209 | devices = sd.query_devices() | 220 | devices = sd.query_devices() |
| 210 | print(devices) | 221 | print(devices) |
| @@ -129,6 +129,13 @@ def add_model_args(parser: argparse.ArgumentParser): | @@ -129,6 +129,13 @@ def add_model_args(parser: argparse.ArgumentParser): | ||
| 129 | help="Feature dimension of the model", | 129 | help="Feature dimension of the model", |
| 130 | ) | 130 | ) |
| 131 | 131 | ||
| 132 | + parser.add_argument( | ||
| 133 | + "--provider", | ||
| 134 | + type=str, | ||
| 135 | + default="cpu", | ||
| 136 | + help="Valid values: cpu, cuda, coreml", | ||
| 137 | + ) | ||
| 138 | + | ||
| 132 | 139 | ||
| 133 | def add_decoding_args(parser: argparse.ArgumentParser): | 140 | def add_decoding_args(parser: argparse.ArgumentParser): |
| 134 | parser.add_argument( | 141 | parser.add_argument( |
| @@ -301,6 +308,7 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: | @@ -301,6 +308,7 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: | ||
| 301 | rule1_min_trailing_silence=args.rule1_min_trailing_silence, | 308 | rule1_min_trailing_silence=args.rule1_min_trailing_silence, |
| 302 | rule2_min_trailing_silence=args.rule2_min_trailing_silence, | 309 | rule2_min_trailing_silence=args.rule2_min_trailing_silence, |
| 303 | rule3_min_utterance_length=args.rule3_min_utterance_length, | 310 | rule3_min_utterance_length=args.rule3_min_utterance_length, |
| 311 | + provider=args.provider, | ||
| 304 | ) | 312 | ) |
| 305 | 313 | ||
| 306 | return recognizer | 314 | return recognizer |
-
请 注册 或 登录 后发表评论