Committed by
GitHub
Support zipformer CTC ASR with whisper features. (#2319)
正在显示
8 个修改的文件
包含
184 行增加
和
37 行删除
| @@ -60,6 +60,8 @@ class FeatureExtractor::Impl { | @@ -60,6 +60,8 @@ class FeatureExtractor::Impl { | ||
| 60 | explicit Impl(const FeatureExtractorConfig &config) : config_(config) { | 60 | explicit Impl(const FeatureExtractorConfig &config) : config_(config) { |
| 61 | if (config_.is_mfcc) { | 61 | if (config_.is_mfcc) { |
| 62 | InitMfcc(); | 62 | InitMfcc(); |
| 63 | + } else if (config_.is_whisper) { | ||
| 64 | + InitWhisper(); | ||
| 63 | } else { | 65 | } else { |
| 64 | InitFbank(); | 66 | InitFbank(); |
| 65 | } | 67 | } |
| @@ -92,13 +94,9 @@ class FeatureExtractor::Impl { | @@ -92,13 +94,9 @@ class FeatureExtractor::Impl { | ||
| 92 | 94 | ||
| 93 | std::vector<float> samples; | 95 | std::vector<float> samples; |
| 94 | resampler_->Resample(waveform, n, false, &samples); | 96 | resampler_->Resample(waveform, n, false, &samples); |
| 95 | - if (fbank_) { | ||
| 96 | - fbank_->AcceptWaveform(config_.sampling_rate, samples.data(), | ||
| 97 | - samples.size()); | ||
| 98 | - } else { | ||
| 99 | - mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(), | ||
| 100 | - samples.size()); | ||
| 101 | - } | 97 | + |
| 98 | + AcceptWaveformWrapper(config_.sampling_rate, samples.data(), | ||
| 99 | + samples.size()); | ||
| 102 | return; | 100 | return; |
| 103 | } | 101 | } |
| 104 | 102 | ||
| @@ -119,61 +117,81 @@ class FeatureExtractor::Impl { | @@ -119,61 +117,81 @@ class FeatureExtractor::Impl { | ||
| 119 | 117 | ||
| 120 | std::vector<float> samples; | 118 | std::vector<float> samples; |
| 121 | resampler_->Resample(waveform, n, false, &samples); | 119 | resampler_->Resample(waveform, n, false, &samples); |
| 122 | - if (fbank_) { | ||
| 123 | - fbank_->AcceptWaveform(config_.sampling_rate, samples.data(), | ||
| 124 | - samples.size()); | ||
| 125 | - } else { | ||
| 126 | - mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(), | ||
| 127 | - samples.size()); | ||
| 128 | - } | 120 | + |
| 121 | + AcceptWaveformWrapper(config_.sampling_rate, samples.data(), | ||
| 122 | + samples.size()); | ||
| 123 | + | ||
| 129 | return; | 124 | return; |
| 130 | } | 125 | } |
| 131 | 126 | ||
| 132 | - if (fbank_) { | ||
| 133 | - fbank_->AcceptWaveform(sampling_rate, waveform, n); | ||
| 134 | - } else { | ||
| 135 | - mfcc_->AcceptWaveform(sampling_rate, waveform, n); | ||
| 136 | - } | 127 | + AcceptWaveformWrapper(sampling_rate, waveform, n); |
| 137 | } | 128 | } |
| 138 | 129 | ||
| 139 | void InputFinished() const { | 130 | void InputFinished() const { |
| 140 | std::lock_guard<std::mutex> lock(mutex_); | 131 | std::lock_guard<std::mutex> lock(mutex_); |
| 141 | - fbank_->InputFinished(); | 132 | + if (fbank_) { |
| 133 | + fbank_->InputFinished(); | ||
| 134 | + } else if (whisper_fbank_) { | ||
| 135 | + whisper_fbank_->InputFinished(); | ||
| 136 | + } else if (mfcc_) { | ||
| 137 | + mfcc_->InputFinished(); | ||
| 138 | + } | ||
| 139 | + | ||
| 140 | + SHERPA_ONNX_LOGE("unreachable code"); | ||
| 141 | + SHERPA_ONNX_EXIT(-1); | ||
| 142 | } | 142 | } |
| 143 | 143 | ||
| 144 | int32_t NumFramesReady() const { | 144 | int32_t NumFramesReady() const { |
| 145 | - std::lock_guard<std::mutex> lock(mutex_); | ||
| 146 | - return fbank_->NumFramesReady(); | 145 | + if (fbank_) { |
| 146 | + return fbank_->NumFramesReady(); | ||
| 147 | + } else if (whisper_fbank_) { | ||
| 148 | + return whisper_fbank_->NumFramesReady(); | ||
| 149 | + } else if (mfcc_) { | ||
| 150 | + return mfcc_->NumFramesReady(); | ||
| 151 | + } | ||
| 152 | + SHERPA_ONNX_LOGE("unreachable code"); | ||
| 153 | + SHERPA_ONNX_EXIT(-1); | ||
| 154 | + return -1; | ||
| 147 | } | 155 | } |
| 148 | 156 | ||
| 149 | bool IsLastFrame(int32_t frame) const { | 157 | bool IsLastFrame(int32_t frame) const { |
| 150 | std::lock_guard<std::mutex> lock(mutex_); | 158 | std::lock_guard<std::mutex> lock(mutex_); |
| 151 | - return fbank_->IsLastFrame(frame); | 159 | + if (fbank_) { |
| 160 | + return fbank_->IsLastFrame(frame); | ||
| 161 | + } else if (whisper_fbank_) { | ||
| 162 | + return whisper_fbank_->IsLastFrame(frame); | ||
| 163 | + } else if (mfcc_) { | ||
| 164 | + return mfcc_->IsLastFrame(frame); | ||
| 165 | + } | ||
| 166 | + | ||
| 167 | + SHERPA_ONNX_LOGE("unreachable code"); | ||
| 168 | + SHERPA_ONNX_EXIT(-1); | ||
| 169 | + return false; | ||
| 152 | } | 170 | } |
| 153 | 171 | ||
| 154 | std::vector<float> GetFrames(int32_t frame_index, int32_t n) { | 172 | std::vector<float> GetFrames(int32_t frame_index, int32_t n) { |
| 155 | std::lock_guard<std::mutex> lock(mutex_); | 173 | std::lock_guard<std::mutex> lock(mutex_); |
| 156 | - if (frame_index + n > fbank_->NumFramesReady()) { | ||
| 157 | - SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n, | ||
| 158 | - fbank_->NumFramesReady()); | ||
| 159 | - exit(-1); | 174 | + if (frame_index + n > NumFramesReady()) { |
| 175 | + SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n, NumFramesReady()); | ||
| 176 | + SHERPA_ONNX_EXIT(-1); | ||
| 160 | } | 177 | } |
| 161 | 178 | ||
| 162 | int32_t discard_num = frame_index - last_frame_index_; | 179 | int32_t discard_num = frame_index - last_frame_index_; |
| 163 | if (discard_num < 0) { | 180 | if (discard_num < 0) { |
| 164 | SHERPA_ONNX_LOGE("last_frame_index_: %d, frame_index_: %d", | 181 | SHERPA_ONNX_LOGE("last_frame_index_: %d, frame_index_: %d", |
| 165 | last_frame_index_, frame_index); | 182 | last_frame_index_, frame_index); |
| 166 | - exit(-1); | 183 | + SHERPA_ONNX_EXIT(-1); |
| 167 | } | 184 | } |
| 168 | - fbank_->Pop(discard_num); | ||
| 169 | 185 | ||
| 170 | - int32_t feature_dim = fbank_->Dim(); | 186 | + PopWrapper(discard_num); |
| 187 | + | ||
| 188 | + int32_t feature_dim = FeatureDim(); | ||
| 171 | std::vector<float> features(feature_dim * n); | 189 | std::vector<float> features(feature_dim * n); |
| 172 | 190 | ||
| 173 | float *p = features.data(); | 191 | float *p = features.data(); |
| 174 | 192 | ||
| 175 | for (int32_t i = 0; i != n; ++i) { | 193 | for (int32_t i = 0; i != n; ++i) { |
| 176 | - const float *f = fbank_->GetFrame(i + frame_index); | 194 | + const float *f = GetFrameWrapper(i + frame_index); |
| 177 | std::copy(f, f + feature_dim, p); | 195 | std::copy(f, f + feature_dim, p); |
| 178 | p += feature_dim; | 196 | p += feature_dim; |
| 179 | } | 197 | } |
| @@ -184,10 +202,65 @@ class FeatureExtractor::Impl { | @@ -184,10 +202,65 @@ class FeatureExtractor::Impl { | ||
| 184 | } | 202 | } |
| 185 | 203 | ||
| 186 | int32_t FeatureDim() const { | 204 | int32_t FeatureDim() const { |
| 187 | - return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins; | 205 | + if (fbank_ || whisper_fbank_) { |
| 206 | + return opts_.mel_opts.num_bins; | ||
| 207 | + } else if (mfcc_) { | ||
| 208 | + return mfcc_opts_.num_ceps; | ||
| 209 | + } | ||
| 210 | + | ||
| 211 | + SHERPA_ONNX_LOGE("unreachable code"); | ||
| 212 | + SHERPA_ONNX_EXIT(-1); | ||
| 213 | + return -1; | ||
| 188 | } | 214 | } |
| 189 | 215 | ||
| 190 | private: | 216 | private: |
| 217 | + void AcceptWaveformWrapper(float sampling_rate, const float *waveform, | ||
| 218 | + int32_t n) const { | ||
| 219 | + if (fbank_) { | ||
| 220 | + fbank_->AcceptWaveform(sampling_rate, waveform, n); | ||
| 221 | + return; | ||
| 222 | + } else if (whisper_fbank_) { | ||
| 223 | + whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n); | ||
| 224 | + return; | ||
| 225 | + } else if (mfcc_) { | ||
| 226 | + mfcc_->AcceptWaveform(sampling_rate, waveform, n); | ||
| 227 | + return; | ||
| 228 | + } | ||
| 229 | + | ||
| 230 | + SHERPA_ONNX_LOGE("unreachable code"); | ||
| 231 | + SHERPA_ONNX_EXIT(-1); | ||
| 232 | + } | ||
| 233 | + | ||
| 234 | + const float *GetFrameWrapper(int32_t frame_index) const { | ||
| 235 | + if (fbank_) { | ||
| 236 | + return fbank_->GetFrame(frame_index); | ||
| 237 | + } else if (whisper_fbank_) { | ||
| 238 | + return whisper_fbank_->GetFrame(frame_index); | ||
| 239 | + } else if (mfcc_) { | ||
| 240 | + return mfcc_->GetFrame(frame_index); | ||
| 241 | + } | ||
| 242 | + | ||
| 243 | + SHERPA_ONNX_LOGE("unreachable code"); | ||
| 244 | + SHERPA_ONNX_EXIT(-1); | ||
| 245 | + return nullptr; | ||
| 246 | + } | ||
| 247 | + | ||
| 248 | + void PopWrapper(int32_t discard_num) const { | ||
| 249 | + if (fbank_) { | ||
| 250 | + fbank_->Pop(discard_num); | ||
| 251 | + return; | ||
| 252 | + } else if (whisper_fbank_) { | ||
| 253 | + whisper_fbank_->Pop(discard_num); | ||
| 254 | + return; | ||
| 255 | + } else if (mfcc_) { | ||
| 256 | + mfcc_->Pop(discard_num); | ||
| 257 | + return; | ||
| 258 | + } | ||
| 259 | + | ||
| 260 | + SHERPA_ONNX_LOGE("unreachable code"); | ||
| 261 | + SHERPA_ONNX_EXIT(-1); | ||
| 262 | + } | ||
| 263 | + | ||
| 191 | void InitFbank() { | 264 | void InitFbank() { |
| 192 | opts_.frame_opts.dither = config_.dither; | 265 | opts_.frame_opts.dither = config_.dither; |
| 193 | opts_.frame_opts.snip_edges = config_.snip_edges; | 266 | opts_.frame_opts.snip_edges = config_.snip_edges; |
| @@ -208,6 +281,7 @@ class FeatureExtractor::Impl { | @@ -208,6 +281,7 @@ class FeatureExtractor::Impl { | ||
| 208 | 281 | ||
| 209 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | 282 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); |
| 210 | } | 283 | } |
| 284 | + | ||
| 211 | void InitMfcc() { | 285 | void InitMfcc() { |
| 212 | mfcc_opts_.frame_opts.dither = config_.dither; | 286 | mfcc_opts_.frame_opts.dither = config_.dither; |
| 213 | mfcc_opts_.frame_opts.snip_edges = config_.snip_edges; | 287 | mfcc_opts_.frame_opts.snip_edges = config_.snip_edges; |
| @@ -232,9 +306,23 @@ class FeatureExtractor::Impl { | @@ -232,9 +306,23 @@ class FeatureExtractor::Impl { | ||
| 232 | mfcc_ = std::make_unique<knf::OnlineMfcc>(mfcc_opts_); | 306 | mfcc_ = std::make_unique<knf::OnlineMfcc>(mfcc_opts_); |
| 233 | } | 307 | } |
| 234 | 308 | ||
| 309 | + void InitWhisper() { | ||
| 310 | + config_.normalize_samples = true; | ||
| 311 | + opts_.frame_opts.samp_freq = 16000; | ||
| 312 | + opts_.mel_opts.num_bins = config_.feature_dim; | ||
| 313 | + | ||
| 314 | + knf::WhisperFeatureOptions whisper_opts; | ||
| 315 | + whisper_opts.frame_opts = opts_.frame_opts; | ||
| 316 | + whisper_opts.dim = config_.feature_dim; | ||
| 317 | + | ||
| 318 | + whisper_fbank_ = std::make_unique<knf::OnlineWhisperFbank>(whisper_opts); | ||
| 319 | + config_.sampling_rate = opts_.frame_opts.samp_freq; | ||
| 320 | + } | ||
| 321 | + | ||
| 235 | private: | 322 | private: |
| 236 | std::unique_ptr<knf::OnlineFbank> fbank_; | 323 | std::unique_ptr<knf::OnlineFbank> fbank_; |
| 237 | std::unique_ptr<knf::OnlineMfcc> mfcc_; | 324 | std::unique_ptr<knf::OnlineMfcc> mfcc_; |
| 325 | + std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_; | ||
| 238 | knf::FbankOptions opts_; | 326 | knf::FbankOptions opts_; |
| 239 | knf::MfccOptions mfcc_opts_; | 327 | knf::MfccOptions mfcc_opts_; |
| 240 | FeatureExtractorConfig config_; | 328 | FeatureExtractorConfig config_; |
| @@ -79,6 +79,8 @@ struct FeatureExtractorConfig { | @@ -79,6 +79,8 @@ struct FeatureExtractorConfig { | ||
| 79 | 79 | ||
| 80 | bool is_mfcc = false; | 80 | bool is_mfcc = false; |
| 81 | 81 | ||
| 82 | + bool is_whisper = false; | ||
| 83 | + | ||
| 82 | bool round_to_power_of_two = true; | 84 | bool round_to_power_of_two = true; |
| 83 | 85 | ||
| 84 | std::string ToString() const; | 86 | std::string ToString() const; |
| @@ -77,6 +77,8 @@ class OnlineCtcModel { | @@ -77,6 +77,8 @@ class OnlineCtcModel { | ||
| 77 | 77 | ||
| 78 | // Return true if the model supports batch size > 1 | 78 | // Return true if the model supports batch size > 1 |
| 79 | virtual bool SupportBatchProcessing() const { return true; } | 79 | virtual bool SupportBatchProcessing() const { return true; } |
| 80 | + | ||
| 81 | + virtual bool UseWhisperFeature() const { return false; } | ||
| 80 | }; | 82 | }; |
| 81 | 83 | ||
| 82 | } // namespace sherpa_onnx | 84 | } // namespace sherpa_onnx |
| @@ -15,6 +15,7 @@ | @@ -15,6 +15,7 @@ | ||
| 15 | 15 | ||
| 16 | #include "sherpa-onnx/csrc/file-utils.h" | 16 | #include "sherpa-onnx/csrc/file-utils.h" |
| 17 | #include "sherpa-onnx/csrc/macros.h" | 17 | #include "sherpa-onnx/csrc/macros.h" |
| 18 | +#include "sherpa-onnx/csrc/offline-whisper-model.h" | ||
| 18 | #include "sherpa-onnx/csrc/online-ctc-decoder.h" | 19 | #include "sherpa-onnx/csrc/online-ctc-decoder.h" |
| 19 | #include "sherpa-onnx/csrc/online-ctc-fst-decoder.h" | 20 | #include "sherpa-onnx/csrc/online-ctc-fst-decoder.h" |
| 20 | #include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h" | 21 | #include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h" |
| @@ -91,6 +92,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | @@ -91,6 +92,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 91 | config_.feat_config.normalize_samples = false; | 92 | config_.feat_config.normalize_samples = false; |
| 92 | } | 93 | } |
| 93 | 94 | ||
| 95 | + if (model_->UseWhisperFeature()) { | ||
| 96 | + config_.feat_config.is_whisper = true; | ||
| 97 | + } | ||
| 98 | + | ||
| 94 | InitDecoder(); | 99 | InitDecoder(); |
| 95 | } | 100 | } |
| 96 | 101 | ||
| @@ -108,6 +113,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | @@ -108,6 +113,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 108 | config_.feat_config.normalize_samples = false; | 113 | config_.feat_config.normalize_samples = false; |
| 109 | } | 114 | } |
| 110 | 115 | ||
| 116 | + if (model_->UseWhisperFeature()) { | ||
| 117 | + config_.feat_config.is_whisper = true; | ||
| 118 | + } | ||
| 119 | + | ||
| 111 | InitDecoder(); | 120 | InitDecoder(); |
| 112 | } | 121 | } |
| 113 | 122 | ||
| @@ -147,6 +156,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | @@ -147,6 +156,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 147 | const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); | 156 | const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); |
| 148 | std::vector<float> features = | 157 | std::vector<float> features = |
| 149 | ss[i]->GetFrames(num_processed_frames, chunk_length); | 158 | ss[i]->GetFrames(num_processed_frames, chunk_length); |
| 159 | + if (config_.feat_config.is_whisper) { | ||
| 160 | + OfflineWhisperModel::NormalizeFeatures(features.data(), chunk_length, | ||
| 161 | + feat_dim); | ||
| 162 | + } | ||
| 150 | 163 | ||
| 151 | // Question: should num_processed_frames include chunk_shift? | 164 | // Question: should num_processed_frames include chunk_shift? |
| 152 | ss[i]->GetNumProcessedFrames() += chunk_shift; | 165 | ss[i]->GetNumProcessedFrames() += chunk_shift; |
| @@ -287,6 +300,12 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | @@ -287,6 +300,12 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 287 | const auto num_processed_frames = s->GetNumProcessedFrames(); | 300 | const auto num_processed_frames = s->GetNumProcessedFrames(); |
| 288 | std::vector<float> frames = | 301 | std::vector<float> frames = |
| 289 | s->GetFrames(num_processed_frames, chunk_length); | 302 | s->GetFrames(num_processed_frames, chunk_length); |
| 303 | + | ||
| 304 | + if (config_.feat_config.is_whisper) { | ||
| 305 | + OfflineWhisperModel::NormalizeFeatures(frames.data(), chunk_length, | ||
| 306 | + feat_dim); | ||
| 307 | + } | ||
| 308 | + | ||
| 290 | s->GetNumProcessedFrames() += chunk_shift; | 309 | s->GetNumProcessedFrames() += chunk_shift; |
| 291 | 310 | ||
| 292 | auto memory_info = | 311 | auto memory_info = |
| @@ -19,34 +19,51 @@ class OnlineStream::Impl { | @@ -19,34 +19,51 @@ class OnlineStream::Impl { | ||
| 19 | : feat_extractor_(config), context_graph_(std::move(context_graph)) {} | 19 | : feat_extractor_(config), context_graph_(std::move(context_graph)) {} |
| 20 | 20 | ||
| 21 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { | 21 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { |
| 22 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 22 | feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); | 23 | feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); |
| 23 | } | 24 | } |
| 24 | 25 | ||
| 25 | - void InputFinished() const { feat_extractor_.InputFinished(); } | 26 | + void InputFinished() const { |
| 27 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 28 | + feat_extractor_.InputFinished(); | ||
| 29 | + } | ||
| 26 | 30 | ||
| 27 | int32_t NumFramesReady() const { | 31 | int32_t NumFramesReady() const { |
| 32 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 28 | return feat_extractor_.NumFramesReady() - start_frame_index_; | 33 | return feat_extractor_.NumFramesReady() - start_frame_index_; |
| 29 | } | 34 | } |
| 30 | 35 | ||
| 31 | bool IsLastFrame(int32_t frame) const { | 36 | bool IsLastFrame(int32_t frame) const { |
| 37 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 32 | return feat_extractor_.IsLastFrame(frame); | 38 | return feat_extractor_.IsLastFrame(frame); |
| 33 | } | 39 | } |
| 34 | 40 | ||
| 35 | std::vector<float> GetFrames(int32_t frame_index, int32_t n) const { | 41 | std::vector<float> GetFrames(int32_t frame_index, int32_t n) const { |
| 42 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 36 | return feat_extractor_.GetFrames(frame_index + start_frame_index_, n); | 43 | return feat_extractor_.GetFrames(frame_index + start_frame_index_, n); |
| 37 | } | 44 | } |
| 38 | 45 | ||
| 39 | void Reset() { | 46 | void Reset() { |
| 47 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 40 | // we don't reset the feature extractor | 48 | // we don't reset the feature extractor |
| 41 | start_frame_index_ += num_processed_frames_; | 49 | start_frame_index_ += num_processed_frames_; |
| 42 | num_processed_frames_ = 0; | 50 | num_processed_frames_ = 0; |
| 43 | } | 51 | } |
| 44 | 52 | ||
| 45 | - int32_t &GetNumProcessedFrames() { return num_processed_frames_; } | 53 | + int32_t &GetNumProcessedFrames() { |
| 54 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 55 | + return num_processed_frames_; | ||
| 56 | + } | ||
| 46 | 57 | ||
| 47 | - int32_t GetNumFramesSinceStart() const { return start_frame_index_; } | 58 | + int32_t GetNumFramesSinceStart() const { |
| 59 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 60 | + return start_frame_index_; | ||
| 61 | + } | ||
| 48 | 62 | ||
| 49 | - int32_t &GetCurrentSegment() { return segment_; } | 63 | + int32_t &GetCurrentSegment() { |
| 64 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 65 | + return segment_; | ||
| 66 | + } | ||
| 50 | 67 | ||
| 51 | void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } | 68 | void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } |
| 52 | 69 | ||
| @@ -125,6 +142,7 @@ class OnlineStream::Impl { | @@ -125,6 +142,7 @@ class OnlineStream::Impl { | ||
| 125 | 142 | ||
| 126 | private: | 143 | private: |
| 127 | FeatureExtractor feat_extractor_; | 144 | FeatureExtractor feat_extractor_; |
| 145 | + mutable std::mutex mutex_; | ||
| 128 | /// For contextual-biasing | 146 | /// For contextual-biasing |
| 129 | ContextGraphPtr context_graph_; | 147 | ContextGraphPtr context_graph_; |
| 130 | int32_t num_processed_frames_ = 0; // before subsampling | 148 | int32_t num_processed_frames_ = 0; // before subsampling |
| @@ -74,6 +74,8 @@ class OnlineZipformer2CtcModel::Impl { | @@ -74,6 +74,8 @@ class OnlineZipformer2CtcModel::Impl { | ||
| 74 | 74 | ||
| 75 | int32_t ChunkShift() const { return decode_chunk_len_; } | 75 | int32_t ChunkShift() const { return decode_chunk_len_; } |
| 76 | 76 | ||
| 77 | + bool UseWhisperFeature() const { return use_whisper_feature_; } | ||
| 78 | + | ||
| 77 | OrtAllocator *Allocator() { return allocator_; } | 79 | OrtAllocator *Allocator() { return allocator_; } |
| 78 | 80 | ||
| 79 | // Return a vector containing 3 tensors | 81 | // Return a vector containing 3 tensors |
| @@ -278,6 +280,12 @@ class OnlineZipformer2CtcModel::Impl { | @@ -278,6 +280,12 @@ class OnlineZipformer2CtcModel::Impl { | ||
| 278 | SHERPA_ONNX_READ_META_DATA(T_, "T"); | 280 | SHERPA_ONNX_READ_META_DATA(T_, "T"); |
| 279 | SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); | 281 | SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); |
| 280 | 282 | ||
| 283 | + std::string feature_type; | ||
| 284 | + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(feature_type, "feature", ""); | ||
| 285 | + if (feature_type == "whisper") { | ||
| 286 | + use_whisper_feature_ = true; | ||
| 287 | + } | ||
| 288 | + | ||
| 281 | { | 289 | { |
| 282 | auto shape = | 290 | auto shape = |
| 283 | sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape(); | 291 | sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape(); |
| @@ -417,6 +425,10 @@ class OnlineZipformer2CtcModel::Impl { | @@ -417,6 +425,10 @@ class OnlineZipformer2CtcModel::Impl { | ||
| 417 | int32_t T_ = 0; | 425 | int32_t T_ = 0; |
| 418 | int32_t decode_chunk_len_ = 0; | 426 | int32_t decode_chunk_len_ = 0; |
| 419 | int32_t vocab_size_ = 0; | 427 | int32_t vocab_size_ = 0; |
| 428 | + | ||
| 429 | + // for models from | ||
| 430 | + // https://github.com/k2-fsa/icefall/blob/master/egs/multi_zh-hans/ASR/RESULTS.md#streaming-with-ctc-head | ||
| 431 | + bool use_whisper_feature_ = false; | ||
| 420 | }; | 432 | }; |
| 421 | 433 | ||
| 422 | OnlineZipformer2CtcModel::OnlineZipformer2CtcModel( | 434 | OnlineZipformer2CtcModel::OnlineZipformer2CtcModel( |
| @@ -447,6 +459,10 @@ int32_t OnlineZipformer2CtcModel::ChunkShift() const { | @@ -447,6 +459,10 @@ int32_t OnlineZipformer2CtcModel::ChunkShift() const { | ||
| 447 | return impl_->ChunkShift(); | 459 | return impl_->ChunkShift(); |
| 448 | } | 460 | } |
| 449 | 461 | ||
| 462 | +bool OnlineZipformer2CtcModel::UseWhisperFeature() const { | ||
| 463 | + return impl_->UseWhisperFeature(); | ||
| 464 | +} | ||
| 465 | + | ||
| 450 | OrtAllocator *OnlineZipformer2CtcModel::Allocator() const { | 466 | OrtAllocator *OnlineZipformer2CtcModel::Allocator() const { |
| 451 | return impl_->Allocator(); | 467 | return impl_->Allocator(); |
| 452 | } | 468 | } |
| @@ -64,6 +64,8 @@ class OnlineZipformer2CtcModel : public OnlineCtcModel { | @@ -64,6 +64,8 @@ class OnlineZipformer2CtcModel : public OnlineCtcModel { | ||
| 64 | // before we process the next chunk. | 64 | // before we process the next chunk. |
| 65 | int32_t ChunkShift() const override; | 65 | int32_t ChunkShift() const override; |
| 66 | 66 | ||
| 67 | + bool UseWhisperFeature() const override; | ||
| 68 | + | ||
| 67 | private: | 69 | private: |
| 68 | class Impl; | 70 | class Impl; |
| 69 | std::unique_ptr<Impl> impl_; | 71 | std::unique_ptr<Impl> impl_; |
| @@ -130,7 +130,7 @@ for a list of pre-trained models to download. | @@ -130,7 +130,7 @@ for a list of pre-trained models to download. | ||
| 130 | } | 130 | } |
| 131 | 131 | ||
| 132 | if (!mic.OpenDevice(device_index, mic_sample_rate, 1, RecordCallback, | 132 | if (!mic.OpenDevice(device_index, mic_sample_rate, 1, RecordCallback, |
| 133 | - nullptr /* user_data */)) { | 133 | + s.get())) { |
| 134 | fprintf(stderr, "portaudio error: %d\n", device_index); | 134 | fprintf(stderr, "portaudio error: %d\n", device_index); |
| 135 | exit(EXIT_FAILURE); | 135 | exit(EXIT_FAILURE); |
| 136 | } | 136 | } |
-
请 注册 或 登录 后发表评论