Fangjun Kuang
Committed by GitHub

Support specifying providers in Python API (#198)

@@ -112,6 +112,7 @@ class OfflineRecognizer(object): @@ -112,6 +112,7 @@ class OfflineRecognizer(object):
112 feature_dim: int = 80, 112 feature_dim: int = 80,
113 decoding_method: str = "greedy_search", 113 decoding_method: str = "greedy_search",
114 debug: bool = False, 114 debug: bool = False,
  115 + provider: str = "cpu",
115 ): 116 ):
116 """ 117 """
117 Please refer to 118 Please refer to
@@ -138,6 +139,8 @@ class OfflineRecognizer(object): @@ -138,6 +139,8 @@ class OfflineRecognizer(object):
138 Valid values are greedy_search, modified_beam_search. 139 Valid values are greedy_search, modified_beam_search.
139 debug: 140 debug:
140 True to show debug messages. 141 True to show debug messages.
  142 + provider:
  143 + onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
141 """ 144 """
142 self = cls.__new__(cls) 145 self = cls.__new__(cls)
143 model_config = OfflineModelConfig( 146 model_config = OfflineModelConfig(
@@ -145,6 +148,7 @@ class OfflineRecognizer(object): @@ -145,6 +148,7 @@ class OfflineRecognizer(object):
145 tokens=tokens, 148 tokens=tokens,
146 num_threads=num_threads, 149 num_threads=num_threads,
147 debug=debug, 150 debug=debug,
  151 + provider=provider,
148 ) 152 )
149 153
150 feat_config = OfflineFeatureExtractorConfig( 154 feat_config = OfflineFeatureExtractorConfig(
@@ -170,6 +174,7 @@ class OfflineRecognizer(object): @@ -170,6 +174,7 @@ class OfflineRecognizer(object):
170 feature_dim: int = 80, 174 feature_dim: int = 80,
171 decoding_method: str = "greedy_search", 175 decoding_method: str = "greedy_search",
172 debug: bool = False, 176 debug: bool = False,
  177 + provider: str = "cpu",
173 ): 178 ):
174 """ 179 """
175 Please refer to 180 Please refer to
@@ -196,6 +201,8 @@ class OfflineRecognizer(object): @@ -196,6 +201,8 @@ class OfflineRecognizer(object):
196 Valid values are greedy_search, modified_beam_search. 201 Valid values are greedy_search, modified_beam_search.
197 debug: 202 debug:
198 True to show debug messages. 203 True to show debug messages.
  204 + provider:
  205 + onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
199 """ 206 """
200 self = cls.__new__(cls) 207 self = cls.__new__(cls)
201 model_config = OfflineModelConfig( 208 model_config = OfflineModelConfig(
@@ -203,6 +210,7 @@ class OfflineRecognizer(object): @@ -203,6 +210,7 @@ class OfflineRecognizer(object):
203 tokens=tokens, 210 tokens=tokens,
204 num_threads=num_threads, 211 num_threads=num_threads,
205 debug=debug, 212 debug=debug,
  213 + provider=provider,
206 ) 214 )
207 215
208 feat_config = OfflineFeatureExtractorConfig( 216 feat_config = OfflineFeatureExtractorConfig(
@@ -131,7 +131,7 @@ class OnlineRecognizer(object): @@ -131,7 +131,7 @@ class OnlineRecognizer(object):
131 self.recognizer = _Recognizer(recognizer_config) 131 self.recognizer = _Recognizer(recognizer_config)
132 self.config = recognizer_config 132 self.config = recognizer_config
133 133
134 - def create_stream(self, contexts_list : Optional[List[List[int]]] = None): 134 + def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
135 if contexts_list is None: 135 if contexts_list is None:
136 return self.recognizer.create_stream() 136 return self.recognizer.create_stream()
137 else: 137 else:
@@ -72,6 +72,7 @@ class TestOfflineRecognizer(unittest.TestCase): @@ -72,6 +72,7 @@ class TestOfflineRecognizer(unittest.TestCase):
72 joiner=joiner, 72 joiner=joiner,
73 tokens=tokens, 73 tokens=tokens,
74 num_threads=1, 74 num_threads=1,
  75 + provider="cpu",
75 ) 76 )
76 77
77 s = recognizer.create_stream() 78 s = recognizer.create_stream()
@@ -106,6 +107,7 @@ class TestOfflineRecognizer(unittest.TestCase): @@ -106,6 +107,7 @@ class TestOfflineRecognizer(unittest.TestCase):
106 joiner=joiner, 107 joiner=joiner,
107 tokens=tokens, 108 tokens=tokens,
108 num_threads=1, 109 num_threads=1,
  110 + provider="cpu",
109 ) 111 )
110 112
111 s0 = recognizer.create_stream() 113 s0 = recognizer.create_stream()
@@ -143,6 +145,7 @@ class TestOfflineRecognizer(unittest.TestCase): @@ -143,6 +145,7 @@ class TestOfflineRecognizer(unittest.TestCase):
143 paraformer=model, 145 paraformer=model,
144 tokens=tokens, 146 tokens=tokens,
145 num_threads=1, 147 num_threads=1,
  148 + provider="cpu",
146 ) 149 )
147 150
148 s = recognizer.create_stream() 151 s = recognizer.create_stream()
@@ -172,6 +175,7 @@ class TestOfflineRecognizer(unittest.TestCase): @@ -172,6 +175,7 @@ class TestOfflineRecognizer(unittest.TestCase):
172 paraformer=model, 175 paraformer=model,
173 tokens=tokens, 176 tokens=tokens,
174 num_threads=1, 177 num_threads=1,
  178 + provider="cpu",
175 ) 179 )
176 180
177 s0 = recognizer.create_stream() 181 s0 = recognizer.create_stream()
@@ -214,6 +218,7 @@ class TestOfflineRecognizer(unittest.TestCase): @@ -214,6 +218,7 @@ class TestOfflineRecognizer(unittest.TestCase):
214 model=model, 218 model=model,
215 tokens=tokens, 219 tokens=tokens,
216 num_threads=1, 220 num_threads=1,
  221 + provider="cpu",
217 ) 222 )
218 223
219 s = recognizer.create_stream() 224 s = recognizer.create_stream()
@@ -242,6 +247,7 @@ class TestOfflineRecognizer(unittest.TestCase): @@ -242,6 +247,7 @@ class TestOfflineRecognizer(unittest.TestCase):
242 model=model, 247 model=model,
243 tokens=tokens, 248 tokens=tokens,
244 num_threads=1, 249 num_threads=1,
  250 + provider="cpu",
245 ) 251 )
246 252
247 s0 = recognizer.create_stream() 253 s0 = recognizer.create_stream()
@@ -72,6 +72,7 @@ class TestOnlineRecognizer(unittest.TestCase): @@ -72,6 +72,7 @@ class TestOnlineRecognizer(unittest.TestCase):
72 tokens=tokens, 72 tokens=tokens,
73 num_threads=1, 73 num_threads=1,
74 decoding_method=decoding_method, 74 decoding_method=decoding_method,
  75 + provider="cpu",
75 ) 76 )
76 s = recognizer.create_stream() 77 s = recognizer.create_stream()
77 samples, sample_rate = read_wave(wave0) 78 samples, sample_rate = read_wave(wave0)
@@ -115,6 +116,7 @@ class TestOnlineRecognizer(unittest.TestCase): @@ -115,6 +116,7 @@ class TestOnlineRecognizer(unittest.TestCase):
115 tokens=tokens, 116 tokens=tokens,
116 num_threads=1, 117 num_threads=1,
117 decoding_method=decoding_method, 118 decoding_method=decoding_method,
  119 + provider="cpu",
118 ) 120 )
119 streams = [] 121 streams = []
120 waves = [wave0, wave1, wave2, wave3, wave4] 122 waves = [wave0, wave1, wave2, wave3, wave4]