Fangjun Kuang
Committed by GitHub

Support Zipformer transducer ASR with whisper features. (#2321)

Adds support for Zipformer transducer ASR models that use Whisper-style 
features by introducing a new feature flag, parsing metadata, 
and integrating per-chunk normalization.

- Introduce UseWhisperFeature in the model interface and Zipformer implementation
- Parse "feature" metadata to set the whisper flag and wire it into the recognizer
- Update feature extraction logic to handle Whisper filterbanks with early returns
@@ -131,10 +131,13 @@ class FeatureExtractor::Impl { @@ -131,10 +131,13 @@ class FeatureExtractor::Impl {
131 std::lock_guard<std::mutex> lock(mutex_); 131 std::lock_guard<std::mutex> lock(mutex_);
132 if (fbank_) { 132 if (fbank_) {
133 fbank_->InputFinished(); 133 fbank_->InputFinished();
  134 + return;
134 } else if (whisper_fbank_) { 135 } else if (whisper_fbank_) {
135 whisper_fbank_->InputFinished(); 136 whisper_fbank_->InputFinished();
  137 + return;
136 } else if (mfcc_) { 138 } else if (mfcc_) {
137 mfcc_->InputFinished(); 139 mfcc_->InputFinished();
  140 + return;
138 } 141 }
139 142
140 SHERPA_ONNX_LOGE("unreachable code"); 143 SHERPA_ONNX_LOGE("unreachable code");
@@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
16 16
17 #include "sherpa-onnx/csrc/file-utils.h" 17 #include "sherpa-onnx/csrc/file-utils.h"
18 #include "sherpa-onnx/csrc/macros.h" 18 #include "sherpa-onnx/csrc/macros.h"
  19 +#include "sherpa-onnx/csrc/offline-whisper-model.h"
19 #include "sherpa-onnx/csrc/online-lm.h" 20 #include "sherpa-onnx/csrc/online-lm.h"
20 #include "sherpa-onnx/csrc/online-recognizer-impl.h" 21 #include "sherpa-onnx/csrc/online-recognizer-impl.h"
21 #include "sherpa-onnx/csrc/online-recognizer.h" 22 #include "sherpa-onnx/csrc/online-recognizer.h"
@@ -133,6 +134,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -133,6 +134,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
133 config.decoding_method.c_str()); 134 config.decoding_method.c_str());
134 exit(-1); 135 exit(-1);
135 } 136 }
  137 +
  138 + if (model_->UseWhisperFeature()) {
  139 + config_.feat_config.is_whisper = true;
  140 + }
136 } 141 }
137 142
138 template <typename Manager> 143 template <typename Manager>
@@ -182,6 +187,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -182,6 +187,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
182 config.decoding_method.c_str()); 187 config.decoding_method.c_str());
183 exit(-1); 188 exit(-1);
184 } 189 }
  190 +
  191 + if (model_->UseWhisperFeature()) {
  192 + config_.feat_config.is_whisper = true;
  193 + }
185 } 194 }
186 195
187 std::unique_ptr<OnlineStream> CreateStream() const override { 196 std::unique_ptr<OnlineStream> CreateStream() const override {
@@ -292,6 +301,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -292,6 +301,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
292 std::vector<float> features = 301 std::vector<float> features =
293 ss[i]->GetFrames(num_processed_frames, chunk_size); 302 ss[i]->GetFrames(num_processed_frames, chunk_size);
294 303
  304 + if (config_.feat_config.is_whisper) {
  305 + OfflineWhisperModel::NormalizeFeatures(features.data(), chunk_size,
  306 + feature_dim);
  307 + }
  308 +
295 // Question: should num_processed_frames include chunk_shift? 309 // Question: should num_processed_frames include chunk_shift?
296 ss[i]->GetNumProcessedFrames() += chunk_shift; 310 ss[i]->GetNumProcessedFrames() += chunk_shift;
297 311
@@ -132,6 +132,8 @@ class OnlineTransducerModel { @@ -132,6 +132,8 @@ class OnlineTransducerModel {
132 132
133 virtual int32_t SubsamplingFactor() const { return 4; } 133 virtual int32_t SubsamplingFactor() const { return 4; }
134 134
  135 + virtual bool UseWhisperFeature() const { return false; }
  136 +
135 virtual OrtAllocator *Allocator() = 0; 137 virtual OrtAllocator *Allocator() = 0;
136 138
137 Ort::Value BuildDecoderInput( 139 Ort::Value BuildDecoderInput(
@@ -120,6 +120,12 @@ void OnlineZipformer2TransducerModel::InitEncoder(void *model_data, @@ -120,6 +120,12 @@ void OnlineZipformer2TransducerModel::InitEncoder(void *model_data,
120 SHERPA_ONNX_READ_META_DATA(T_, "T"); 120 SHERPA_ONNX_READ_META_DATA(T_, "T");
121 SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); 121 SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
122 122
  123 + std::string feature_type;
  124 + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(feature_type, "feature", "");
  125 + if (feature_type == "whisper") {
  126 + use_whisper_feature_ = true;
  127 + }
  128 +
123 if (config_.debug) { 129 if (config_.debug) {
124 auto print = [](const std::vector<int32_t> &v, const char *name) { 130 auto print = [](const std::vector<int32_t> &v, const char *name) {
125 std::ostringstream os; 131 std::ostringstream os;
@@ -52,6 +52,8 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { @@ -52,6 +52,8 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {
52 int32_t VocabSize() const override { return vocab_size_; } 52 int32_t VocabSize() const override { return vocab_size_; }
53 OrtAllocator *Allocator() override { return allocator_; } 53 OrtAllocator *Allocator() override { return allocator_; }
54 54
  55 + bool UseWhisperFeature() const override { return use_whisper_feature_; }
  56 +
55 private: 57 private:
56 void InitEncoder(void *model_data, size_t model_data_length); 58 void InitEncoder(void *model_data, size_t model_data_length);
57 void InitDecoder(void *model_data, size_t model_data_length); 59 void InitDecoder(void *model_data, size_t model_data_length);
@@ -103,6 +105,10 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { @@ -103,6 +105,10 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {
103 int32_t context_size_ = 0; 105 int32_t context_size_ = 0;
104 int32_t vocab_size_ = 0; 106 int32_t vocab_size_ = 0;
105 int32_t feature_dim_ = 80; 107 int32_t feature_dim_ = 80;
  108 +
  109 + // for models from
  110 + // https://github.com/k2-fsa/icefall/blob/master/egs/multi_zh-hans/ASR/RESULTS.md#streaming-with-ctc-head
  111 + bool use_whisper_feature_ = false;
106 }; 112 };
107 113
108 } // namespace sherpa_onnx 114 } // namespace sherpa_onnx