Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2024-04-19 18:33:18 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-04-19 18:33:18 +0800
Commit
54bc504065bfd8c07bfc2957560a9e7f7ccb550c
54bc5040
1 parent
c1608b35
Add Python API example for CED audio tagging. (#793)
显示空白字符变更
内嵌
并排对比
正在显示
2 个修改的文件
包含
122 行增加
和
3 行删除
python-api-examples/audio-tagging-from-a-file-ced.py
sherpa-onnx/python/csrc/audio-tagging.cc
python-api-examples/audio-tagging-from-a-file-ced.py
0 → 100755
查看文件 @
54bc504
#!/usr/bin/env python3
"""
This script shows how to use audio tagging Python APIs to tag a file.
Please read the code to download the required model files and test wave file.
"""
import
logging
import
time
from
pathlib
import
Path
import
numpy
as
np
import
sherpa_onnx
import
soundfile
as
sf
def
read_test_wave
():
# Please download the model files and test wave files from
# https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
test_wave
=
"./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/test_wavs/6.wav"
if
not
Path
(
test_wave
)
.
is_file
():
raise
ValueError
(
f
"Please download {test_wave} from "
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
)
# See https://python-soundfile.readthedocs.io/en/0.11.0/#soundfile.read
data
,
sample_rate
=
sf
.
read
(
test_wave
,
always_2d
=
True
,
dtype
=
"float32"
,
)
data
=
data
[:,
0
]
# use only the first channel
samples
=
np
.
ascontiguousarray
(
data
)
# samples is a 1-d array of dtype float32
# sample_rate is a scalar
return
samples
,
sample_rate
def
create_audio_tagger
():
# Please download the model files and test wave files from
# https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
model_file
=
"./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/model.int8.onnx"
label_file
=
(
"./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/class_labels_indices.csv"
)
if
not
Path
(
model_file
)
.
is_file
():
raise
ValueError
(
f
"Please download {model_file} from "
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
)
if
not
Path
(
label_file
)
.
is_file
():
raise
ValueError
(
f
"Please download {label_file} from "
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
)
config
=
sherpa_onnx
.
AudioTaggingConfig
(
model
=
sherpa_onnx
.
AudioTaggingModelConfig
(
ced
=
model_file
,
num_threads
=
1
,
debug
=
True
,
provider
=
"cpu"
,
),
labels
=
label_file
,
top_k
=
5
,
)
if
not
config
.
validate
():
raise
ValueError
(
f
"Please check the config: {config}"
)
print
(
config
)
return
sherpa_onnx
.
AudioTagging
(
config
)
def
main
():
logging
.
info
(
"Create audio tagger"
)
audio_tagger
=
create_audio_tagger
()
logging
.
info
(
"Read test wave"
)
samples
,
sample_rate
=
read_test_wave
()
logging
.
info
(
"Computing"
)
start_time
=
time
.
time
()
stream
=
audio_tagger
.
create_stream
()
stream
.
accept_waveform
(
sample_rate
=
sample_rate
,
waveform
=
samples
)
result
=
audio_tagger
.
compute
(
stream
)
end_time
=
time
.
time
()
elapsed_seconds
=
end_time
-
start_time
audio_duration
=
len
(
samples
)
/
sample_rate
real_time_factor
=
elapsed_seconds
/
audio_duration
logging
.
info
(
f
"Elapsed seconds: {elapsed_seconds:.3f}"
)
logging
.
info
(
f
"Audio duration in seconds: {audio_duration:.3f}"
)
logging
.
info
(
f
"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
)
s
=
"
\n
"
for
i
,
e
in
enumerate
(
result
):
s
+=
f
"{i}: {e}
\n
"
logging
.
info
(
s
)
if
__name__
==
"__main__"
:
formatter
=
"
%(asctime)
s
%(levelname)
s [
%(filename)
s:
%(lineno)
d]
%(message)
s"
logging
.
basicConfig
(
format
=
formatter
,
level
=
logging
.
INFO
)
main
()
...
...
sherpa-onnx/python/csrc/audio-tagging.cc
查看文件 @
54bc504
...
...
@@ -29,9 +29,9 @@ static void PybindAudioTaggingModelConfig(py::module *m) {
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<
const
OfflineZipformerAudioTaggingModelConfig
&
,
const
std
::
string
&
,
int32_t
,
bool
,
const
std
::
string
&>
(),
py
::
arg
(
"zipformer"
),
py
::
arg
(
"ced"
)
=
""
,
py
::
arg
(
"num_threads"
)
=
1
,
py
::
arg
(
"debug"
)
=
false
,
py
::
arg
(
"provider"
)
=
"cpu"
)
py
::
arg
(
"zipformer"
)
=
OfflineZipformerAudioTaggingModelConfig
{},
py
::
arg
(
"ced"
)
=
""
,
py
::
arg
(
"num_threads"
)
=
1
,
py
::
arg
(
"debug"
)
=
false
,
py
::
arg
(
"provider"
)
=
"cpu"
)
.
def_readwrite
(
"zipformer"
,
&
PyClass
::
zipformer
)
.
def_readwrite
(
"num_threads"
,
&
PyClass
::
num_threads
)
.
def_readwrite
(
"debug"
,
&
PyClass
::
debug
)
...
...
请
注册
或
登录
后发表评论