Fangjun Kuang
Committed by GitHub

Support Zipformer transducer ASR with whisper features. (#2321)

Adds support for Zipformer transducer ASR models that use Whisper-style 
features by introducing a new feature flag, parsing metadata, 
and integrating per-chunk normalization.

- Introduce UseWhisperFeature in the model interface and Zipformer implementation
- Parse "feature" metadata to set the whisper flag and wire it into the recognizer
- Update feature extraction logic to handle Whisper filterbanks with early returns
... ... @@ -131,10 +131,13 @@ class FeatureExtractor::Impl {
std::lock_guard<std::mutex> lock(mutex_);
if (fbank_) {
fbank_->InputFinished();
return;
} else if (whisper_fbank_) {
whisper_fbank_->InputFinished();
return;
} else if (mfcc_) {
mfcc_->InputFinished();
return;
}
SHERPA_ONNX_LOGE("unreachable code");
... ...
... ... @@ -16,6 +16,7 @@
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-whisper-model.h"
#include "sherpa-onnx/csrc/online-lm.h"
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
... ... @@ -133,6 +134,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
config.decoding_method.c_str());
exit(-1);
}
if (model_->UseWhisperFeature()) {
config_.feat_config.is_whisper = true;
}
}
template <typename Manager>
... ... @@ -182,6 +187,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
config.decoding_method.c_str());
exit(-1);
}
if (model_->UseWhisperFeature()) {
config_.feat_config.is_whisper = true;
}
}
std::unique_ptr<OnlineStream> CreateStream() const override {
... ... @@ -292,6 +301,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
std::vector<float> features =
ss[i]->GetFrames(num_processed_frames, chunk_size);
if (config_.feat_config.is_whisper) {
OfflineWhisperModel::NormalizeFeatures(features.data(), chunk_size,
feature_dim);
}
// Question: should num_processed_frames include chunk_shift?
ss[i]->GetNumProcessedFrames() += chunk_shift;
... ...
... ... @@ -132,6 +132,8 @@ class OnlineTransducerModel {
virtual int32_t SubsamplingFactor() const { return 4; }
virtual bool UseWhisperFeature() const { return false; }
virtual OrtAllocator *Allocator() = 0;
Ort::Value BuildDecoderInput(
... ...
... ... @@ -120,6 +120,12 @@ void OnlineZipformer2TransducerModel::InitEncoder(void *model_data,
SHERPA_ONNX_READ_META_DATA(T_, "T");
SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
std::string feature_type;
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(feature_type, "feature", "");
if (feature_type == "whisper") {
use_whisper_feature_ = true;
}
if (config_.debug) {
auto print = [](const std::vector<int32_t> &v, const char *name) {
std::ostringstream os;
... ...
... ... @@ -52,6 +52,8 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {
int32_t VocabSize() const override { return vocab_size_; }
OrtAllocator *Allocator() override { return allocator_; }
bool UseWhisperFeature() const override { return use_whisper_feature_; }
private:
void InitEncoder(void *model_data, size_t model_data_length);
void InitDecoder(void *model_data, size_t model_data_length);
... ... @@ -103,6 +105,10 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {
int32_t context_size_ = 0;
int32_t vocab_size_ = 0;
int32_t feature_dim_ = 80;
// for models from
// https://github.com/k2-fsa/icefall/blob/master/egs/multi_zh-hans/ASR/RESULTS.md#streaming-with-ctc-head
bool use_whisper_feature_ = false;
};
} // namespace sherpa_onnx
... ...