Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Jingzhao Ou
2023-06-03 23:13:55 -0700
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-06-04 14:13:55 +0800
Commit
fdd49d05387cb8dedbceeade644d8b0396c638d4
fdd49d05
1 parent
0ed501b8
add batch processing to sherpa-onnx (#166)
隐藏空白字符变更
内嵌
并排对比
正在显示
1 个修改的文件
包含
48 行增加
和
30 行删除
sherpa-onnx/csrc/sherpa-onnx.cc
sherpa-onnx/csrc/sherpa-onnx.cc
查看文件 @
fdd49d0
...
...
@@ -5,6 +5,8 @@
#include <stdio.h>
#include <chrono> // NOLINT
#include <iomanip>
#include <iostream>
#include <string>
#include <vector>
...
...
@@ -14,6 +16,12 @@
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-reader.h"
typedef
struct
{
std
::
unique_ptr
<
sherpa_onnx
::
OnlineStream
>
online_stream
;
float
duration
;
float
elapsed_seconds
;
}
Stream
;
int
main
(
int32_t
argc
,
char
*
argv
[])
{
const
char
*
kUsageMessage
=
R"usage(
Usage:
...
...
@@ -61,29 +69,26 @@ for a list of pre-trained models to download.
sherpa_onnx
::
OnlineRecognizer
recognizer
(
config
);
float
duration
=
0
;
std
::
vector
<
Stream
>
ss
;
const
auto
begin
=
std
::
chrono
::
steady_clock
::
now
();
std
::
vector
<
float
>
durations
;
for
(
int32_t
i
=
1
;
i
<=
po
.
NumArgs
();
++
i
)
{
const
std
::
string
wav_filename
=
po
.
GetArg
(
i
);
int32_t
sampling_rate
=
-
1
;
bool
is_ok
=
false
;
const
std
::
vector
<
float
>
samples
=
sherpa_onnx
::
ReadWave
(
wav_filename
,
&
sampling_rate
,
&
is_ok
);
sherpa_onnx
::
ReadWave
(
wav_filename
,
&
sampling_rate
,
&
is_ok
);
if
(
!
is_ok
)
{
fprintf
(
stderr
,
"Failed to read %s
\n
"
,
wav_filename
.
c_str
());
return
-
1
;
}
fprintf
(
stderr
,
"sampling rate of input file: %d
\n
"
,
sampling_rate
);
const
float
duration
=
samples
.
size
()
/
static_cast
<
float
>
(
sampling_rate
);
fprintf
(
stderr
,
"wav filename: %s
\n
"
,
wav_filename
.
c_str
());
fprintf
(
stderr
,
"wav duration (s): %.3f
\n
"
,
duration
);
fprintf
(
stderr
,
"Started
\n
"
);
const
auto
begin
=
std
::
chrono
::
steady_clock
::
now
();
auto
s
=
recognizer
.
CreateStream
();
s
->
AcceptWaveform
(
sampling_rate
,
samples
.
data
(),
samples
.
size
());
...
...
@@ -94,33 +99,46 @@ for a list of pre-trained models to download.
// Call InputFinished() to indicate that no audio samples are available
s
->
InputFinished
();
ss
.
push_back
({
std
::
move
(
s
),
duration
,
0
});
}
while
(
recognizer
.
IsReady
(
s
.
get
()))
{
recognizer
.
DecodeStream
(
s
.
get
());
}
const
std
::
string
text
=
recognizer
.
GetResult
(
s
.
get
()).
AsJsonString
();
const
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
const
float
elapsed_seconds
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
end
-
begin
)
std
::
vector
<
sherpa_onnx
::
OnlineStream
*>
ready_streams
;
for
(;;)
{
ready_streams
.
clear
();
for
(
auto
&
s
:
ss
)
{
const
auto
p_ss
=
s
.
online_stream
.
get
();
if
(
recognizer
.
IsReady
(
p_ss
))
{
ready_streams
.
push_back
(
p_ss
);
}
else
if
(
s
.
elapsed_seconds
==
0
)
{
const
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
const
float
elapsed_seconds
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
end
-
begin
)
.
count
()
/
1000.
;
s
.
elapsed_seconds
=
elapsed_seconds
;
}
}
fprintf
(
stderr
,
"Done!
\n
"
);
fprintf
(
stderr
,
"Recognition result for %s:
\n
%s
\n
"
,
wav_filename
.
c_str
(),
text
.
c_str
());
fprintf
(
stderr
,
"num threads: %d
\n
"
,
config
.
model_config
.
num_threads
);
fprintf
(
stderr
,
"decoding method: %s
\n
"
,
config
.
decoding_method
.
c_str
());
if
(
config
.
decoding_method
==
"modified_beam_search"
)
{
fprintf
(
stderr
,
"max active paths: %d
\n
"
,
config
.
max_active_paths
);
if
(
ready_streams
.
empty
())
{
break
;
}
fprintf
(
stderr
,
"Elapsed seconds: %.3f s
\n
"
,
elapsed_seconds
);
const
float
rtf
=
elapsed_seconds
/
duration
;
fprintf
(
stderr
,
"Real time factor (RTF): %.3f / %.3f = %.3f
\n
"
,
elapsed_seconds
,
duration
,
rtf
);
recognizer
.
DecodeStreams
(
ready_streams
.
data
(),
ready_streams
.
size
());
}
std
::
ostringstream
os
;
for
(
int32_t
i
=
1
;
i
<=
po
.
NumArgs
();
++
i
)
{
const
auto
&
s
=
ss
[
i
-
1
];
const
float
rtf
=
s
.
elapsed_seconds
/
s
.
duration
;
os
<<
po
.
GetArg
(
i
)
<<
"
\n
"
;
os
<<
std
::
setprecision
(
2
)
<<
"Elapsed seconds: "
<<
s
.
elapsed_seconds
<<
", Real time factor (RTF): "
<<
rtf
<<
"
\n
"
;
const
auto
r
=
recognizer
.
GetResult
(
s
.
online_stream
.
get
());
os
<<
r
.
text
<<
"
\n
"
;
os
<<
r
.
AsJsonString
()
<<
"
\n\n
"
;
}
std
::
cerr
<<
os
.
str
();
return
0
;
}
...
...
请
注册
或
登录
后发表评论