Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
xinhecuican
2024-03-11 10:23:46 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-03-11 10:23:46 +0800
Commit
f43139e803d5084504d5c6368aa7dc01a48ce917
f43139e8
1 parent
1777a5dd
c++ api for keyword spotter (#642)
隐藏空白字符变更
内嵌
并排对比
正在显示
2 个修改的文件
包含
301 行增加
和
0 行删除
sherpa-onnx/c-api/c-api.cc
sherpa-onnx/c-api/c-api.h
sherpa-onnx/c-api/c-api.cc
查看文件 @
f43139e
...
...
@@ -17,6 +17,7 @@
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/voice-activity-detector.h"
#include "sherpa-onnx/csrc/wave-writer.h"
#include "sherpa-onnx/csrc/keyword-spotter.h"
struct
SherpaOnnxOnlineRecognizer
{
std
::
unique_ptr
<
sherpa_onnx
::
OnlineRecognizer
>
impl
;
...
...
@@ -411,6 +412,189 @@ void DestroyOfflineRecognizerResult(
}
// ============================================================
// For Keyword Spot
// ============================================================
struct
SherpaOnnxKeywordSpotter
{
std
::
unique_ptr
<
sherpa_onnx
::
KeywordSpotter
>
impl
;
};
SherpaOnnxKeywordSpotter
*
CreateKeywordSpotter
(
const
SherpaOnnxKeywordSpotterConfig
*
config
)
{
sherpa_onnx
::
KeywordSpotterConfig
spotter_config
;
spotter_config
.
feat_config
.
sampling_rate
=
SHERPA_ONNX_OR
(
config
->
feat_config
.
sample_rate
,
16000
);
spotter_config
.
feat_config
.
feature_dim
=
SHERPA_ONNX_OR
(
config
->
feat_config
.
feature_dim
,
80
);
spotter_config
.
model_config
.
transducer
.
encoder
=
SHERPA_ONNX_OR
(
config
->
model_config
.
transducer
.
encoder
,
""
);
spotter_config
.
model_config
.
transducer
.
decoder
=
SHERPA_ONNX_OR
(
config
->
model_config
.
transducer
.
decoder
,
""
);
spotter_config
.
model_config
.
transducer
.
joiner
=
SHERPA_ONNX_OR
(
config
->
model_config
.
transducer
.
joiner
,
""
);
spotter_config
.
model_config
.
paraformer
.
encoder
=
SHERPA_ONNX_OR
(
config
->
model_config
.
paraformer
.
encoder
,
""
);
spotter_config
.
model_config
.
paraformer
.
decoder
=
SHERPA_ONNX_OR
(
config
->
model_config
.
paraformer
.
decoder
,
""
);
spotter_config
.
model_config
.
zipformer2_ctc
.
model
=
SHERPA_ONNX_OR
(
config
->
model_config
.
zipformer2_ctc
.
model
,
""
);
spotter_config
.
model_config
.
tokens
=
SHERPA_ONNX_OR
(
config
->
model_config
.
tokens
,
""
);
spotter_config
.
model_config
.
num_threads
=
SHERPA_ONNX_OR
(
config
->
model_config
.
num_threads
,
1
);
spotter_config
.
model_config
.
provider
=
SHERPA_ONNX_OR
(
config
->
model_config
.
provider
,
"cpu"
);
spotter_config
.
model_config
.
model_type
=
SHERPA_ONNX_OR
(
config
->
model_config
.
model_type
,
""
);
spotter_config
.
model_config
.
debug
=
SHERPA_ONNX_OR
(
config
->
model_config
.
debug
,
0
);
spotter_config
.
max_active_paths
=
SHERPA_ONNX_OR
(
config
->
max_active_paths
,
4
);
spotter_config
.
num_trailing_blanks
=
SHERPA_ONNX_OR
(
config
->
num_trailing_blanks
,
1
);
spotter_config
.
keywords_score
=
SHERPA_ONNX_OR
(
config
->
keywords_score
,
1.0
);
spotter_config
.
keywords_threshold
=
SHERPA_ONNX_OR
(
config
->
keywords_threshold
,
0.25
);
spotter_config
.
keywords_file
=
SHERPA_ONNX_OR
(
config
->
keywords_file
,
""
);
if
(
config
->
model_config
.
debug
)
{
SHERPA_ONNX_LOGE
(
"%s
\n
"
,
spotter_config
.
ToString
().
c_str
());
}
if
(
!
spotter_config
.
Validate
())
{
SHERPA_ONNX_LOGE
(
"Errors in config!"
);
return
nullptr
;
}
SherpaOnnxKeywordSpotter
*
spotter
=
new
SherpaOnnxKeywordSpotter
;
spotter
->
impl
=
std
::
make_unique
<
sherpa_onnx
::
KeywordSpotter
>
(
spotter_config
);
return
spotter
;
}
void
DestroyKeywordSpotter
(
SherpaOnnxKeywordSpotter
*
spotter
)
{
delete
spotter
;
}
SherpaOnnxOnlineStream
*
CreateKeywordStream
(
const
SherpaOnnxKeywordSpotter
*
spotter
)
{
SherpaOnnxOnlineStream
*
stream
=
new
SherpaOnnxOnlineStream
(
spotter
->
impl
->
CreateStream
());
return
stream
;
}
int32_t
IsKeywordStreamReady
(
SherpaOnnxKeywordSpotter
*
spotter
,
SherpaOnnxOnlineStream
*
stream
)
{
return
spotter
->
impl
->
IsReady
(
stream
->
impl
.
get
());
}
void
DecodeKeywordStream
(
SherpaOnnxKeywordSpotter
*
spotter
,
SherpaOnnxOnlineStream
*
stream
)
{
return
spotter
->
impl
->
DecodeStream
(
stream
->
impl
.
get
());
}
void
DecodeMultipleKeywordStreams
(
SherpaOnnxKeywordSpotter
*
spotter
,
SherpaOnnxOnlineStream
**
streams
,
int32_t
n
)
{
std
::
vector
<
sherpa_onnx
::
OnlineStream
*>
ss
(
n
);
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
ss
[
i
]
=
streams
[
i
]
->
impl
.
get
();
}
spotter
->
impl
->
DecodeStreams
(
ss
.
data
(),
n
);
}
const
SherpaOnnxKeywordResult
*
GetKeywordResult
(
SherpaOnnxKeywordSpotter
*
spotter
,
SherpaOnnxOnlineStream
*
stream
)
{
const
sherpa_onnx
::
KeywordResult
&
result
=
spotter
->
impl
->
GetResult
(
stream
->
impl
.
get
());
const
auto
&
keyword
=
result
.
keyword
;
auto
r
=
new
SherpaOnnxKeywordResult
;
memset
(
r
,
0
,
sizeof
(
SherpaOnnxKeywordResult
));
r
->
start_time
=
result
.
start_time
;
// copy keyword
r
->
keyword
=
new
char
[
keyword
.
size
()
+
1
];
std
::
copy
(
keyword
.
begin
(),
keyword
.
end
(),
const_cast
<
char
*>
(
r
->
keyword
));
const_cast
<
char
*>
(
r
->
keyword
)[
keyword
.
size
()]
=
0
;
// copy json
const
auto
&
json
=
result
.
AsJsonString
();
r
->
json
=
new
char
[
json
.
size
()
+
1
];
std
::
copy
(
json
.
begin
(),
json
.
end
(),
const_cast
<
char
*>
(
r
->
json
));
const_cast
<
char
*>
(
r
->
json
)[
json
.
size
()]
=
0
;
// copy tokens
auto
count
=
result
.
tokens
.
size
();
if
(
count
>
0
)
{
size_t
total_length
=
0
;
for
(
const
auto
&
token
:
result
.
tokens
)
{
// +1 for the null character at the end of each token
total_length
+=
token
.
size
()
+
1
;
}
r
->
count
=
count
;
// Each word ends with nullptr
r
->
tokens
=
new
char
[
total_length
];
memset
(
reinterpret_cast
<
void
*>
(
const_cast
<
char
*>
(
r
->
tokens
)),
0
,
total_length
);
char
**
tokens_temp
=
new
char
*
[
r
->
count
];
int32_t
pos
=
0
;
for
(
int32_t
i
=
0
;
i
<
r
->
count
;
++
i
)
{
tokens_temp
[
i
]
=
const_cast
<
char
*>
(
r
->
tokens
)
+
pos
;
memcpy
(
reinterpret_cast
<
void
*>
(
const_cast
<
char
*>
(
r
->
tokens
+
pos
)),
result
.
tokens
[
i
].
c_str
(),
result
.
tokens
[
i
].
size
());
// +1 to move past the null character
pos
+=
result
.
tokens
[
i
].
size
()
+
1
;
}
r
->
tokens_arr
=
tokens_temp
;
if
(
!
result
.
timestamps
.
empty
())
{
r
->
timestamps
=
new
float
[
result
.
timestamps
.
size
()];
std
::
copy
(
result
.
timestamps
.
begin
(),
result
.
timestamps
.
end
(),
r
->
timestamps
);
}
else
{
r
->
timestamps
=
nullptr
;
}
}
else
{
r
->
count
=
0
;
r
->
timestamps
=
nullptr
;
r
->
tokens
=
nullptr
;
r
->
tokens_arr
=
nullptr
;
}
return
r
;
}
void
DestroyKeywordResult
(
const
SherpaOnnxKeywordResult
*
r
)
{
if
(
r
)
{
delete
[]
r
->
keyword
;
delete
[]
r
->
json
;
delete
[]
r
->
tokens
;
delete
[]
r
->
tokens_arr
;
delete
[]
r
->
timestamps
;
delete
r
;
}
}
// ============================================================
// For VAD
// ============================================================
//
...
...
sherpa-onnx/c-api/c-api.h
查看文件 @
f43139e
...
...
@@ -474,6 +474,123 @@ SHERPA_ONNX_API void DestroyOfflineRecognizerResult(
const
SherpaOnnxOfflineRecognizerResult
*
r
);
// ============================================================
// For Keyword Spot
// ============================================================
SHERPA_ONNX_API
typedef
struct
SherpaOnnxKeywordResult
{
/// The triggered keyword.
/// For English, it consists of space separated words.
/// For Chinese, it consists of Chinese words without spaces.
/// Example 1: "hello world"
/// Example 2: "你好世界"
const
char
*
keyword
;
/// Decoded results at the token level.
/// For instance, for BPE-based models it consists of a list of BPE tokens.
const
char
*
tokens
;
const
char
*
const
*
tokens_arr
;
int32_t
count
;
/// timestamps.size() == tokens.size()
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
float
*
timestamps
;
/// Starting time of this segment.
/// When an endpoint is detected, it will change
float
start_time
;
/** Return a json string.
*
* The returned string contains:
* {
* "keyword": "The triggered keyword",
* "tokens": [x, x, x],
* "timestamps": [x, x, x],
* "start_time": x,
* }
*/
const
char
*
json
;
}
SherpaOnnxKeywordResult
;
SHERPA_ONNX_API
typedef
struct
SherpaOnnxKeywordSpotterConfig
{
SherpaOnnxFeatureConfig
feat_config
;
SherpaOnnxOnlineModelConfig
model_config
;
int32_t
max_active_paths
;
int32_t
num_trailing_blanks
;
float
keywords_score
;
float
keywords_threshold
;
const
char
*
keywords_file
;
}
SherpaOnnxKeywordSpotterConfig
;
SHERPA_ONNX_API
typedef
struct
SherpaOnnxKeywordSpotter
SherpaOnnxKeywordSpotter
;
/// @param config Config for the keyword spotter.
/// @return Return a pointer to the spotter. The user has to invoke
/// DestroyKeywordSpotter() to free it to avoid memory leak.
SHERPA_ONNX_API
SherpaOnnxKeywordSpotter
*
CreateKeywordSpotter
(
const
SherpaOnnxKeywordSpotterConfig
*
config
);
/// Free a pointer returned by CreateKeywordSpotter()
///
/// @param p A pointer returned by CreateKeywordSpotter()
SHERPA_ONNX_API
void
DestroyKeywordSpotter
(
SherpaOnnxKeywordSpotter
*
spotter
);
/// Create an online stream for accepting wave samples.
///
/// @param spotter A pointer returned by CreateKeywordSpotter()
/// @return Return a pointer to an OnlineStream. The user has to invoke
/// DestroyOnlineStream() to free it to avoid memory leak.
SHERPA_ONNX_API
SherpaOnnxOnlineStream
*
CreateKeywordStream
(
const
SherpaOnnxKeywordSpotter
*
spotter
);
/// Return 1 if there are enough number of feature frames for decoding.
/// Return 0 otherwise.
///
/// @param spotter A pointer returned by CreateKeywordSpotter
/// @param stream A pointer returned by CreateKeywordStream
SHERPA_ONNX_API
int32_t
IsKeywordStreamReady
(
SherpaOnnxKeywordSpotter
*
spotter
,
SherpaOnnxOnlineStream
*
stream
);
/// Call this function to run the neural network model and decoding.
//
/// Precondition for this function: IsKeywordStreamReady() MUST return 1.
SHERPA_ONNX_API
void
DecodeKeywordStream
(
SherpaOnnxKeywordSpotter
*
spotter
,
SherpaOnnxOnlineStream
*
stream
);
/// This function is similar to DecodeKeywordStream(). It decodes multiple
/// OnlineStream in parallel.
///
/// Caution: The caller has to ensure each OnlineStream is ready, i.e.,
/// IsKeywordStreamReady() for that stream should return 1.
///
/// @param spotter A pointer returned by CreateKeywordSpotter()
/// @param streams A pointer array containing pointers returned by
/// CreateKeywordStream()
/// @param n Number of elements in the given streams array.
SHERPA_ONNX_API
void
DecodeMultipleKeywordStreams
(
SherpaOnnxKeywordSpotter
*
spotter
,
SherpaOnnxOnlineStream
**
streams
,
int32_t
n
);
/// Get the decoding results so far for an OnlineStream.
///
/// @param recognizer A pointer returned by CreateKeywordSpotter().
/// @param stream A pointer returned by CreateKeywordStream().
/// @return A pointer containing the result. The user has to invoke
/// DestroyKeywordResult() to free the returned pointer to
/// avoid memory leak.
SHERPA_ONNX_API
const
SherpaOnnxKeywordResult
*
GetKeywordResult
(
SherpaOnnxKeywordSpotter
*
spotter
,
SherpaOnnxOnlineStream
*
stream
);
/// Destroy the pointer returned by GetKeywordResult().
///
/// @param r A pointer returned by GetKeywordResult()
SHERPA_ONNX_API
void
DestroyKeywordResult
(
const
SherpaOnnxKeywordResult
*
r
);
// ============================================================
// For VAD
// ============================================================
...
...
请
注册
或
登录
后发表评论