Committed by
GitHub
Support Zipformer transducer ASR with whisper features. (#2321)
Adds support for Zipformer transducer ASR models that use Whisper-style features by introducing a new feature flag, parsing metadata, and integrating per-chunk normalization. - Introduce UseWhisperFeature in the model interface and Zipformer implementation - Parse "feature" metadata to set the whisper flag and wire it into the recognizer - Update feature extraction logic to handle Whisper filterbanks with early returns
正在显示
5 个修改的文件
包含
31 行增加
和
0 行删除
| @@ -131,10 +131,13 @@ class FeatureExtractor::Impl { | @@ -131,10 +131,13 @@ class FeatureExtractor::Impl { | ||
| 131 | std::lock_guard<std::mutex> lock(mutex_); | 131 | std::lock_guard<std::mutex> lock(mutex_); |
| 132 | if (fbank_) { | 132 | if (fbank_) { |
| 133 | fbank_->InputFinished(); | 133 | fbank_->InputFinished(); |
| 134 | + return; | ||
| 134 | } else if (whisper_fbank_) { | 135 | } else if (whisper_fbank_) { |
| 135 | whisper_fbank_->InputFinished(); | 136 | whisper_fbank_->InputFinished(); |
| 137 | + return; | ||
| 136 | } else if (mfcc_) { | 138 | } else if (mfcc_) { |
| 137 | mfcc_->InputFinished(); | 139 | mfcc_->InputFinished(); |
| 140 | + return; | ||
| 138 | } | 141 | } |
| 139 | 142 | ||
| 140 | SHERPA_ONNX_LOGE("unreachable code"); | 143 | SHERPA_ONNX_LOGE("unreachable code"); |
| @@ -16,6 +16,7 @@ | @@ -16,6 +16,7 @@ | ||
| 16 | 16 | ||
| 17 | #include "sherpa-onnx/csrc/file-utils.h" | 17 | #include "sherpa-onnx/csrc/file-utils.h" |
| 18 | #include "sherpa-onnx/csrc/macros.h" | 18 | #include "sherpa-onnx/csrc/macros.h" |
| 19 | +#include "sherpa-onnx/csrc/offline-whisper-model.h" | ||
| 19 | #include "sherpa-onnx/csrc/online-lm.h" | 20 | #include "sherpa-onnx/csrc/online-lm.h" |
| 20 | #include "sherpa-onnx/csrc/online-recognizer-impl.h" | 21 | #include "sherpa-onnx/csrc/online-recognizer-impl.h" |
| 21 | #include "sherpa-onnx/csrc/online-recognizer.h" | 22 | #include "sherpa-onnx/csrc/online-recognizer.h" |
| @@ -133,6 +134,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -133,6 +134,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 133 | config.decoding_method.c_str()); | 134 | config.decoding_method.c_str()); |
| 134 | exit(-1); | 135 | exit(-1); |
| 135 | } | 136 | } |
| 137 | + | ||
| 138 | + if (model_->UseWhisperFeature()) { | ||
| 139 | + config_.feat_config.is_whisper = true; | ||
| 140 | + } | ||
| 136 | } | 141 | } |
| 137 | 142 | ||
| 138 | template <typename Manager> | 143 | template <typename Manager> |
| @@ -182,6 +187,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -182,6 +187,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 182 | config.decoding_method.c_str()); | 187 | config.decoding_method.c_str()); |
| 183 | exit(-1); | 188 | exit(-1); |
| 184 | } | 189 | } |
| 190 | + | ||
| 191 | + if (model_->UseWhisperFeature()) { | ||
| 192 | + config_.feat_config.is_whisper = true; | ||
| 193 | + } | ||
| 185 | } | 194 | } |
| 186 | 195 | ||
| 187 | std::unique_ptr<OnlineStream> CreateStream() const override { | 196 | std::unique_ptr<OnlineStream> CreateStream() const override { |
| @@ -292,6 +301,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -292,6 +301,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 292 | std::vector<float> features = | 301 | std::vector<float> features = |
| 293 | ss[i]->GetFrames(num_processed_frames, chunk_size); | 302 | ss[i]->GetFrames(num_processed_frames, chunk_size); |
| 294 | 303 | ||
| 304 | + if (config_.feat_config.is_whisper) { | ||
| 305 | + OfflineWhisperModel::NormalizeFeatures(features.data(), chunk_size, | ||
| 306 | + feature_dim); | ||
| 307 | + } | ||
| 308 | + | ||
| 295 | // Question: should num_processed_frames include chunk_shift? | 309 | // Question: should num_processed_frames include chunk_shift? |
| 296 | ss[i]->GetNumProcessedFrames() += chunk_shift; | 310 | ss[i]->GetNumProcessedFrames() += chunk_shift; |
| 297 | 311 |
| @@ -132,6 +132,8 @@ class OnlineTransducerModel { | @@ -132,6 +132,8 @@ class OnlineTransducerModel { | ||
| 132 | 132 | ||
| 133 | virtual int32_t SubsamplingFactor() const { return 4; } | 133 | virtual int32_t SubsamplingFactor() const { return 4; } |
| 134 | 134 | ||
| 135 | + virtual bool UseWhisperFeature() const { return false; } | ||
| 136 | + | ||
| 135 | virtual OrtAllocator *Allocator() = 0; | 137 | virtual OrtAllocator *Allocator() = 0; |
| 136 | 138 | ||
| 137 | Ort::Value BuildDecoderInput( | 139 | Ort::Value BuildDecoderInput( |
| @@ -120,6 +120,12 @@ void OnlineZipformer2TransducerModel::InitEncoder(void *model_data, | @@ -120,6 +120,12 @@ void OnlineZipformer2TransducerModel::InitEncoder(void *model_data, | ||
| 120 | SHERPA_ONNX_READ_META_DATA(T_, "T"); | 120 | SHERPA_ONNX_READ_META_DATA(T_, "T"); |
| 121 | SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); | 121 | SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); |
| 122 | 122 | ||
| 123 | + std::string feature_type; | ||
| 124 | + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(feature_type, "feature", ""); | ||
| 125 | + if (feature_type == "whisper") { | ||
| 126 | + use_whisper_feature_ = true; | ||
| 127 | + } | ||
| 128 | + | ||
| 123 | if (config_.debug) { | 129 | if (config_.debug) { |
| 124 | auto print = [](const std::vector<int32_t> &v, const char *name) { | 130 | auto print = [](const std::vector<int32_t> &v, const char *name) { |
| 125 | std::ostringstream os; | 131 | std::ostringstream os; |
| @@ -52,6 +52,8 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { | @@ -52,6 +52,8 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { | ||
| 52 | int32_t VocabSize() const override { return vocab_size_; } | 52 | int32_t VocabSize() const override { return vocab_size_; } |
| 53 | OrtAllocator *Allocator() override { return allocator_; } | 53 | OrtAllocator *Allocator() override { return allocator_; } |
| 54 | 54 | ||
| 55 | + bool UseWhisperFeature() const override { return use_whisper_feature_; } | ||
| 56 | + | ||
| 55 | private: | 57 | private: |
| 56 | void InitEncoder(void *model_data, size_t model_data_length); | 58 | void InitEncoder(void *model_data, size_t model_data_length); |
| 57 | void InitDecoder(void *model_data, size_t model_data_length); | 59 | void InitDecoder(void *model_data, size_t model_data_length); |
| @@ -103,6 +105,10 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { | @@ -103,6 +105,10 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { | ||
| 103 | int32_t context_size_ = 0; | 105 | int32_t context_size_ = 0; |
| 104 | int32_t vocab_size_ = 0; | 106 | int32_t vocab_size_ = 0; |
| 105 | int32_t feature_dim_ = 80; | 107 | int32_t feature_dim_ = 80; |
| 108 | + | ||
| 109 | + // for models from | ||
| 110 | + // https://github.com/k2-fsa/icefall/blob/master/egs/multi_zh-hans/ASR/RESULTS.md#streaming-with-ctc-head | ||
| 111 | + bool use_whisper_feature_ = false; | ||
| 106 | }; | 112 | }; |
| 107 | 113 | ||
| 108 | } // namespace sherpa_onnx | 114 | } // namespace sherpa_onnx |
-
请 注册 或 登录 后发表评论