Committed by
GitHub
Limit number of tokens per second for whisper. (#1958)
Otherwise, it spends lots of time in the loop if the EOT token is not predicted.
正在显示
4 个修改的文件
包含
14 行增加
和
6 行删除
| @@ -131,7 +131,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | @@ -131,7 +131,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | ||
| 131 | auto cross_kv = model_->ForwardEncoder(std::move(mel)); | 131 | auto cross_kv = model_->ForwardEncoder(std::move(mel)); |
| 132 | 132 | ||
| 133 | auto results = decoder_->Decode(std::move(cross_kv.first), | 133 | auto results = decoder_->Decode(std::move(cross_kv.first), |
| 134 | - std::move(cross_kv.second)); | 134 | + std::move(cross_kv.second), num_frames); |
| 135 | 135 | ||
| 136 | auto r = Convert(results[0], symbol_table_); | 136 | auto r = Convert(results[0], symbol_table_); |
| 137 | s->SetResult(r); | 137 | s->SetResult(r); |
| @@ -33,7 +33,8 @@ class OfflineWhisperDecoder { | @@ -33,7 +33,8 @@ class OfflineWhisperDecoder { | ||
| 33 | * @return Return a vector of size `N` containing the decoded results. | 33 | * @return Return a vector of size `N` containing the decoded results. |
| 34 | */ | 34 | */ |
| 35 | virtual std::vector<OfflineWhisperDecoderResult> Decode( | 35 | virtual std::vector<OfflineWhisperDecoderResult> Decode( |
| 36 | - Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0; | 36 | + Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v, |
| 37 | + int32_t num_feature_frames) = 0; | ||
| 37 | 38 | ||
| 38 | virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0; | 39 | virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0; |
| 39 | }; | 40 | }; |
| @@ -19,7 +19,8 @@ void OfflineWhisperGreedySearchDecoder::SetConfig( | @@ -19,7 +19,8 @@ void OfflineWhisperGreedySearchDecoder::SetConfig( | ||
| 19 | 19 | ||
| 20 | std::vector<OfflineWhisperDecoderResult> | 20 | std::vector<OfflineWhisperDecoderResult> |
| 21 | OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | 21 | OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, |
| 22 | - Ort::Value cross_v) { | 22 | + Ort::Value cross_v, |
| 23 | + int32_t num_feature_frames) { | ||
| 23 | auto memory_info = | 24 | auto memory_info = |
| 24 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | 25 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); |
| 25 | 26 | ||
| @@ -99,7 +100,12 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | @@ -99,7 +100,12 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | ||
| 99 | int32_t n_text_ctx = model_->TextCtx(); | 100 | int32_t n_text_ctx = model_->TextCtx(); |
| 100 | 101 | ||
| 101 | std::vector<int32_t> predicted_tokens; | 102 | std::vector<int32_t> predicted_tokens; |
| 102 | - for (int32_t i = 0; i < n_text_ctx / 2; ++i) { | 103 | + |
| 104 | + // assume at most 6 tokens per second | ||
| 105 | + int32_t num_possible_tokens = num_feature_frames / 100 * 6; | ||
| 106 | + num_possible_tokens = std::min<int32_t>(num_possible_tokens, n_text_ctx / 2); | ||
| 107 | + | ||
| 108 | + for (int32_t i = 0; i < num_possible_tokens; ++i) { | ||
| 103 | if (max_token_id == model_->EOT()) { | 109 | if (max_token_id == model_->EOT()) { |
| 104 | break; | 110 | break; |
| 105 | } | 111 | } |
| @@ -18,8 +18,9 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { | @@ -18,8 +18,9 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { | ||
| 18 | OfflineWhisperModel *model) | 18 | OfflineWhisperModel *model) |
| 19 | : config_(config), model_(model) {} | 19 | : config_(config), model_(model) {} |
| 20 | 20 | ||
| 21 | - std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k, | ||
| 22 | - Ort::Value cross_v) override; | 21 | + std::vector<OfflineWhisperDecoderResult> Decode( |
| 22 | + Ort::Value cross_k, Ort::Value cross_v, | ||
| 23 | + int32_t num_feature_frames) override; | ||
| 23 | 24 | ||
| 24 | void SetConfig(const OfflineWhisperModelConfig &config) override; | 25 | void SetConfig(const OfflineWhisperModelConfig &config) override; |
| 25 | 26 |
-
请 注册 或 登录 后发表评论