Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
ivan provalov
2024-07-13 07:30:47 -0700
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-07-13 22:30:47 +0800
Commit
de04b3b9bfc6d48a8ac340e00083d9fd5411b81e
de04b3b9
1 parent
ab71c397
Allow modify model config at decode time for ASR (#1124)
隐藏空白字符变更
内嵌
并排对比
正在显示
15 个修改的文件
包含
121 行增加
和
13 行删除
sherpa-onnx/c-api/c-api.cc
sherpa-onnx/c-api/c-api.h
sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
sherpa-onnx/csrc/offline-recognizer-impl.cc
sherpa-onnx/csrc/offline-recognizer-impl.h
sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h
sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
sherpa-onnx/csrc/offline-recognizer.cc
sherpa-onnx/csrc/offline-recognizer.h
sherpa-onnx/csrc/offline-stream.h
sherpa-onnx/csrc/offline-whisper-decoder.h
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h
sherpa-onnx/c-api/c-api.cc
查看文件 @
de04b3b
...
...
@@ -308,8 +308,27 @@ struct SherpaOnnxOfflineStream {
:
impl
(
std
::
move
(
p
))
{}
};
static
sherpa_onnx
::
OfflineRecognizerConfig
convertConfig
(
const
SherpaOnnxOfflineRecognizerConfig
*
config
);
SherpaOnnxOfflineRecognizer
*
CreateOfflineRecognizer
(
const
SherpaOnnxOfflineRecognizerConfig
*
config
)
{
sherpa_onnx
::
OfflineRecognizerConfig
recognizer_config
=
convertConfig
(
config
);
if
(
!
recognizer_config
.
Validate
())
{
SHERPA_ONNX_LOGE
(
"Errors in config"
);
return
nullptr
;
}
SherpaOnnxOfflineRecognizer
*
recognizer
=
new
SherpaOnnxOfflineRecognizer
;
recognizer
->
impl
=
std
::
make_unique
<
sherpa_onnx
::
OfflineRecognizer
>
(
recognizer_config
);
return
recognizer
;
}
sherpa_onnx
::
OfflineRecognizerConfig
convertConfig
(
const
SherpaOnnxOfflineRecognizerConfig
*
config
)
{
sherpa_onnx
::
OfflineRecognizerConfig
recognizer_config
;
recognizer_config
.
feat_config
.
sampling_rate
=
...
...
@@ -398,17 +417,15 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
SHERPA_ONNX_LOGE
(
"%s"
,
recognizer_config
.
ToString
().
c_str
());
}
if
(
!
recognizer_config
.
Validate
())
{
SHERPA_ONNX_LOGE
(
"Errors in config"
);
return
nullptr
;
}
SherpaOnnxOfflineRecognizer
*
recognizer
=
new
SherpaOnnxOfflineRecognizer
;
recognizer
->
impl
=
std
::
make_unique
<
sherpa_onnx
::
OfflineRecognizer
>
(
recognizer_config
);
return
recognizer_config
;
}
return
recognizer
;
void
SherpaOnnxOfflineRecognizerSetConfig
(
const
SherpaOnnxOfflineRecognizer
*
recognizer
,
const
SherpaOnnxOfflineRecognizerConfig
*
config
){
sherpa_onnx
::
OfflineRecognizerConfig
recognizer_config
=
convertConfig
(
config
);
recognizer
->
impl
->
SetConfig
(
recognizer_config
);
}
void
DestroyOfflineRecognizer
(
SherpaOnnxOfflineRecognizer
*
recognizer
)
{
...
...
@@ -461,6 +478,13 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
pText
[
text
.
size
()]
=
0
;
r
->
text
=
pText
;
//lang
const
auto
&
lang
=
result
.
lang
;
char
*
c_lang
=
new
char
[
lang
.
size
()
+
1
];
std
::
copy
(
lang
.
begin
(),
lang
.
end
(),
c_lang
);
c_lang
[
lang
.
size
()]
=
'\0'
;
r
->
lang
=
c_lang
;
// copy json
std
::
string
json
=
result
.
AsJsonString
();
char
*
pJson
=
new
char
[
json
.
size
()
+
1
];
...
...
@@ -517,6 +541,7 @@ void DestroyOfflineRecognizerResult(
delete
[]
r
->
tokens
;
delete
[]
r
->
tokens_arr
;
delete
[]
r
->
json
;
delete
[]
r
->
lang
;
delete
r
;
}
}
...
...
sherpa-onnx/c-api/c-api.h
查看文件 @
de04b3b
...
...
@@ -428,6 +428,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineStream SherpaOnnxOfflineStream;
SHERPA_ONNX_API
SherpaOnnxOfflineRecognizer
*
CreateOfflineRecognizer
(
const
SherpaOnnxOfflineRecognizerConfig
*
config
);
/// @param config Config for the recognizer.
SHERPA_ONNX_API
void
SherpaOnnxOfflineRecognizerSetConfig
(
const
SherpaOnnxOfflineRecognizer
*
recognizer
,
const
SherpaOnnxOfflineRecognizerConfig
*
config
);
/// Free a pointer returned by CreateOfflineRecognizer()
///
/// @param p A pointer returned by CreateOfflineRecognizer()
...
...
@@ -491,7 +496,7 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams(
SHERPA_ONNX_API
typedef
struct
SherpaOnnxOfflineRecognizerResult
{
const
char
*
text
;
// Pointer to continuous memory which holds timestamps
// Pointer to continuous memory which holds timestamps
//
// It is NULL if the model does not support timestamps
float
*
timestamps
;
...
...
@@ -519,6 +524,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
* }
*/
const
char
*
json
;
//return recognized language
const
char
*
lang
;
}
SherpaOnnxOfflineRecognizerResult
;
/// Get the result of the offline stream.
...
...
sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
查看文件 @
de04b3b
...
...
@@ -212,6 +212,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}
}
OfflineRecognizerConfig
GetConfig
()
const
override
{
return
config_
;
}
private
:
// Decode a single stream.
// Some models do not support batch size > 1, e.g., WeNet CTC models.
...
...
sherpa-onnx/csrc/offline-recognizer-impl.cc
查看文件 @
de04b3b
...
...
@@ -431,4 +431,8 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
return
text
;
}
void
OfflineRecognizerImpl
::
SetConfig
(
const
OfflineRecognizerConfig
&
config
)
{
config_
=
config
;
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-recognizer-impl.h
查看文件 @
de04b3b
...
...
@@ -48,6 +48,10 @@ class OfflineRecognizerImpl {
virtual
void
DecodeStreams
(
OfflineStream
**
ss
,
int32_t
n
)
const
=
0
;
virtual
void
SetConfig
(
const
OfflineRecognizerConfig
&
config
);
virtual
OfflineRecognizerConfig
GetConfig
()
const
=
0
;
std
::
string
ApplyInverseTextNormalization
(
std
::
string
text
)
const
;
private
:
...
...
sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h
查看文件 @
de04b3b
...
...
@@ -211,6 +211,10 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
}
}
OfflineRecognizerConfig
GetConfig
()
const
override
{
return
config_
;
}
private
:
std
::
vector
<
float
>
ApplyLFR
(
const
std
::
vector
<
float
>
&
in
)
const
{
int32_t
lfr_window_size
=
model_
->
LfrWindowSize
();
...
...
sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
查看文件 @
de04b3b
...
...
@@ -246,6 +246,11 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}
}
OfflineRecognizerConfig
GetConfig
()
const
override
{
return
config_
;
}
void
InitHotwords
()
{
// each line in hotwords_file contains space-separated words
...
...
sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
查看文件 @
de04b3b
...
...
@@ -139,6 +139,10 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
}
}
OfflineRecognizerConfig
GetConfig
()
const
override
{
return
config_
;
}
private
:
void
PostInit
()
{
config_
.
feat_config
.
nemo_normalize_type
=
...
...
sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
查看文件 @
de04b3b
...
...
@@ -45,6 +45,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
}
r
.
text
=
text
;
r
.
lang
=
src
.
lang
;
return
r
;
}
...
...
@@ -100,8 +101,18 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
}
}
void
SetConfig
(
const
OfflineRecognizerConfig
&
config
)
override
{
config_
.
model_config
.
whisper
=
config
.
model_config
.
whisper
;
}
OfflineRecognizerConfig
GetConfig
()
const
override
{
return
config_
;
}
private
:
void
DecodeStream
(
OfflineStream
*
s
)
const
{
decoder_
->
SetConfig
(
config_
.
model_config
.
whisper
);
int32_t
max_num_frames
=
3000
;
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
...
...
sherpa-onnx/csrc/offline-recognizer.cc
查看文件 @
de04b3b
...
...
@@ -156,4 +156,12 @@ void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const {
impl_
->
DecodeStreams
(
ss
,
n
);
}
void
OfflineRecognizer
::
SetConfig
(
const
OfflineRecognizerConfig
&
config
)
{
impl_
->
SetConfig
(
config
);
}
OfflineRecognizerConfig
OfflineRecognizer
::
GetConfig
()
const
{
return
impl_
->
GetConfig
();
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-recognizer.h
查看文件 @
de04b3b
...
...
@@ -119,6 +119,15 @@ class OfflineRecognizer {
*/
void
DecodeStreams
(
OfflineStream
**
ss
,
int32_t
n
)
const
;
/** Onnxruntime Session objects are not affected by this method.
* The exact behavior can be defined by a specific recognizer impl.
* For instance, for the whisper recognizer, you can retrieve the language and task from
* the config and ignore any remaining fields in `config`.
*/
void
SetConfig
(
const
OfflineRecognizerConfig
&
config
);
OfflineRecognizerConfig
GetConfig
()
const
;
private
:
std
::
unique_ptr
<
OfflineRecognizerImpl
>
impl_
;
};
...
...
sherpa-onnx/csrc/offline-stream.h
查看文件 @
de04b3b
...
...
@@ -26,7 +26,9 @@ struct OfflineRecognitionResult {
// For instance, for BPE-based models it consists of a list of BPE tokens.
std
::
vector
<
std
::
string
>
tokens
;
/// timestamps.size() == tokens.size()
std
::
string
lang
;
/// timestamps.size() == tokens.size()
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
std
::
vector
<
float
>
timestamps
;
...
...
sherpa-onnx/csrc/offline-whisper-decoder.h
查看文件 @
de04b3b
...
...
@@ -6,14 +6,17 @@
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
#include <vector>
#include <string>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
namespace
sherpa_onnx
{
struct
OfflineWhisperDecoderResult
{
/// The decoded token IDs
std
::
vector
<
int32_t
>
tokens
;
std
::
string
lang
;
};
class
OfflineWhisperDecoder
{
...
...
@@ -31,6 +34,9 @@ class OfflineWhisperDecoder {
*/
virtual
std
::
vector
<
OfflineWhisperDecoderResult
>
Decode
(
Ort
::
Value
n_layer_cross_k
,
Ort
::
Value
n_layer_cross_v
)
=
0
;
virtual
void
SetConfig
(
const
OfflineWhisperModelConfig
&
config
)
=
0
;
};
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
查看文件 @
de04b3b
...
...
@@ -12,6 +12,10 @@
namespace
sherpa_onnx
{
void
OfflineWhisperGreedySearchDecoder
::
SetConfig
(
const
OfflineWhisperModelConfig
&
config
)
{
config_
=
config
;
}
std
::
vector
<
OfflineWhisperDecoderResult
>
OfflineWhisperGreedySearchDecoder
::
Decode
(
Ort
::
Value
cross_k
,
Ort
::
Value
cross_v
)
{
...
...
@@ -129,6 +133,13 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
std
::
vector
<
OfflineWhisperDecoderResult
>
ans
(
1
);
const
auto
&
id2lang
=
model_
->
GetID2Lang
();
if
(
id2lang
.
count
(
initial_tokens
[
1
]))
{
ans
[
0
].
lang
=
id2lang
.
at
(
initial_tokens
[
1
]);
}
else
{
ans
[
0
].
lang
=
""
;
}
ans
[
0
].
tokens
=
std
::
move
(
predicted_tokens
);
return
ans
;
...
...
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h
查看文件 @
de04b3b
...
...
@@ -8,7 +8,6 @@
#include <vector>
#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
#include "sherpa-onnx/csrc/offline-whisper-model.h"
namespace
sherpa_onnx
{
...
...
@@ -22,6 +21,8 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
std
::
vector
<
OfflineWhisperDecoderResult
>
Decode
(
Ort
::
Value
cross_k
,
Ort
::
Value
cross_v
)
override
;
void
SetConfig
(
const
OfflineWhisperModelConfig
&
config
)
override
;
private
:
OfflineWhisperModelConfig
config_
;
OfflineWhisperModel
*
model_
;
// not owned
...
...
请
注册
或
登录
后发表评论