Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2025-06-27 00:15:11 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2025-06-27 00:15:11 +0800
Commit
54bf3732d9e7b2d16345f38d95e830b7308034be
54bf3732
1 parent
282211c0
Support zipformer CTC ASR with whisper features. (#2319)
隐藏空白字符变更
内嵌
并排对比
正在显示
8 个修改的文件
包含
184 行增加
和
37 行删除
sherpa-onnx/csrc/features.cc
sherpa-onnx/csrc/features.h
sherpa-onnx/csrc/online-ctc-model.h
sherpa-onnx/csrc/online-recognizer-ctc-impl.h
sherpa-onnx/csrc/online-stream.cc
sherpa-onnx/csrc/online-zipformer2-ctc-model.cc
sherpa-onnx/csrc/online-zipformer2-ctc-model.h
sherpa-onnx/csrc/sherpa-onnx-microphone.cc
sherpa-onnx/csrc/features.cc
查看文件 @
54bf373
...
...
@@ -60,6 +60,8 @@ class FeatureExtractor::Impl {
explicit
Impl
(
const
FeatureExtractorConfig
&
config
)
:
config_
(
config
)
{
if
(
config_
.
is_mfcc
)
{
InitMfcc
();
}
else
if
(
config_
.
is_whisper
)
{
InitWhisper
();
}
else
{
InitFbank
();
}
...
...
@@ -92,13 +94,9 @@ class FeatureExtractor::Impl {
std
::
vector
<
float
>
samples
;
resampler_
->
Resample
(
waveform
,
n
,
false
,
&
samples
);
if
(
fbank_
)
{
fbank_
->
AcceptWaveform
(
config_
.
sampling_rate
,
samples
.
data
(),
samples
.
size
());
}
else
{
mfcc_
->
AcceptWaveform
(
config_
.
sampling_rate
,
samples
.
data
(),
samples
.
size
());
}
AcceptWaveformWrapper
(
config_
.
sampling_rate
,
samples
.
data
(),
samples
.
size
());
return
;
}
...
...
@@ -119,61 +117,81 @@ class FeatureExtractor::Impl {
std
::
vector
<
float
>
samples
;
resampler_
->
Resample
(
waveform
,
n
,
false
,
&
samples
);
if
(
fbank_
)
{
fbank_
->
AcceptWaveform
(
config_
.
sampling_rate
,
samples
.
data
(),
samples
.
size
());
}
else
{
mfcc_
->
AcceptWaveform
(
config_
.
sampling_rate
,
samples
.
data
(),
samples
.
size
());
}
AcceptWaveformWrapper
(
config_
.
sampling_rate
,
samples
.
data
(),
samples
.
size
());
return
;
}
if
(
fbank_
)
{
fbank_
->
AcceptWaveform
(
sampling_rate
,
waveform
,
n
);
}
else
{
mfcc_
->
AcceptWaveform
(
sampling_rate
,
waveform
,
n
);
}
AcceptWaveformWrapper
(
sampling_rate
,
waveform
,
n
);
}
void
InputFinished
()
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
fbank_
->
InputFinished
();
if
(
fbank_
)
{
fbank_
->
InputFinished
();
}
else
if
(
whisper_fbank_
)
{
whisper_fbank_
->
InputFinished
();
}
else
if
(
mfcc_
)
{
mfcc_
->
InputFinished
();
}
SHERPA_ONNX_LOGE
(
"unreachable code"
);
SHERPA_ONNX_EXIT
(
-
1
);
}
int32_t
NumFramesReady
()
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
fbank_
->
NumFramesReady
();
if
(
fbank_
)
{
return
fbank_
->
NumFramesReady
();
}
else
if
(
whisper_fbank_
)
{
return
whisper_fbank_
->
NumFramesReady
();
}
else
if
(
mfcc_
)
{
return
mfcc_
->
NumFramesReady
();
}
SHERPA_ONNX_LOGE
(
"unreachable code"
);
SHERPA_ONNX_EXIT
(
-
1
);
return
-
1
;
}
bool
IsLastFrame
(
int32_t
frame
)
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
fbank_
->
IsLastFrame
(
frame
);
if
(
fbank_
)
{
return
fbank_
->
IsLastFrame
(
frame
);
}
else
if
(
whisper_fbank_
)
{
return
whisper_fbank_
->
IsLastFrame
(
frame
);
}
else
if
(
mfcc_
)
{
return
mfcc_
->
IsLastFrame
(
frame
);
}
SHERPA_ONNX_LOGE
(
"unreachable code"
);
SHERPA_ONNX_EXIT
(
-
1
);
return
false
;
}
std
::
vector
<
float
>
GetFrames
(
int32_t
frame_index
,
int32_t
n
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
if
(
frame_index
+
n
>
fbank_
->
NumFramesReady
())
{
SHERPA_ONNX_LOGE
(
"%d + %d > %d
\n
"
,
frame_index
,
n
,
fbank_
->
NumFramesReady
());
exit
(
-
1
);
if
(
frame_index
+
n
>
NumFramesReady
())
{
SHERPA_ONNX_LOGE
(
"%d + %d > %d
\n
"
,
frame_index
,
n
,
NumFramesReady
());
SHERPA_ONNX_EXIT
(
-
1
);
}
int32_t
discard_num
=
frame_index
-
last_frame_index_
;
if
(
discard_num
<
0
)
{
SHERPA_ONNX_LOGE
(
"last_frame_index_: %d, frame_index_: %d"
,
last_frame_index_
,
frame_index
);
exit
(
-
1
);
SHERPA_ONNX_EXIT
(
-
1
);
}
fbank_
->
Pop
(
discard_num
);
int32_t
feature_dim
=
fbank_
->
Dim
();
PopWrapper
(
discard_num
);
int32_t
feature_dim
=
FeatureDim
();
std
::
vector
<
float
>
features
(
feature_dim
*
n
);
float
*
p
=
features
.
data
();
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
const
float
*
f
=
fbank_
->
GetFrame
(
i
+
frame_index
);
const
float
*
f
=
GetFrameWrapper
(
i
+
frame_index
);
std
::
copy
(
f
,
f
+
feature_dim
,
p
);
p
+=
feature_dim
;
}
...
...
@@ -184,10 +202,65 @@ class FeatureExtractor::Impl {
}
int32_t
FeatureDim
()
const
{
return
mfcc_
?
mfcc_opts_
.
num_ceps
:
opts_
.
mel_opts
.
num_bins
;
if
(
fbank_
||
whisper_fbank_
)
{
return
opts_
.
mel_opts
.
num_bins
;
}
else
if
(
mfcc_
)
{
return
mfcc_opts_
.
num_ceps
;
}
SHERPA_ONNX_LOGE
(
"unreachable code"
);
SHERPA_ONNX_EXIT
(
-
1
);
return
-
1
;
}
private
:
void
AcceptWaveformWrapper
(
float
sampling_rate
,
const
float
*
waveform
,
int32_t
n
)
const
{
if
(
fbank_
)
{
fbank_
->
AcceptWaveform
(
sampling_rate
,
waveform
,
n
);
return
;
}
else
if
(
whisper_fbank_
)
{
whisper_fbank_
->
AcceptWaveform
(
sampling_rate
,
waveform
,
n
);
return
;
}
else
if
(
mfcc_
)
{
mfcc_
->
AcceptWaveform
(
sampling_rate
,
waveform
,
n
);
return
;
}
SHERPA_ONNX_LOGE
(
"unreachable code"
);
SHERPA_ONNX_EXIT
(
-
1
);
}
const
float
*
GetFrameWrapper
(
int32_t
frame_index
)
const
{
if
(
fbank_
)
{
return
fbank_
->
GetFrame
(
frame_index
);
}
else
if
(
whisper_fbank_
)
{
return
whisper_fbank_
->
GetFrame
(
frame_index
);
}
else
if
(
mfcc_
)
{
return
mfcc_
->
GetFrame
(
frame_index
);
}
SHERPA_ONNX_LOGE
(
"unreachable code"
);
SHERPA_ONNX_EXIT
(
-
1
);
return
nullptr
;
}
void
PopWrapper
(
int32_t
discard_num
)
const
{
if
(
fbank_
)
{
fbank_
->
Pop
(
discard_num
);
return
;
}
else
if
(
whisper_fbank_
)
{
whisper_fbank_
->
Pop
(
discard_num
);
return
;
}
else
if
(
mfcc_
)
{
mfcc_
->
Pop
(
discard_num
);
return
;
}
SHERPA_ONNX_LOGE
(
"unreachable code"
);
SHERPA_ONNX_EXIT
(
-
1
);
}
void
InitFbank
()
{
opts_
.
frame_opts
.
dither
=
config_
.
dither
;
opts_
.
frame_opts
.
snip_edges
=
config_
.
snip_edges
;
...
...
@@ -208,6 +281,7 @@ class FeatureExtractor::Impl {
fbank_
=
std
::
make_unique
<
knf
::
OnlineFbank
>
(
opts_
);
}
void
InitMfcc
()
{
mfcc_opts_
.
frame_opts
.
dither
=
config_
.
dither
;
mfcc_opts_
.
frame_opts
.
snip_edges
=
config_
.
snip_edges
;
...
...
@@ -232,9 +306,23 @@ class FeatureExtractor::Impl {
mfcc_
=
std
::
make_unique
<
knf
::
OnlineMfcc
>
(
mfcc_opts_
);
}
void
InitWhisper
()
{
config_
.
normalize_samples
=
true
;
opts_
.
frame_opts
.
samp_freq
=
16000
;
opts_
.
mel_opts
.
num_bins
=
config_
.
feature_dim
;
knf
::
WhisperFeatureOptions
whisper_opts
;
whisper_opts
.
frame_opts
=
opts_
.
frame_opts
;
whisper_opts
.
dim
=
config_
.
feature_dim
;
whisper_fbank_
=
std
::
make_unique
<
knf
::
OnlineWhisperFbank
>
(
whisper_opts
);
config_
.
sampling_rate
=
opts_
.
frame_opts
.
samp_freq
;
}
private
:
std
::
unique_ptr
<
knf
::
OnlineFbank
>
fbank_
;
std
::
unique_ptr
<
knf
::
OnlineMfcc
>
mfcc_
;
std
::
unique_ptr
<
knf
::
OnlineWhisperFbank
>
whisper_fbank_
;
knf
::
FbankOptions
opts_
;
knf
::
MfccOptions
mfcc_opts_
;
FeatureExtractorConfig
config_
;
...
...
sherpa-onnx/csrc/features.h
查看文件 @
54bf373
...
...
@@ -79,6 +79,8 @@ struct FeatureExtractorConfig {
bool
is_mfcc
=
false
;
bool
is_whisper
=
false
;
bool
round_to_power_of_two
=
true
;
std
::
string
ToString
()
const
;
...
...
sherpa-onnx/csrc/online-ctc-model.h
查看文件 @
54bf373
...
...
@@ -77,6 +77,8 @@ class OnlineCtcModel {
// Return true if the model supports batch size > 1
virtual
bool
SupportBatchProcessing
()
const
{
return
true
;
}
virtual
bool
UseWhisperFeature
()
const
{
return
false
;
}
};
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/online-recognizer-ctc-impl.h
查看文件 @
54bf373
...
...
@@ -15,6 +15,7 @@
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-whisper-model.h"
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
...
...
@@ -91,6 +92,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
config_
.
feat_config
.
normalize_samples
=
false
;
}
if
(
model_
->
UseWhisperFeature
())
{
config_
.
feat_config
.
is_whisper
=
true
;
}
InitDecoder
();
}
...
...
@@ -108,6 +113,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
config_
.
feat_config
.
normalize_samples
=
false
;
}
if
(
model_
->
UseWhisperFeature
())
{
config_
.
feat_config
.
is_whisper
=
true
;
}
InitDecoder
();
}
...
...
@@ -147,6 +156,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
const
auto
num_processed_frames
=
ss
[
i
]
->
GetNumProcessedFrames
();
std
::
vector
<
float
>
features
=
ss
[
i
]
->
GetFrames
(
num_processed_frames
,
chunk_length
);
if
(
config_
.
feat_config
.
is_whisper
)
{
OfflineWhisperModel
::
NormalizeFeatures
(
features
.
data
(),
chunk_length
,
feat_dim
);
}
// Question: should num_processed_frames include chunk_shift?
ss
[
i
]
->
GetNumProcessedFrames
()
+=
chunk_shift
;
...
...
@@ -287,6 +300,12 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
const
auto
num_processed_frames
=
s
->
GetNumProcessedFrames
();
std
::
vector
<
float
>
frames
=
s
->
GetFrames
(
num_processed_frames
,
chunk_length
);
if
(
config_
.
feat_config
.
is_whisper
)
{
OfflineWhisperModel
::
NormalizeFeatures
(
frames
.
data
(),
chunk_length
,
feat_dim
);
}
s
->
GetNumProcessedFrames
()
+=
chunk_shift
;
auto
memory_info
=
...
...
sherpa-onnx/csrc/online-stream.cc
查看文件 @
54bf373
...
...
@@ -19,34 +19,51 @@ class OnlineStream::Impl {
:
feat_extractor_
(
config
),
context_graph_
(
std
::
move
(
context_graph
))
{}
void
AcceptWaveform
(
int32_t
sampling_rate
,
const
float
*
waveform
,
int32_t
n
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
feat_extractor_
.
AcceptWaveform
(
sampling_rate
,
waveform
,
n
);
}
void
InputFinished
()
const
{
feat_extractor_
.
InputFinished
();
}
void
InputFinished
()
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
feat_extractor_
.
InputFinished
();
}
int32_t
NumFramesReady
()
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
feat_extractor_
.
NumFramesReady
()
-
start_frame_index_
;
}
bool
IsLastFrame
(
int32_t
frame
)
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
feat_extractor_
.
IsLastFrame
(
frame
);
}
std
::
vector
<
float
>
GetFrames
(
int32_t
frame_index
,
int32_t
n
)
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
feat_extractor_
.
GetFrames
(
frame_index
+
start_frame_index_
,
n
);
}
void
Reset
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
// we don't reset the feature extractor
start_frame_index_
+=
num_processed_frames_
;
num_processed_frames_
=
0
;
}
int32_t
&
GetNumProcessedFrames
()
{
return
num_processed_frames_
;
}
int32_t
&
GetNumProcessedFrames
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
num_processed_frames_
;
}
int32_t
GetNumFramesSinceStart
()
const
{
return
start_frame_index_
;
}
int32_t
GetNumFramesSinceStart
()
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
start_frame_index_
;
}
int32_t
&
GetCurrentSegment
()
{
return
segment_
;
}
int32_t
&
GetCurrentSegment
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
segment_
;
}
void
SetResult
(
const
OnlineTransducerDecoderResult
&
r
)
{
result_
=
r
;
}
...
...
@@ -125,6 +142,7 @@ class OnlineStream::Impl {
private
:
FeatureExtractor
feat_extractor_
;
mutable
std
::
mutex
mutex_
;
/// For contextual-biasing
ContextGraphPtr
context_graph_
;
int32_t
num_processed_frames_
=
0
;
// before subsampling
...
...
sherpa-onnx/csrc/online-zipformer2-ctc-model.cc
查看文件 @
54bf373
...
...
@@ -74,6 +74,8 @@ class OnlineZipformer2CtcModel::Impl {
int32_t
ChunkShift
()
const
{
return
decode_chunk_len_
;
}
bool
UseWhisperFeature
()
const
{
return
use_whisper_feature_
;
}
OrtAllocator
*
Allocator
()
{
return
allocator_
;
}
// Return a vector containing 3 tensors
...
...
@@ -278,6 +280,12 @@ class OnlineZipformer2CtcModel::Impl {
SHERPA_ONNX_READ_META_DATA
(
T_
,
"T"
);
SHERPA_ONNX_READ_META_DATA
(
decode_chunk_len_
,
"decode_chunk_len"
);
std
::
string
feature_type
;
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT
(
feature_type
,
"feature"
,
""
);
if
(
feature_type
==
"whisper"
)
{
use_whisper_feature_
=
true
;
}
{
auto
shape
=
sess_
->
GetOutputTypeInfo
(
0
).
GetTensorTypeAndShapeInfo
().
GetShape
();
...
...
@@ -417,6 +425,10 @@ class OnlineZipformer2CtcModel::Impl {
int32_t
T_
=
0
;
int32_t
decode_chunk_len_
=
0
;
int32_t
vocab_size_
=
0
;
// for models from
// https://github.com/k2-fsa/icefall/blob/master/egs/multi_zh-hans/ASR/RESULTS.md#streaming-with-ctc-head
bool
use_whisper_feature_
=
false
;
};
OnlineZipformer2CtcModel
::
OnlineZipformer2CtcModel
(
...
...
@@ -447,6 +459,10 @@ int32_t OnlineZipformer2CtcModel::ChunkShift() const {
return
impl_
->
ChunkShift
();
}
bool
OnlineZipformer2CtcModel
::
UseWhisperFeature
()
const
{
return
impl_
->
UseWhisperFeature
();
}
OrtAllocator
*
OnlineZipformer2CtcModel
::
Allocator
()
const
{
return
impl_
->
Allocator
();
}
...
...
sherpa-onnx/csrc/online-zipformer2-ctc-model.h
查看文件 @
54bf373
...
...
@@ -64,6 +64,8 @@ class OnlineZipformer2CtcModel : public OnlineCtcModel {
// before we process the next chunk.
int32_t
ChunkShift
()
const
override
;
bool
UseWhisperFeature
()
const
override
;
private
:
class
Impl
;
std
::
unique_ptr
<
Impl
>
impl_
;
...
...
sherpa-onnx/csrc/sherpa-onnx-microphone.cc
查看文件 @
54bf373
...
...
@@ -130,7 +130,7 @@ for a list of pre-trained models to download.
}
if
(
!
mic
.
OpenDevice
(
device_index
,
mic_sample_rate
,
1
,
RecordCallback
,
nullptr
/* user_data */
))
{
s
.
get
()
))
{
fprintf
(
stderr
,
"portaudio error: %d
\n
"
,
device_index
);
exit
(
EXIT_FAILURE
);
}
...
...
请
注册
或
登录
后发表评论