Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-08-31 14:41:04 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-08-31 14:41:04 +0800
Commit
a0a747a0c0df93cad346144d0f8f9c43bcacca83
a0a747a0
1 parent
2b0152d2
add endpointing for online websocket server (#294)
隐藏空白字符变更
内嵌
并排对比
正在显示
4 个修改的文件
包含
27 行增加
和
2 行删除
sherpa-onnx/csrc/online-recognizer-transducer-impl.h
sherpa-onnx/csrc/online-stream.cc
sherpa-onnx/csrc/online-stream.h
sherpa-onnx/csrc/online-websocket-server-impl.cc
sherpa-onnx/csrc/online-recognizer-transducer-impl.h
查看文件 @
a0a747a
...
...
@@ -26,7 +26,8 @@ namespace sherpa_onnx {
static
OnlineRecognizerResult
Convert
(
const
OnlineTransducerDecoderResult
&
src
,
const
SymbolTable
&
sym_table
,
int32_t
frame_shift_ms
,
int32_t
subsampling_factor
)
{
int32_t
subsampling_factor
,
int32_t
segment
)
{
OnlineRecognizerResult
r
;
r
.
tokens
.
reserve
(
src
.
tokens
.
size
());
r
.
timestamps
.
reserve
(
src
.
tokens
.
size
());
...
...
@@ -44,6 +45,8 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
r
.
timestamps
.
push_back
(
time
);
}
r
.
segment
=
segment
;
return
r
;
}
...
...
@@ -192,7 +195,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
// TODO(fangjun): Remember to change these constants if needed
int32_t
frame_shift_ms
=
10
;
int32_t
subsampling_factor
=
4
;
return
Convert
(
decoder_result
,
sym_
,
frame_shift_ms
,
subsampling_factor
);
return
Convert
(
decoder_result
,
sym_
,
frame_shift_ms
,
subsampling_factor
,
s
->
GetCurrentSegment
());
}
bool
IsEndpoint
(
OnlineStream
*
s
)
const
override
{
...
...
@@ -213,6 +217,15 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
void
Reset
(
OnlineStream
*
s
)
const
override
{
{
// segment is incremented only when the last
// result is not empty
const
auto
&
r
=
s
->
GetResult
();
if
(
!
r
.
tokens
.
empty
()
&&
r
.
tokens
.
back
()
!=
0
)
{
s
->
GetCurrentSegment
()
+=
1
;
}
}
// we keep the decoder_out
decoder_
->
UpdateDecoderOut
(
&
s
->
GetResult
());
Ort
::
Value
decoder_out
=
std
::
move
(
s
->
GetResult
().
decoder_out
);
...
...
sherpa-onnx/csrc/online-stream.cc
查看文件 @
a0a747a
...
...
@@ -43,6 +43,8 @@ class OnlineStream::Impl {
int32_t
&
GetNumProcessedFrames
()
{
return
num_processed_frames_
;
}
int32_t
&
GetCurrentSegment
()
{
return
segment_
;
}
void
SetResult
(
const
OnlineTransducerDecoderResult
&
r
)
{
result_
=
r
;
}
OnlineTransducerDecoderResult
&
GetResult
()
{
return
result_
;
}
...
...
@@ -83,6 +85,7 @@ class OnlineStream::Impl {
ContextGraphPtr
context_graph_
;
int32_t
num_processed_frames_
=
0
;
// before subsampling
int32_t
start_frame_index_
=
0
;
// never reset
int32_t
segment_
=
0
;
OnlineTransducerDecoderResult
result_
;
std
::
vector
<
Ort
::
Value
>
states_
;
std
::
vector
<
float
>
paraformer_feat_cache_
;
...
...
@@ -123,6 +126,10 @@ int32_t &OnlineStream::GetNumProcessedFrames() {
return
impl_
->
GetNumProcessedFrames
();
}
int32_t
&
OnlineStream
::
GetCurrentSegment
()
{
return
impl_
->
GetCurrentSegment
();
}
void
OnlineStream
::
SetResult
(
const
OnlineTransducerDecoderResult
&
r
)
{
impl_
->
SetResult
(
r
);
}
...
...
sherpa-onnx/csrc/online-stream.h
查看文件 @
a0a747a
...
...
@@ -68,6 +68,8 @@ class OnlineStream {
// The returned reference is valid as long as this object is alive.
int32_t
&
GetNumProcessedFrames
();
int32_t
&
GetCurrentSegment
();
void
SetResult
(
const
OnlineTransducerDecoderResult
&
r
);
OnlineTransducerDecoderResult
&
GetResult
();
...
...
sherpa-onnx/csrc/online-websocket-server-impl.cc
查看文件 @
a0a747a
...
...
@@ -194,6 +194,9 @@ void OnlineWebsocketDecoder::Decode() {
for
(
auto
c
:
c_vec
)
{
auto
result
=
recognizer_
->
GetResult
(
c
->
s
.
get
());
if
(
recognizer_
->
IsEndpoint
(
c
->
s
.
get
()))
{
recognizer_
->
Reset
(
c
->
s
.
get
());
}
asio
::
post
(
server_
->
GetConnectionContext
(),
[
this
,
hdl
=
c
->
hdl
,
str
=
result
.
AsJsonString
()]()
{
...
...
请
注册
或
登录
后发表评论