Fangjun Kuang
Committed by GitHub

Support zipformer CTC ASR with whisper features. (#2319)

... ... @@ -60,6 +60,8 @@ class FeatureExtractor::Impl {
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
if (config_.is_mfcc) {
InitMfcc();
} else if (config_.is_whisper) {
InitWhisper();
} else {
InitFbank();
}
... ... @@ -92,13 +94,9 @@ class FeatureExtractor::Impl {
std::vector<float> samples;
resampler_->Resample(waveform, n, false, &samples);
if (fbank_) {
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
} else {
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
}
AcceptWaveformWrapper(config_.sampling_rate, samples.data(),
samples.size());
return;
}
... ... @@ -119,61 +117,81 @@ class FeatureExtractor::Impl {
std::vector<float> samples;
resampler_->Resample(waveform, n, false, &samples);
if (fbank_) {
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
} else {
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
}
AcceptWaveformWrapper(config_.sampling_rate, samples.data(),
samples.size());
return;
}
if (fbank_) {
fbank_->AcceptWaveform(sampling_rate, waveform, n);
} else {
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
}
AcceptWaveformWrapper(sampling_rate, waveform, n);
}
void InputFinished() const {
std::lock_guard<std::mutex> lock(mutex_);
fbank_->InputFinished();
if (fbank_) {
fbank_->InputFinished();
} else if (whisper_fbank_) {
whisper_fbank_->InputFinished();
} else if (mfcc_) {
mfcc_->InputFinished();
}
SHERPA_ONNX_LOGE("unreachable code");
SHERPA_ONNX_EXIT(-1);
}
int32_t NumFramesReady() const {
std::lock_guard<std::mutex> lock(mutex_);
return fbank_->NumFramesReady();
if (fbank_) {
return fbank_->NumFramesReady();
} else if (whisper_fbank_) {
return whisper_fbank_->NumFramesReady();
} else if (mfcc_) {
return mfcc_->NumFramesReady();
}
SHERPA_ONNX_LOGE("unreachable code");
SHERPA_ONNX_EXIT(-1);
return -1;
}
bool IsLastFrame(int32_t frame) const {
std::lock_guard<std::mutex> lock(mutex_);
return fbank_->IsLastFrame(frame);
if (fbank_) {
return fbank_->IsLastFrame(frame);
} else if (whisper_fbank_) {
return whisper_fbank_->IsLastFrame(frame);
} else if (mfcc_) {
return mfcc_->IsLastFrame(frame);
}
SHERPA_ONNX_LOGE("unreachable code");
SHERPA_ONNX_EXIT(-1);
return false;
}
std::vector<float> GetFrames(int32_t frame_index, int32_t n) {
std::lock_guard<std::mutex> lock(mutex_);
if (frame_index + n > fbank_->NumFramesReady()) {
SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n,
fbank_->NumFramesReady());
exit(-1);
if (frame_index + n > NumFramesReady()) {
SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n, NumFramesReady());
SHERPA_ONNX_EXIT(-1);
}
int32_t discard_num = frame_index - last_frame_index_;
if (discard_num < 0) {
SHERPA_ONNX_LOGE("last_frame_index_: %d, frame_index_: %d",
last_frame_index_, frame_index);
exit(-1);
SHERPA_ONNX_EXIT(-1);
}
fbank_->Pop(discard_num);
int32_t feature_dim = fbank_->Dim();
PopWrapper(discard_num);
int32_t feature_dim = FeatureDim();
std::vector<float> features(feature_dim * n);
float *p = features.data();
for (int32_t i = 0; i != n; ++i) {
const float *f = fbank_->GetFrame(i + frame_index);
const float *f = GetFrameWrapper(i + frame_index);
std::copy(f, f + feature_dim, p);
p += feature_dim;
}
... ... @@ -184,10 +202,65 @@ class FeatureExtractor::Impl {
}
int32_t FeatureDim() const {
return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
if (fbank_ || whisper_fbank_) {
return opts_.mel_opts.num_bins;
} else if (mfcc_) {
return mfcc_opts_.num_ceps;
}
SHERPA_ONNX_LOGE("unreachable code");
SHERPA_ONNX_EXIT(-1);
return -1;
}
private:
void AcceptWaveformWrapper(float sampling_rate, const float *waveform,
int32_t n) const {
if (fbank_) {
fbank_->AcceptWaveform(sampling_rate, waveform, n);
return;
} else if (whisper_fbank_) {
whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
return;
} else if (mfcc_) {
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
return;
}
SHERPA_ONNX_LOGE("unreachable code");
SHERPA_ONNX_EXIT(-1);
}
const float *GetFrameWrapper(int32_t frame_index) const {
if (fbank_) {
return fbank_->GetFrame(frame_index);
} else if (whisper_fbank_) {
return whisper_fbank_->GetFrame(frame_index);
} else if (mfcc_) {
return mfcc_->GetFrame(frame_index);
}
SHERPA_ONNX_LOGE("unreachable code");
SHERPA_ONNX_EXIT(-1);
return nullptr;
}
void PopWrapper(int32_t discard_num) const {
if (fbank_) {
fbank_->Pop(discard_num);
return;
} else if (whisper_fbank_) {
whisper_fbank_->Pop(discard_num);
return;
} else if (mfcc_) {
mfcc_->Pop(discard_num);
return;
}
SHERPA_ONNX_LOGE("unreachable code");
SHERPA_ONNX_EXIT(-1);
}
void InitFbank() {
opts_.frame_opts.dither = config_.dither;
opts_.frame_opts.snip_edges = config_.snip_edges;
... ... @@ -208,6 +281,7 @@ class FeatureExtractor::Impl {
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
void InitMfcc() {
mfcc_opts_.frame_opts.dither = config_.dither;
mfcc_opts_.frame_opts.snip_edges = config_.snip_edges;
... ... @@ -232,9 +306,23 @@ class FeatureExtractor::Impl {
mfcc_ = std::make_unique<knf::OnlineMfcc>(mfcc_opts_);
}
void InitWhisper() {
config_.normalize_samples = true;
opts_.frame_opts.samp_freq = 16000;
opts_.mel_opts.num_bins = config_.feature_dim;
knf::WhisperFeatureOptions whisper_opts;
whisper_opts.frame_opts = opts_.frame_opts;
whisper_opts.dim = config_.feature_dim;
whisper_fbank_ = std::make_unique<knf::OnlineWhisperFbank>(whisper_opts);
config_.sampling_rate = opts_.frame_opts.samp_freq;
}
private:
std::unique_ptr<knf::OnlineFbank> fbank_;
std::unique_ptr<knf::OnlineMfcc> mfcc_;
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
knf::FbankOptions opts_;
knf::MfccOptions mfcc_opts_;
FeatureExtractorConfig config_;
... ...
... ... @@ -79,6 +79,8 @@ struct FeatureExtractorConfig {
bool is_mfcc = false;
bool is_whisper = false;
bool round_to_power_of_two = true;
std::string ToString() const;
... ...
... ... @@ -77,6 +77,8 @@ class OnlineCtcModel {
// Return true if the model supports batch size > 1
virtual bool SupportBatchProcessing() const { return true; }
virtual bool UseWhisperFeature() const { return false; }
};
} // namespace sherpa_onnx
... ...
... ... @@ -15,6 +15,7 @@
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-whisper-model.h"
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
... ... @@ -91,6 +92,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
config_.feat_config.normalize_samples = false;
}
if (model_->UseWhisperFeature()) {
config_.feat_config.is_whisper = true;
}
InitDecoder();
}
... ... @@ -108,6 +113,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
config_.feat_config.normalize_samples = false;
}
if (model_->UseWhisperFeature()) {
config_.feat_config.is_whisper = true;
}
InitDecoder();
}
... ... @@ -147,6 +156,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
std::vector<float> features =
ss[i]->GetFrames(num_processed_frames, chunk_length);
if (config_.feat_config.is_whisper) {
OfflineWhisperModel::NormalizeFeatures(features.data(), chunk_length,
feat_dim);
}
// Question: should num_processed_frames include chunk_shift?
ss[i]->GetNumProcessedFrames() += chunk_shift;
... ... @@ -287,6 +300,12 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
const auto num_processed_frames = s->GetNumProcessedFrames();
std::vector<float> frames =
s->GetFrames(num_processed_frames, chunk_length);
if (config_.feat_config.is_whisper) {
OfflineWhisperModel::NormalizeFeatures(frames.data(), chunk_length,
feat_dim);
}
s->GetNumProcessedFrames() += chunk_shift;
auto memory_info =
... ...
... ... @@ -19,34 +19,51 @@ class OnlineStream::Impl {
: feat_extractor_(config), context_graph_(std::move(context_graph)) {}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
std::lock_guard<std::mutex> lock(mutex_);
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
}
void InputFinished() const { feat_extractor_.InputFinished(); }
void InputFinished() const {
std::lock_guard<std::mutex> lock(mutex_);
feat_extractor_.InputFinished();
}
int32_t NumFramesReady() const {
std::lock_guard<std::mutex> lock(mutex_);
return feat_extractor_.NumFramesReady() - start_frame_index_;
}
bool IsLastFrame(int32_t frame) const {
std::lock_guard<std::mutex> lock(mutex_);
return feat_extractor_.IsLastFrame(frame);
}
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
std::lock_guard<std::mutex> lock(mutex_);
return feat_extractor_.GetFrames(frame_index + start_frame_index_, n);
}
void Reset() {
std::lock_guard<std::mutex> lock(mutex_);
// we don't reset the feature extractor
start_frame_index_ += num_processed_frames_;
num_processed_frames_ = 0;
}
int32_t &GetNumProcessedFrames() { return num_processed_frames_; }
int32_t &GetNumProcessedFrames() {
std::lock_guard<std::mutex> lock(mutex_);
return num_processed_frames_;
}
int32_t GetNumFramesSinceStart() const { return start_frame_index_; }
int32_t GetNumFramesSinceStart() const {
std::lock_guard<std::mutex> lock(mutex_);
return start_frame_index_;
}
int32_t &GetCurrentSegment() { return segment_; }
int32_t &GetCurrentSegment() {
std::lock_guard<std::mutex> lock(mutex_);
return segment_;
}
void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
... ... @@ -125,6 +142,7 @@ class OnlineStream::Impl {
private:
FeatureExtractor feat_extractor_;
mutable std::mutex mutex_;
/// For contextual-biasing
ContextGraphPtr context_graph_;
int32_t num_processed_frames_ = 0; // before subsampling
... ...
... ... @@ -74,6 +74,8 @@ class OnlineZipformer2CtcModel::Impl {
int32_t ChunkShift() const { return decode_chunk_len_; }
bool UseWhisperFeature() const { return use_whisper_feature_; }
OrtAllocator *Allocator() { return allocator_; }
// Return a vector containing 3 tensors
... ... @@ -278,6 +280,12 @@ class OnlineZipformer2CtcModel::Impl {
SHERPA_ONNX_READ_META_DATA(T_, "T");
SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
std::string feature_type;
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(feature_type, "feature", "");
if (feature_type == "whisper") {
use_whisper_feature_ = true;
}
{
auto shape =
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
... ... @@ -417,6 +425,10 @@ class OnlineZipformer2CtcModel::Impl {
int32_t T_ = 0;
int32_t decode_chunk_len_ = 0;
int32_t vocab_size_ = 0;
// for models from
// https://github.com/k2-fsa/icefall/blob/master/egs/multi_zh-hans/ASR/RESULTS.md#streaming-with-ctc-head
bool use_whisper_feature_ = false;
};
OnlineZipformer2CtcModel::OnlineZipformer2CtcModel(
... ... @@ -447,6 +459,10 @@ int32_t OnlineZipformer2CtcModel::ChunkShift() const {
return impl_->ChunkShift();
}
bool OnlineZipformer2CtcModel::UseWhisperFeature() const {
return impl_->UseWhisperFeature();
}
OrtAllocator *OnlineZipformer2CtcModel::Allocator() const {
return impl_->Allocator();
}
... ...
... ... @@ -64,6 +64,8 @@ class OnlineZipformer2CtcModel : public OnlineCtcModel {
// before we process the next chunk.
int32_t ChunkShift() const override;
bool UseWhisperFeature() const override;
private:
class Impl;
std::unique_ptr<Impl> impl_;
... ...
... ... @@ -130,7 +130,7 @@ for a list of pre-trained models to download.
}
if (!mic.OpenDevice(device_index, mic_sample_rate, 1, RecordCallback,
nullptr /* user_data */)) {
s.get())) {
fprintf(stderr, "portaudio error: %d\n", device_index);
exit(EXIT_FAILURE);
}
... ...