Fangjun Kuang
Committed by GitHub

Support specifying providers in Python API (#198)

... ... @@ -112,6 +112,7 @@ class OfflineRecognizer(object):
feature_dim: int = 80,
decoding_method: str = "greedy_search",
debug: bool = False,
provider: str = "cpu",
):
"""
Please refer to
... ... @@ -138,6 +139,8 @@ class OfflineRecognizer(object):
Valid values are greedy_search, modified_beam_search.
debug:
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
... ... @@ -145,6 +148,7 @@ class OfflineRecognizer(object):
tokens=tokens,
num_threads=num_threads,
debug=debug,
provider=provider,
)
feat_config = OfflineFeatureExtractorConfig(
... ... @@ -170,6 +174,7 @@ class OfflineRecognizer(object):
feature_dim: int = 80,
decoding_method: str = "greedy_search",
debug: bool = False,
provider: str = "cpu",
):
"""
Please refer to
... ... @@ -196,6 +201,8 @@ class OfflineRecognizer(object):
Valid values are greedy_search, modified_beam_search.
debug:
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
... ... @@ -203,6 +210,7 @@ class OfflineRecognizer(object):
tokens=tokens,
num_threads=num_threads,
debug=debug,
provider=provider,
)
feat_config = OfflineFeatureExtractorConfig(
... ...
... ... @@ -131,7 +131,7 @@ class OnlineRecognizer(object):
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
def create_stream(self, contexts_list : Optional[List[List[int]]] = None):
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
if contexts_list is None:
return self.recognizer.create_stream()
else:
... ...
... ... @@ -72,6 +72,7 @@ class TestOfflineRecognizer(unittest.TestCase):
joiner=joiner,
tokens=tokens,
num_threads=1,
provider="cpu",
)
s = recognizer.create_stream()
... ... @@ -106,6 +107,7 @@ class TestOfflineRecognizer(unittest.TestCase):
joiner=joiner,
tokens=tokens,
num_threads=1,
provider="cpu",
)
s0 = recognizer.create_stream()
... ... @@ -143,6 +145,7 @@ class TestOfflineRecognizer(unittest.TestCase):
paraformer=model,
tokens=tokens,
num_threads=1,
provider="cpu",
)
s = recognizer.create_stream()
... ... @@ -172,6 +175,7 @@ class TestOfflineRecognizer(unittest.TestCase):
paraformer=model,
tokens=tokens,
num_threads=1,
provider="cpu",
)
s0 = recognizer.create_stream()
... ... @@ -214,6 +218,7 @@ class TestOfflineRecognizer(unittest.TestCase):
model=model,
tokens=tokens,
num_threads=1,
provider="cpu",
)
s = recognizer.create_stream()
... ... @@ -242,6 +247,7 @@ class TestOfflineRecognizer(unittest.TestCase):
model=model,
tokens=tokens,
num_threads=1,
provider="cpu",
)
s0 = recognizer.create_stream()
... ...
... ... @@ -72,6 +72,7 @@ class TestOnlineRecognizer(unittest.TestCase):
tokens=tokens,
num_threads=1,
decoding_method=decoding_method,
provider="cpu",
)
s = recognizer.create_stream()
samples, sample_rate = read_wave(wave0)
... ... @@ -115,6 +116,7 @@ class TestOnlineRecognizer(unittest.TestCase):
tokens=tokens,
num_threads=1,
decoding_method=decoding_method,
provider="cpu",
)
streams = []
waves = [wave0, wave1, wave2, wave3, wave4]
... ...