Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-04-19 16:02:37 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-04-19 16:02:37 +0800
Commit
ad05f5266634aff7468a62f4657b7c1ebb7a7283
ad05f526
1 parent
4b5d2887
Add timestamps for streaming ASR. (#123)
隐藏空白字符变更
内嵌
并排对比
正在显示
11 个修改的文件
包含
170 行增加
和
19 行删除
CMakeLists.txt
cmake/json.cmake
sherpa-onnx/csrc/online-recognizer.cc
sherpa-onnx/csrc/online-recognizer.h
sherpa-onnx/csrc/online-transducer-decoder.cc
sherpa-onnx/csrc/online-transducer-decoder.h
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
sherpa-onnx/csrc/online-websocket-server-impl.cc
sherpa-onnx/csrc/sherpa-onnx.cc
sherpa-onnx/jni/jni.cc
CMakeLists.txt
查看文件 @
ad05f52
...
...
@@ -126,6 +126,8 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
include
(
asio
)
endif
()
include
(
json
)
add_subdirectory
(
sherpa-onnx
)
if
(
SHERPA_ONNX_ENABLE_C_API
)
...
...
cmake/json.cmake
0 → 100644
查看文件 @
ad05f52
function
(
download_json
)
include
(
FetchContent
)
set
(
json_URL
"https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz"
)
set
(
json_URL2
"https://huggingface.co/csukuangfj/sherpa-cmake-deps/resolve/main/json-3.11.2.tar.gz"
)
set
(
json_HASH
"SHA256=d69f9deb6a75e2580465c6c4c5111b89c4dc2fa94e3a85fcd2ffcd9a143d9273"
)
# If you don't have access to the Internet,
# please pre-download json
set
(
possible_file_locations
$ENV{HOME}/Downloads/json-3.11.2.tar.gz
${
PROJECT_SOURCE_DIR
}
/json-3.11.2.tar.gz
${
PROJECT_BINARY_DIR
}
/json-3.11.2.tar.gz
/tmp/json-3.11.2.tar.gz
/star-fj/fangjun/download/github/json-3.11.2.tar.gz
)
foreach
(
f IN LISTS possible_file_locations
)
if
(
EXISTS
${
f
}
)
set
(
json_URL
"
${
f
}
"
)
file
(
TO_CMAKE_PATH
"
${
json_URL
}
"
json_URL
)
set
(
json_URL2
)
break
()
endif
()
endforeach
()
FetchContent_Declare
(
json
URL
${
json_URL
}
${
json_URL2
}
URL_HASH
${
json_HASH
}
)
FetchContent_GetProperties
(
json
)
if
(
NOT json_POPULATED
)
message
(
STATUS
"Downloading json from
${
json_URL
}
"
)
FetchContent_Populate
(
json
)
endif
()
message
(
STATUS
"json is downloaded to
${
json_SOURCE_DIR
}
"
)
include_directories
(
${
json_SOURCE_DIR
}
/include
)
# Use #include "nlohmann/json.hpp"
endfunction
()
download_json
()
...
...
sherpa-onnx/csrc/online-recognizer.cc
查看文件 @
ad05f52
...
...
@@ -8,11 +8,13 @@
#include <assert.h>
#include <algorithm>
#include <iomanip>
#include <memory>
#include <sstream>
#include <utility>
#include <vector>
#include "nlohmann/json.hpp"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
...
...
@@ -22,16 +24,56 @@
namespace
sherpa_onnx
{
std
::
string
OnlineRecognizerResult
::
AsJsonString
()
const
{
using
json
=
nlohmann
::
json
;
json
j
;
j
[
"text"
]
=
text
;
j
[
"tokens"
]
=
tokens
;
j
[
"start_time"
]
=
start_time
;
#if 1
// This branch chooses number of decimal points to keep in
// the return json string
std
::
ostringstream
os
;
os
<<
"["
;
std
::
string
sep
=
""
;
for
(
auto
t
:
timestamps
)
{
os
<<
sep
<<
std
::
fixed
<<
std
::
setprecision
(
2
)
<<
t
;
sep
=
", "
;
}
os
<<
"]"
;
j
[
"timestamps"
]
=
os
.
str
();
#else
j
[
"timestamps"
]
=
timestamps
;
#endif
j
[
"segment"
]
=
segment
;
j
[
"is_final"
]
=
is_final
;
return
j
.
dump
();
}
static
OnlineRecognizerResult
Convert
(
const
OnlineTransducerDecoderResult
&
src
,
const
SymbolTable
&
sym_table
)
{
std
::
string
text
;
for
(
auto
t
:
src
.
tokens
)
{
text
+=
sym_table
[
t
];
const
SymbolTable
&
sym_table
,
int32_t
frame_shift_ms
,
int32_t
subsampling_factor
)
{
OnlineRecognizerResult
r
;
r
.
tokens
.
reserve
(
src
.
tokens
.
size
());
r
.
timestamps
.
reserve
(
src
.
tokens
.
size
());
for
(
auto
i
:
src
.
tokens
)
{
auto
sym
=
sym_table
[
i
];
r
.
text
.
append
(
sym
);
r
.
tokens
.
push_back
(
std
::
move
(
sym
));
}
float
frame_shift_s
=
frame_shift_ms
/
1000.
*
subsampling_factor
;
for
(
auto
t
:
src
.
timestamps
)
{
float
time
=
frame_shift_s
*
t
;
r
.
timestamps
.
push_back
(
time
);
}
OnlineRecognizerResult
ans
;
ans
.
text
=
std
::
move
(
text
);
return
ans
;
return
r
;
}
void
OnlineRecognizerConfig
::
Register
(
ParseOptions
*
po
)
{
...
...
@@ -169,7 +211,10 @@ class OnlineRecognizer::Impl {
OnlineTransducerDecoderResult
decoder_result
=
s
->
GetResult
();
decoder_
->
StripLeadingBlanks
(
&
decoder_result
);
return
Convert
(
decoder_result
,
sym_
);
// TODO(fangjun): Remember to change these constants if needed
int32_t
frame_shift_ms
=
10
;
int32_t
subsampling_factor
=
4
;
return
Convert
(
decoder_result
,
sym_
,
frame_shift_ms
,
subsampling_factor
);
}
bool
IsEndpoint
(
OnlineStream
*
s
)
const
{
...
...
sherpa-onnx/csrc/online-recognizer.h
查看文件 @
ad05f52
...
...
@@ -7,6 +7,7 @@
#include <memory>
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
...
...
@@ -22,10 +23,45 @@
namespace
sherpa_onnx
{
struct
OnlineRecognizerResult
{
/// Recognition results.
/// For English, it consists of space separated words.
/// For Chinese, it consists of Chinese words without spaces.
/// Example 1: "hello world"
/// Example 2: "你好世界"
std
::
string
text
;
// TODO(fangjun): Add a method to return a json string
std
::
string
ToString
()
const
{
return
text
;
}
/// Decoded results at the token level.
/// For instance, for BPE-based models it consists of a list of BPE tokens.
std
::
vector
<
std
::
string
>
tokens
;
/// timestamps.size() == tokens.size()
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
std
::
vector
<
float
>
timestamps
;
/// ID of this segment
/// When an endpoint is detected, it is incremented
int32_t
segment
=
0
;
/// Starting frame of this segment.
/// When an endpoint is detected, it will change
float
start_time
=
0
;
/// True if this is the last segment.
bool
is_final
=
false
;
/** Return a json string.
*
* The returned string contains:
* {
* "text": "The recognition result",
* "tokens": [x, x, x],
* "timestamps": [x, x, x],
* "segment": x,
* "start_time": x,
* "is_final": true|false
* }
*/
std
::
string
AsJsonString
()
const
;
};
struct
OnlineRecognizerConfig
{
...
...
sherpa-onnx/csrc/online-transducer-decoder.cc
查看文件 @
ad05f52
...
...
@@ -34,6 +34,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
hyps
=
other
.
hyps
;
frame_offset
=
other
.
frame_offset
;
timestamps
=
other
.
timestamps
;
return
*
this
;
}
...
...
@@ -54,6 +57,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
decoder_out
=
std
::
move
(
other
.
decoder_out
);
hyps
=
std
::
move
(
other
.
hyps
);
frame_offset
=
other
.
frame_offset
;
timestamps
=
std
::
move
(
other
.
timestamps
);
return
*
this
;
}
...
...
sherpa-onnx/csrc/online-transducer-decoder.h
查看文件 @
ad05f52
...
...
@@ -13,12 +13,18 @@
namespace
sherpa_onnx
{
struct
OnlineTransducerDecoderResult
{
/// Number of frames after subsampling we have decoded so far
int32_t
frame_offset
=
0
;
/// The decoded token IDs so far
std
::
vector
<
int64_t
>
tokens
;
/// number of trailing blank frames decoded so far
int32_t
num_trailing_blanks
=
0
;
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
std
::
vector
<
int32_t
>
timestamps
;
// Cache decoder_out for endpointing
Ort
::
Value
decoder_out
;
...
...
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
查看文件 @
ad05f52
...
...
@@ -102,16 +102,18 @@ void OnlineTransducerGreedySearchDecoder::Decode(
bool
emitted
=
false
;
for
(
int32_t
i
=
0
;
i
<
batch_size
;
++
i
,
p_logit
+=
vocab_size
)
{
auto
&
r
=
(
*
result
)[
i
];
auto
y
=
static_cast
<
int32_t
>
(
std
::
distance
(
static_cast
<
const
float
*>
(
p_logit
),
std
::
max_element
(
static_cast
<
const
float
*>
(
p_logit
),
static_cast
<
const
float
*>
(
p_logit
)
+
vocab_size
)));
if
(
y
!=
0
)
{
emitted
=
true
;
(
*
result
)[
i
].
tokens
.
push_back
(
y
);
(
*
result
)[
i
].
num_trailing_blanks
=
0
;
r
.
tokens
.
push_back
(
y
);
r
.
timestamps
.
push_back
(
t
+
r
.
frame_offset
);
r
.
num_trailing_blanks
=
0
;
}
else
{
++
(
*
result
)[
i
]
.
num_trailing_blanks
;
++
r
.
num_trailing_blanks
;
}
}
if
(
emitted
)
{
...
...
@@ -121,6 +123,11 @@ void OnlineTransducerGreedySearchDecoder::Decode(
}
UpdateCachedDecoderOut
(
model_
->
Allocator
(),
&
decoder_out
,
result
);
// Update frame_offset
for
(
auto
&
r
:
*
result
)
{
r
.
frame_offset
+=
num_frames
;
}
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
查看文件 @
ad05f52
...
...
@@ -87,6 +87,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
std
::
vector
<
int64_t
>
tokens
(
hyp
.
ys
.
begin
()
+
context_size
,
hyp
.
ys
.
end
());
r
->
tokens
=
std
::
move
(
tokens
);
r
->
timestamps
=
std
::
move
(
hyp
.
timestamps
);
r
->
num_trailing_blanks
=
hyp
.
num_trailing_blanks
;
}
...
...
@@ -148,6 +149,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
float
*
p_logit
=
logit
.
GetTensorMutableData
<
float
>
();
for
(
int32_t
b
=
0
;
b
<
batch_size
;
++
b
)
{
int32_t
frame_offset
=
(
*
result
)[
b
].
frame_offset
;
int32_t
start
=
hyps_num_split
[
b
];
int32_t
end
=
hyps_num_split
[
b
+
1
];
LogSoftmax
(
p_logit
,
vocab_size
,
(
end
-
start
));
...
...
@@ -162,6 +164,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Hypothesis
new_hyp
=
prev
[
hyp_index
];
if
(
new_token
!=
0
)
{
new_hyp
.
ys
.
push_back
(
new_token
);
new_hyp
.
timestamps
.
push_back
(
t
+
frame_offset
);
new_hyp
.
num_trailing_blanks
=
0
;
}
else
{
++
new_hyp
.
num_trailing_blanks
;
...
...
@@ -177,10 +180,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
for
(
int32_t
b
=
0
;
b
!=
batch_size
;
++
b
)
{
auto
&
hyps
=
cur
[
b
];
auto
best_hyp
=
hyps
.
GetMostProbable
(
true
);
auto
&
r
=
(
*
result
)[
b
];
(
*
result
)[
b
].
hyps
=
std
::
move
(
hyps
);
(
*
result
)[
b
].
tokens
=
std
::
move
(
best_hyp
.
ys
);
(
*
result
)[
b
].
num_trailing_blanks
=
best_hyp
.
num_trailing_blanks
;
r
.
hyps
=
std
::
move
(
hyps
);
r
.
tokens
=
std
::
move
(
best_hyp
.
ys
);
r
.
num_trailing_blanks
=
best_hyp
.
num_trailing_blanks
;
r
.
frame_offset
+=
num_frames
;
}
}
...
...
sherpa-onnx/csrc/online-websocket-server-impl.cc
查看文件 @
ad05f52
...
...
@@ -196,7 +196,7 @@ void OnlineWebsocketDecoder::Decode() {
auto
result
=
recognizer_
->
GetResult
(
c
->
s
.
get
());
asio
::
post
(
server_
->
GetConnectionContext
(),
[
this
,
hdl
=
c
->
hdl
,
str
=
result
.
To
String
()]()
{
[
this
,
hdl
=
c
->
hdl
,
str
=
result
.
AsJson
String
()]()
{
server_
->
Send
(
hdl
,
str
);
});
active_
.
erase
(
c
->
hdl
);
...
...
sherpa-onnx/csrc/sherpa-onnx.cc
查看文件 @
ad05f52
...
...
@@ -102,7 +102,7 @@ for a list of pre-trained models to download.
recognizer
.
DecodeStream
(
s
.
get
());
}
std
::
string
text
=
recognizer
.
GetResult
(
s
.
get
()).
text
;
std
::
string
text
=
recognizer
.
GetResult
(
s
.
get
()).
AsJsonString
()
;
fprintf
(
stderr
,
"Done!
\n
"
);
...
...
sherpa-onnx/jni/jni.cc
查看文件 @
ad05f52
...
...
@@ -434,7 +434,7 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult(
sherpa_onnx
::
OnlineStream
*
s
=
reinterpret_cast
<
sherpa_onnx
::
OnlineStream
*>
(
s_ptr
);
sherpa_onnx
::
OnlineRecognizerResult
result
=
model
->
GetResult
(
s
);
return
env
->
NewStringUTF
(
result
.
ToString
()
.
c_str
());
return
env
->
NewStringUTF
(
result
.
text
.
c_str
());
}
SHERPA_ONNX_EXTERN_C
...
...
请
注册
或
登录
后发表评论