Fangjun Kuang
Committed by GitHub

Limit number of tokens per second for whisper. (#1958)

Otherwise, it spends lots of time in the loop if the EOT token
is not predicted.
@@ -131,7 +131,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -131,7 +131,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
131 auto cross_kv = model_->ForwardEncoder(std::move(mel)); 131 auto cross_kv = model_->ForwardEncoder(std::move(mel));
132 132
133 auto results = decoder_->Decode(std::move(cross_kv.first), 133 auto results = decoder_->Decode(std::move(cross_kv.first),
134 - std::move(cross_kv.second)); 134 + std::move(cross_kv.second), num_frames);
135 135
136 auto r = Convert(results[0], symbol_table_); 136 auto r = Convert(results[0], symbol_table_);
137 s->SetResult(r); 137 s->SetResult(r);
@@ -33,7 +33,8 @@ class OfflineWhisperDecoder { @@ -33,7 +33,8 @@ class OfflineWhisperDecoder {
33 * @return Return a vector of size `N` containing the decoded results. 33 * @return Return a vector of size `N` containing the decoded results.
34 */ 34 */
35 virtual std::vector<OfflineWhisperDecoderResult> Decode( 35 virtual std::vector<OfflineWhisperDecoderResult> Decode(
36 - Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0; 36 + Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v,
  37 + int32_t num_feature_frames) = 0;
37 38
38 virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0; 39 virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0;
39 }; 40 };
@@ -19,7 +19,8 @@ void OfflineWhisperGreedySearchDecoder::SetConfig( @@ -19,7 +19,8 @@ void OfflineWhisperGreedySearchDecoder::SetConfig(
19 19
20 std::vector<OfflineWhisperDecoderResult> 20 std::vector<OfflineWhisperDecoderResult>
21 OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, 21 OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
22 - Ort::Value cross_v) { 22 + Ort::Value cross_v,
  23 + int32_t num_feature_frames) {
23 auto memory_info = 24 auto memory_info =
24 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 25 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
25 26
@@ -99,7 +100,12 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, @@ -99,7 +100,12 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
99 int32_t n_text_ctx = model_->TextCtx(); 100 int32_t n_text_ctx = model_->TextCtx();
100 101
101 std::vector<int32_t> predicted_tokens; 102 std::vector<int32_t> predicted_tokens;
102 - for (int32_t i = 0; i < n_text_ctx / 2; ++i) { 103 +
  104 + // assume at most 6 tokens per second
  105 + int32_t num_possible_tokens = num_feature_frames / 100 * 6;
  106 + num_possible_tokens = std::min<int32_t>(num_possible_tokens, n_text_ctx / 2);
  107 +
  108 + for (int32_t i = 0; i < num_possible_tokens; ++i) {
103 if (max_token_id == model_->EOT()) { 109 if (max_token_id == model_->EOT()) {
104 break; 110 break;
105 } 111 }
@@ -18,8 +18,9 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { @@ -18,8 +18,9 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
18 OfflineWhisperModel *model) 18 OfflineWhisperModel *model)
19 : config_(config), model_(model) {} 19 : config_(config), model_(model) {}
20 20
21 - std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,  
22 - Ort::Value cross_v) override; 21 + std::vector<OfflineWhisperDecoderResult> Decode(
  22 + Ort::Value cross_k, Ort::Value cross_v,
  23 + int32_t num_feature_frames) override;
23 24
24 void SetConfig(const OfflineWhisperModelConfig &config) override; 25 void SetConfig(const OfflineWhisperModelConfig &config) override;
25 26