Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-07-06 10:14:01 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-07-06 10:14:01 +0800
Commit
33bf8dc1f47c8101f54105d837f3bf40e5a319cb
33bf8dc1
1 parent
3a08191a
Support specifying providers in Python API (#198)
隐藏空白字符变更
内嵌
并排对比
正在显示
4 个修改的文件
包含
17 行增加
和
1 行删除
sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
sherpa-onnx/python/sherpa_onnx/online_recognizer.py
sherpa-onnx/python/tests/test_offline_recognizer.py
sherpa-onnx/python/tests/test_online_recognizer.py
sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
查看文件 @
33bf8dc
...
...
@@ -112,6 +112,7 @@ class OfflineRecognizer(object):
feature_dim
:
int
=
80
,
decoding_method
:
str
=
"greedy_search"
,
debug
:
bool
=
False
,
provider
:
str
=
"cpu"
,
):
"""
Please refer to
...
...
@@ -138,6 +139,8 @@ class OfflineRecognizer(object):
Valid values are greedy_search, modified_beam_search.
debug:
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self
=
cls
.
__new__
(
cls
)
model_config
=
OfflineModelConfig
(
...
...
@@ -145,6 +148,7 @@ class OfflineRecognizer(object):
tokens
=
tokens
,
num_threads
=
num_threads
,
debug
=
debug
,
provider
=
provider
,
)
feat_config
=
OfflineFeatureExtractorConfig
(
...
...
@@ -170,6 +174,7 @@ class OfflineRecognizer(object):
feature_dim
:
int
=
80
,
decoding_method
:
str
=
"greedy_search"
,
debug
:
bool
=
False
,
provider
:
str
=
"cpu"
,
):
"""
Please refer to
...
...
@@ -196,6 +201,8 @@ class OfflineRecognizer(object):
Valid values are greedy_search, modified_beam_search.
debug:
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self
=
cls
.
__new__
(
cls
)
model_config
=
OfflineModelConfig
(
...
...
@@ -203,6 +210,7 @@ class OfflineRecognizer(object):
tokens
=
tokens
,
num_threads
=
num_threads
,
debug
=
debug
,
provider
=
provider
,
)
feat_config
=
OfflineFeatureExtractorConfig
(
...
...
sherpa-onnx/python/sherpa_onnx/online_recognizer.py
查看文件 @
33bf8dc
...
...
@@ -131,7 +131,7 @@ class OnlineRecognizer(object):
self
.
recognizer
=
_Recognizer
(
recognizer_config
)
self
.
config
=
recognizer_config
def
create_stream
(
self
,
contexts_list
:
Optional
[
List
[
List
[
int
]]]
=
None
):
def
create_stream
(
self
,
contexts_list
:
Optional
[
List
[
List
[
int
]]]
=
None
):
if
contexts_list
is
None
:
return
self
.
recognizer
.
create_stream
()
else
:
...
...
sherpa-onnx/python/tests/test_offline_recognizer.py
查看文件 @
33bf8dc
...
...
@@ -72,6 +72,7 @@ class TestOfflineRecognizer(unittest.TestCase):
joiner
=
joiner
,
tokens
=
tokens
,
num_threads
=
1
,
provider
=
"cpu"
,
)
s
=
recognizer
.
create_stream
()
...
...
@@ -106,6 +107,7 @@ class TestOfflineRecognizer(unittest.TestCase):
joiner
=
joiner
,
tokens
=
tokens
,
num_threads
=
1
,
provider
=
"cpu"
,
)
s0
=
recognizer
.
create_stream
()
...
...
@@ -143,6 +145,7 @@ class TestOfflineRecognizer(unittest.TestCase):
paraformer
=
model
,
tokens
=
tokens
,
num_threads
=
1
,
provider
=
"cpu"
,
)
s
=
recognizer
.
create_stream
()
...
...
@@ -172,6 +175,7 @@ class TestOfflineRecognizer(unittest.TestCase):
paraformer
=
model
,
tokens
=
tokens
,
num_threads
=
1
,
provider
=
"cpu"
,
)
s0
=
recognizer
.
create_stream
()
...
...
@@ -214,6 +218,7 @@ class TestOfflineRecognizer(unittest.TestCase):
model
=
model
,
tokens
=
tokens
,
num_threads
=
1
,
provider
=
"cpu"
,
)
s
=
recognizer
.
create_stream
()
...
...
@@ -242,6 +247,7 @@ class TestOfflineRecognizer(unittest.TestCase):
model
=
model
,
tokens
=
tokens
,
num_threads
=
1
,
provider
=
"cpu"
,
)
s0
=
recognizer
.
create_stream
()
...
...
sherpa-onnx/python/tests/test_online_recognizer.py
查看文件 @
33bf8dc
...
...
@@ -72,6 +72,7 @@ class TestOnlineRecognizer(unittest.TestCase):
tokens
=
tokens
,
num_threads
=
1
,
decoding_method
=
decoding_method
,
provider
=
"cpu"
,
)
s
=
recognizer
.
create_stream
()
samples
,
sample_rate
=
read_wave
(
wave0
)
...
...
@@ -115,6 +116,7 @@ class TestOnlineRecognizer(unittest.TestCase):
tokens
=
tokens
,
num_threads
=
1
,
decoding_method
=
decoding_method
,
provider
=
"cpu"
,
)
streams
=
[]
waves
=
[
wave0
,
wave1
,
wave2
,
wave3
,
wave4
]
...
...
请
注册
或
登录
后发表评论