Committed by
GitHub
Support specifying providers in Python API (#198)
正在显示
4 个修改的文件
包含
17 行增加
和
1 行删除
| @@ -112,6 +112,7 @@ class OfflineRecognizer(object): | @@ -112,6 +112,7 @@ class OfflineRecognizer(object): | ||
| 112 | feature_dim: int = 80, | 112 | feature_dim: int = 80, |
| 113 | decoding_method: str = "greedy_search", | 113 | decoding_method: str = "greedy_search", |
| 114 | debug: bool = False, | 114 | debug: bool = False, |
| 115 | + provider: str = "cpu", | ||
| 115 | ): | 116 | ): |
| 116 | """ | 117 | """ |
| 117 | Please refer to | 118 | Please refer to |
| @@ -138,6 +139,8 @@ class OfflineRecognizer(object): | @@ -138,6 +139,8 @@ class OfflineRecognizer(object): | ||
| 138 | Valid values are greedy_search, modified_beam_search. | 139 | Valid values are greedy_search, modified_beam_search. |
| 139 | debug: | 140 | debug: |
| 140 | True to show debug messages. | 141 | True to show debug messages. |
| 142 | + provider: | ||
| 143 | + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | ||
| 141 | """ | 144 | """ |
| 142 | self = cls.__new__(cls) | 145 | self = cls.__new__(cls) |
| 143 | model_config = OfflineModelConfig( | 146 | model_config = OfflineModelConfig( |
| @@ -145,6 +148,7 @@ class OfflineRecognizer(object): | @@ -145,6 +148,7 @@ class OfflineRecognizer(object): | ||
| 145 | tokens=tokens, | 148 | tokens=tokens, |
| 146 | num_threads=num_threads, | 149 | num_threads=num_threads, |
| 147 | debug=debug, | 150 | debug=debug, |
| 151 | + provider=provider, | ||
| 148 | ) | 152 | ) |
| 149 | 153 | ||
| 150 | feat_config = OfflineFeatureExtractorConfig( | 154 | feat_config = OfflineFeatureExtractorConfig( |
| @@ -170,6 +174,7 @@ class OfflineRecognizer(object): | @@ -170,6 +174,7 @@ class OfflineRecognizer(object): | ||
| 170 | feature_dim: int = 80, | 174 | feature_dim: int = 80, |
| 171 | decoding_method: str = "greedy_search", | 175 | decoding_method: str = "greedy_search", |
| 172 | debug: bool = False, | 176 | debug: bool = False, |
| 177 | + provider: str = "cpu", | ||
| 173 | ): | 178 | ): |
| 174 | """ | 179 | """ |
| 175 | Please refer to | 180 | Please refer to |
| @@ -196,6 +201,8 @@ class OfflineRecognizer(object): | @@ -196,6 +201,8 @@ class OfflineRecognizer(object): | ||
| 196 | Valid values are greedy_search, modified_beam_search. | 201 | Valid values are greedy_search, modified_beam_search. |
| 197 | debug: | 202 | debug: |
| 198 | True to show debug messages. | 203 | True to show debug messages. |
| 204 | + provider: | ||
| 205 | + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | ||
| 199 | """ | 206 | """ |
| 200 | self = cls.__new__(cls) | 207 | self = cls.__new__(cls) |
| 201 | model_config = OfflineModelConfig( | 208 | model_config = OfflineModelConfig( |
| @@ -203,6 +210,7 @@ class OfflineRecognizer(object): | @@ -203,6 +210,7 @@ class OfflineRecognizer(object): | ||
| 203 | tokens=tokens, | 210 | tokens=tokens, |
| 204 | num_threads=num_threads, | 211 | num_threads=num_threads, |
| 205 | debug=debug, | 212 | debug=debug, |
| 213 | + provider=provider, | ||
| 206 | ) | 214 | ) |
| 207 | 215 | ||
| 208 | feat_config = OfflineFeatureExtractorConfig( | 216 | feat_config = OfflineFeatureExtractorConfig( |
| @@ -131,7 +131,7 @@ class OnlineRecognizer(object): | @@ -131,7 +131,7 @@ class OnlineRecognizer(object): | ||
| 131 | self.recognizer = _Recognizer(recognizer_config) | 131 | self.recognizer = _Recognizer(recognizer_config) |
| 132 | self.config = recognizer_config | 132 | self.config = recognizer_config |
| 133 | 133 | ||
| 134 | - def create_stream(self, contexts_list : Optional[List[List[int]]] = None): | 134 | + def create_stream(self, contexts_list: Optional[List[List[int]]] = None): |
| 135 | if contexts_list is None: | 135 | if contexts_list is None: |
| 136 | return self.recognizer.create_stream() | 136 | return self.recognizer.create_stream() |
| 137 | else: | 137 | else: |
| @@ -72,6 +72,7 @@ class TestOfflineRecognizer(unittest.TestCase): | @@ -72,6 +72,7 @@ class TestOfflineRecognizer(unittest.TestCase): | ||
| 72 | joiner=joiner, | 72 | joiner=joiner, |
| 73 | tokens=tokens, | 73 | tokens=tokens, |
| 74 | num_threads=1, | 74 | num_threads=1, |
| 75 | + provider="cpu", | ||
| 75 | ) | 76 | ) |
| 76 | 77 | ||
| 77 | s = recognizer.create_stream() | 78 | s = recognizer.create_stream() |
| @@ -106,6 +107,7 @@ class TestOfflineRecognizer(unittest.TestCase): | @@ -106,6 +107,7 @@ class TestOfflineRecognizer(unittest.TestCase): | ||
| 106 | joiner=joiner, | 107 | joiner=joiner, |
| 107 | tokens=tokens, | 108 | tokens=tokens, |
| 108 | num_threads=1, | 109 | num_threads=1, |
| 110 | + provider="cpu", | ||
| 109 | ) | 111 | ) |
| 110 | 112 | ||
| 111 | s0 = recognizer.create_stream() | 113 | s0 = recognizer.create_stream() |
| @@ -143,6 +145,7 @@ class TestOfflineRecognizer(unittest.TestCase): | @@ -143,6 +145,7 @@ class TestOfflineRecognizer(unittest.TestCase): | ||
| 143 | paraformer=model, | 145 | paraformer=model, |
| 144 | tokens=tokens, | 146 | tokens=tokens, |
| 145 | num_threads=1, | 147 | num_threads=1, |
| 148 | + provider="cpu", | ||
| 146 | ) | 149 | ) |
| 147 | 150 | ||
| 148 | s = recognizer.create_stream() | 151 | s = recognizer.create_stream() |
| @@ -172,6 +175,7 @@ class TestOfflineRecognizer(unittest.TestCase): | @@ -172,6 +175,7 @@ class TestOfflineRecognizer(unittest.TestCase): | ||
| 172 | paraformer=model, | 175 | paraformer=model, |
| 173 | tokens=tokens, | 176 | tokens=tokens, |
| 174 | num_threads=1, | 177 | num_threads=1, |
| 178 | + provider="cpu", | ||
| 175 | ) | 179 | ) |
| 176 | 180 | ||
| 177 | s0 = recognizer.create_stream() | 181 | s0 = recognizer.create_stream() |
| @@ -214,6 +218,7 @@ class TestOfflineRecognizer(unittest.TestCase): | @@ -214,6 +218,7 @@ class TestOfflineRecognizer(unittest.TestCase): | ||
| 214 | model=model, | 218 | model=model, |
| 215 | tokens=tokens, | 219 | tokens=tokens, |
| 216 | num_threads=1, | 220 | num_threads=1, |
| 221 | + provider="cpu", | ||
| 217 | ) | 222 | ) |
| 218 | 223 | ||
| 219 | s = recognizer.create_stream() | 224 | s = recognizer.create_stream() |
| @@ -242,6 +247,7 @@ class TestOfflineRecognizer(unittest.TestCase): | @@ -242,6 +247,7 @@ class TestOfflineRecognizer(unittest.TestCase): | ||
| 242 | model=model, | 247 | model=model, |
| 243 | tokens=tokens, | 248 | tokens=tokens, |
| 244 | num_threads=1, | 249 | num_threads=1, |
| 250 | + provider="cpu", | ||
| 245 | ) | 251 | ) |
| 246 | 252 | ||
| 247 | s0 = recognizer.create_stream() | 253 | s0 = recognizer.create_stream() |
| @@ -72,6 +72,7 @@ class TestOnlineRecognizer(unittest.TestCase): | @@ -72,6 +72,7 @@ class TestOnlineRecognizer(unittest.TestCase): | ||
| 72 | tokens=tokens, | 72 | tokens=tokens, |
| 73 | num_threads=1, | 73 | num_threads=1, |
| 74 | decoding_method=decoding_method, | 74 | decoding_method=decoding_method, |
| 75 | + provider="cpu", | ||
| 75 | ) | 76 | ) |
| 76 | s = recognizer.create_stream() | 77 | s = recognizer.create_stream() |
| 77 | samples, sample_rate = read_wave(wave0) | 78 | samples, sample_rate = read_wave(wave0) |
| @@ -115,6 +116,7 @@ class TestOnlineRecognizer(unittest.TestCase): | @@ -115,6 +116,7 @@ class TestOnlineRecognizer(unittest.TestCase): | ||
| 115 | tokens=tokens, | 116 | tokens=tokens, |
| 116 | num_threads=1, | 117 | num_threads=1, |
| 117 | decoding_method=decoding_method, | 118 | decoding_method=decoding_method, |
| 119 | + provider="cpu", | ||
| 118 | ) | 120 | ) |
| 119 | streams = [] | 121 | streams = [] |
| 120 | waves = [wave0, wave1, wave2, wave3, wave4] | 122 | waves = [wave0, wave1, wave2, wave3, wave4] |
-
请 注册 或 登录 后发表评论